auto_sigma_rule_generator/backend/llm_client.py

652 lines
No EOL
26 KiB
Python

"""
LangChain-based LLM client for enhanced SIGMA rule generation.
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
"""
import os
import logging
from typing import Optional, Dict, Any, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
import yaml
logger = logging.getLogger(__name__)
class LLMClient:
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
SUPPORTED_PROVIDERS = {
'openai': {
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
'env_key': 'OPENAI_API_KEY',
'default_model': 'gpt-4o-mini'
},
'anthropic': {
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
'env_key': 'ANTHROPIC_API_KEY',
'default_model': 'claude-3-5-sonnet-20241022'
},
'ollama': {
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
'env_key': 'OLLAMA_BASE_URL',
'default_model': 'llama3.2'
}
}
def __init__(self, provider: str = None, model: str = None):
"""Initialize LLM client with specified provider and model."""
self.provider = provider or self._detect_provider()
self.model = model or self._get_default_model(self.provider)
self.llm = None
self.output_parser = StrOutputParser()
self._initialize_llm()
def _detect_provider(self) -> str:
"""Auto-detect available LLM provider based on environment variables."""
# Check for API keys in order of preference
if os.getenv('ANTHROPIC_API_KEY'):
return 'anthropic'
elif os.getenv('OPENAI_API_KEY'):
return 'openai'
elif os.getenv('OLLAMA_BASE_URL'):
return 'ollama'
else:
# Default to OpenAI if no keys found
return 'openai'
def _get_default_model(self, provider: str) -> str:
"""Get default model for the specified provider."""
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
def _initialize_llm(self):
"""Initialize the LLM based on provider and model."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
logger.warning("OpenAI API key not found")
return
self.llm = ChatOpenAI(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
logger.warning("Anthropic API key not found")
return
self.llm = ChatAnthropic(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Check if model is available, if not try to pull it
if not self._check_ollama_model_available(base_url, self.model):
logger.info(f"Model {self.model} not found, attempting to pull...")
if self._pull_ollama_model(base_url, self.model):
logger.info(f"Successfully pulled model {self.model}")
else:
logger.error(f"Failed to pull model {self.model}")
return
self.llm = Ollama(
model=self.model,
base_url=base_url,
temperature=0.1
)
if self.llm:
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
else:
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
except Exception as e:
logger.error(f"Error initializing LLM client: {e}")
self.llm = None
def is_available(self) -> bool:
"""Check if LLM client is available and configured."""
return self.llm is not None
def get_provider_info(self) -> Dict[str, Any]:
"""Get information about the current provider and configuration."""
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
# For Ollama, get actually available models
available_models = provider_info.get('models', [])
if self.provider == 'ollama':
ollama_models = self._get_ollama_available_models()
if ollama_models:
available_models = ollama_models
return {
'provider': self.provider,
'model': self.model,
'available': self.is_available(),
'supported_models': provider_info.get('models', []),
'available_models': available_models,
'env_key': provider_info.get('env_key', ''),
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
}
async def generate_sigma_rule(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> Optional[str]:
"""
Generate or enhance a SIGMA rule using the configured LLM.
Args:
cve_id: CVE identifier
poc_content: Proof-of-concept code content from GitHub
cve_description: CVE description from NVD
existing_rule: Optional existing SIGMA rule to enhance
Returns:
Generated SIGMA rule YAML content or None if failed
"""
if not self.is_available():
logger.warning("LLM client not available")
return None
try:
# Create the prompt template
prompt = self._build_sigma_generation_prompt(
cve_id, poc_content, cve_description, existing_rule
)
# Create the chain
chain = prompt | self.llm | self.output_parser
# Debug: Log what we're sending to the LLM
input_data = {
"cve_id": cve_id,
"poc_content": poc_content[:4000], # Truncate if too long
"cve_description": cve_description,
"existing_rule": existing_rule or "None"
}
logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}")
# Generate the response
response = await chain.ainvoke(input_data)
# Debug: Log raw LLM response
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
# Extract the SIGMA rule from response
sigma_rule = self._extract_sigma_rule(response)
# Post-process to ensure clean YAML
sigma_rule = self._post_process_sigma_rule(sigma_rule)
# Debug: Log final processed rule
logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...")
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
return sigma_rule
except Exception as e:
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
return None
def _build_sigma_generation_prompt(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
"""Build the prompt template for SIGMA rule generation."""
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
**CRITICAL: You must follow the exact SIGMA specification format:**
1. **YAML Structure Requirements:**
- Use UTF-8 encoding with LF line breaks
- Indent with 4 spaces (no tabs)
- Use lowercase keys only
- Use single quotes for string values
- No quotes for numeric values
- Follow proper YAML syntax
2. **MANDATORY Fields (must include):**
- title: Brief description (max 256 chars)
- logsource: Log data source specification
- detection: Search identifiers and conditions
- condition: How detection elements combine
3. **RECOMMENDED Fields:**
- id: Unique UUID
- status: 'experimental' (for new rules)
- description: Detailed explanation
- author: 'AI Generated'
- date: Current date (YYYY/MM/DD)
- references: Array with CVE link
- tags: MITRE ATT&CK techniques
4. **Detection Structure:**
- Use selection blocks (selection, selection1, etc.)
- Condition references these selections
- Use proper field names (Image, CommandLine, ProcessName, etc.)
- Support wildcards (*) and value lists
**ABSOLUTE REQUIREMENTS:**
- Output ONLY valid YAML
- NO explanatory text before or after
- NO comments or instructions
- NO markdown formatting or code blocks
- NEVER repeat the input prompt or template
- NEVER include variables like {cve_id} or {poc_content}
- NO "Human:", "CVE ID:", "Description:" headers
- NO "Analyze this" or "Output only" text
- Start IMMEDIATELY with 'title:'
- End with the last YAML line only
- Ensure perfect YAML syntax
**STRUCTURE REQUIREMENTS:**
- title: Descriptive title that MUST include the exact CVE ID provided by the user
- id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012')
- status: experimental
- description: Specific description based on CVE and PoC analysis
- author: 'AI Generated'
- date: Current date (2025/01/11)
- references: Include the EXACT CVE URL with the CVE ID provided by the user
- tags: Relevant MITRE ATT&CK techniques based on PoC analysis
- logsource: Appropriate category based on exploit type
- detection: Specific indicators from PoC analysis (NOT generic examples)
- condition: Logic connecting the detection selections
**CRITICAL RULES:**
1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID
2. Analyze the provided CVE and PoC content to create SPECIFIC detection patterns
3. DO NOT hallucinate or invent CVE IDs from your training data
4. Use the CVE ID exactly as provided in the title and references"""
if existing_rule:
user_template = """CVE ID: {cve_id}
CVE Description: {cve_description}
PoC Code:
{poc_content}
Existing SIGMA Rule:
{existing_rule}
Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'."""
else:
user_template = """CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE:
**MANDATORY CVE ID TO USE: {cve_id}**
**CVE Description: {cve_description}**
**Proof-of-Concept Code Analysis:**
{poc_content}
**CRITICAL REQUIREMENTS:**
1. Use EXACTLY this CVE ID in the title: {cve_id}
2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{cve_id}
3. Analyze the CVE description to understand the vulnerability type
4. Extract specific indicators from the PoC code (files, processes, commands, network patterns)
5. Create detection logic based on the actual exploit behavior
6. Use relevant logsource category (process_creation, file_event, network_connection, etc.)
7. Include appropriate MITRE ATT&CK tags based on the exploit techniques
**IMPORTANT: You MUST use the exact CVE ID "{cve_id}" - do NOT generate a different CVE ID!**
Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {cve_id}."""
return ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
def _extract_sigma_rule(self, response_text: str) -> str:
"""Extract and clean SIGMA rule YAML from LLM response."""
lines = response_text.split('\n')
yaml_lines = []
in_yaml_block = False
found_title = False
for line in lines:
stripped = line.strip()
# Skip code block markers
if stripped.startswith('```'):
if stripped.startswith('```yaml'):
in_yaml_block = True
elif stripped == '```' and in_yaml_block:
break
continue
# Skip obvious non-YAML content
if not in_yaml_block and not found_title:
if not stripped.startswith('title:'):
# Skip explanatory text and prompt artifacts
skip_phrases = [
'please note', 'this rule', 'you should', 'analysis:',
'explanation:', 'based on', 'the following', 'here is',
'note that', 'important:', 'remember', 'this is a',
'make sure to', 'you can modify', 'adjust the',
'human:', 'cve id:', 'cve description:', 'poc code:',
'exploit code:', 'analyze this', 'create a', 'output only'
]
if any(phrase in stripped.lower() for phrase in skip_phrases):
continue
# Skip template variables and prompt artifacts
if '{' in stripped and '}' in stripped:
continue
# Skip lines that are clearly not YAML structure
if stripped and not ':' in stripped and len(stripped) > 20:
continue
# Start collecting when we find title or are in YAML block
if stripped.startswith('title:') or in_yaml_block:
found_title = True
in_yaml_block = True
# Skip explanatory comments
if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()):
continue
yaml_lines.append(line)
# Stop if we encounter obvious non-YAML after starting
elif found_title:
# Stop at explanatory text after the rule
stop_phrases = [
'please note', 'this rule should', 'make sure to',
'you can modify', 'adjust the', 'also, this is',
'based on the analysis', 'the rule above'
]
if any(phrase in stripped.lower() for phrase in stop_phrases):
break
# Stop at lines without colons that aren't indented (likely explanations)
if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped:
break
if not yaml_lines:
# Fallback: look for any line with YAML-like structure
for line in lines:
if ':' in line and not line.strip().startswith('#'):
yaml_lines.append(line)
return '\n'.join(yaml_lines).strip()
def _post_process_sigma_rule(self, rule_content: str) -> str:
"""Post-process SIGMA rule to ensure clean YAML format."""
lines = rule_content.split('\n')
cleaned_lines = []
for line in lines:
stripped = line.strip()
# Skip obvious non-YAML content and prompt artifacts
if any(phrase in stripped.lower() for phrase in [
'please note', 'you should replace', 'this is a proof-of-concept',
'please make sure', 'note that', 'important:', 'remember to',
'analysis shows', 'based on the', 'the rule above', 'this rule',
'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:',
'analyze this', 'create a', 'output only', 'generate a'
]):
continue
# Skip template variables
if '{' in stripped and '}' in stripped:
continue
# Skip lines that look like explanations
if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '):
# This might be explanatory text, skip it
if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']):
continue
# Skip empty explanatory sections
if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']:
continue
cleaned_lines.append(line)
return '\n'.join(cleaned_lines).strip()
async def enhance_existing_rule(self,
existing_rule: str,
poc_content: str,
cve_id: str) -> Optional[str]:
"""
Enhance an existing SIGMA rule with PoC analysis.
Args:
existing_rule: Existing SIGMA rule YAML
poc_content: PoC code content
cve_id: CVE identifier
Returns:
Enhanced SIGMA rule or None if failed
"""
if not self.is_available():
return None
try:
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
**Task:** Enhance the existing rule by:
1. Adding more specific detection patterns found in the PoC
2. Improving the condition logic
3. Adding relevant tags or MITRE ATT&CK mappings
4. Keeping the rule structure intact but making it more effective
Output ONLY the enhanced SIGMA rule in valid YAML format."""
user_template = """**CVE ID:** {cve_id}
**PoC Code:**
```
{poc_content}
```
**Existing SIGMA Rule:**
```yaml
{existing_rule}
```"""
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
chain = prompt | self.llm | self.output_parser
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:3000],
"existing_rule": existing_rule
})
enhanced_rule = self._extract_sigma_rule(response)
enhanced_rule = self._post_process_sigma_rule(enhanced_rule)
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
return enhanced_rule
except Exception as e:
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
return None
def validate_sigma_rule(self, rule_content: str) -> bool:
"""Validate that the generated rule follows SIGMA specification."""
try:
# Parse as YAML
parsed = yaml.safe_load(rule_content)
if not isinstance(parsed, dict):
logger.warning("Rule must be a YAML dictionary")
return False
# Check MANDATORY fields per SIGMA spec
mandatory_fields = ['title', 'logsource', 'detection']
for field in mandatory_fields:
if field not in parsed:
logger.warning(f"Missing mandatory field: {field}")
return False
# Validate title
title = parsed.get('title', '')
if not isinstance(title, str) or len(title) > 256:
logger.warning("Title must be string ≤256 characters")
return False
# Validate logsource structure
logsource = parsed.get('logsource', {})
if not isinstance(logsource, dict):
logger.warning("Logsource must be a dictionary")
return False
# Validate detection structure
detection = parsed.get('detection', {})
if not isinstance(detection, dict):
logger.warning("Detection must be a dictionary")
return False
# Check for condition (can be in detection or at root)
has_condition = 'condition' in detection or 'condition' in parsed
if not has_condition:
logger.warning("Missing condition field")
return False
# Check for at least one selection
selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter']
for key in detection.keys() if key != 'condition')
if not selection_found:
logger.warning("Detection must have at least one selection")
return False
# Validate status if present
status = parsed.get('status')
if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']:
logger.warning(f"Invalid status: {status}")
return False
logger.info("SIGMA rule validation passed")
return True
except yaml.YAMLError as e:
logger.warning(f"YAML parsing error: {e}")
return False
except Exception as e:
logger.warning(f"Rule validation error: {e}")
return False
@classmethod
def get_available_providers(cls) -> List[Dict[str, Any]]:
"""Get list of available LLM providers and their configuration status."""
providers = []
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
env_key = provider_info.get('env_key', '')
api_key_configured = bool(os.getenv(env_key))
providers.append({
'name': provider_name,
'models': provider_info.get('models', []),
'default_model': provider_info.get('default_model', ''),
'env_key': env_key,
'api_key_configured': api_key_configured,
'available': api_key_configured or provider_name == 'ollama'
})
return providers
def switch_provider(self, provider: str, model: str = None):
"""Switch to a different LLM provider and model."""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.model = model or self._get_default_model(provider)
self._initialize_llm()
logger.info(f"Switched to provider: {provider} with model: {self.model}")
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
"""Check if an Ollama model is available locally"""
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
return True
return False
except Exception as e:
logger.error(f"Error checking Ollama models: {e}")
return False
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
import requests
import json
# Use the pull API endpoint
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=300, # 5 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
if data.get('status'):
logger.info(f"Ollama pull progress: {data.get('status')}")
if data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
return False
except Exception as e:
logger.error(f"Error pulling Ollama model {model}: {e}")
return False
def _get_ollama_available_models(self) -> List[str]:
"""Get list of available Ollama models"""
try:
import requests
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
return [m.get('name', '') for m in models if m.get('name')]
return []
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
return []