auto_sigma_rule_generator/backend/llm_client.py

1389 lines
No EOL
61 KiB
Python

"""
LangChain-based LLM client for enhanced SIGMA rule generation.
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
"""
import os
import logging
from typing import Optional, Dict, Any, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
import yaml
from cve2capec_client import CVE2CAPECClient
logger = logging.getLogger(__name__)
class LLMClient:
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
SUPPORTED_PROVIDERS = {
'openai': {
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
'env_key': 'OPENAI_API_KEY',
'default_model': 'gpt-4o-mini'
},
'anthropic': {
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
'env_key': 'ANTHROPIC_API_KEY',
'default_model': 'claude-3-5-sonnet-20241022'
},
'ollama': {
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
'env_key': 'OLLAMA_BASE_URL',
'default_model': 'llama3.2'
}
}
def __init__(self, provider: str = None, model: str = None):
"""Initialize LLM client with specified provider and model."""
self.provider = provider or self._detect_provider()
self.model = model or self._get_default_model(self.provider)
self.llm = None
self.output_parser = StrOutputParser()
self.cve2capec_client = CVE2CAPECClient()
self._initialize_llm()
def _detect_provider(self) -> str:
"""Auto-detect available LLM provider based on environment variables."""
# Check for API keys in order of preference
if os.getenv('ANTHROPIC_API_KEY'):
return 'anthropic'
elif os.getenv('OPENAI_API_KEY'):
return 'openai'
elif os.getenv('OLLAMA_BASE_URL'):
return 'ollama'
else:
# Default to OpenAI if no keys found
return 'openai'
def _get_default_model(self, provider: str) -> str:
"""Get default model for the specified provider."""
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
def _initialize_llm(self):
"""Initialize the LLM based on provider and model."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
logger.warning("OpenAI API key not found")
return
self.llm = ChatOpenAI(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
logger.warning("Anthropic API key not found")
return
self.llm = ChatAnthropic(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Check if model is available, if not try to pull it
if not self._check_ollama_model_available(base_url, self.model):
logger.info(f"Model {self.model} not found, attempting to pull...")
if self._pull_ollama_model(base_url, self.model):
logger.info(f"Successfully pulled model {self.model}")
else:
logger.error(f"Failed to pull model {self.model}")
return
self.llm = Ollama(
model=self.model,
base_url=base_url,
temperature=0.1,
num_ctx=4096, # Context window size
top_p=0.9,
top_k=40
)
if self.llm:
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
else:
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
except Exception as e:
logger.error(f"Error initializing LLM client: {e}")
self.llm = None
def is_available(self) -> bool:
"""Check if LLM client is available and configured."""
return self.llm is not None
def get_provider_info(self) -> Dict[str, Any]:
"""Get information about the current provider and configuration."""
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
# For Ollama, get actually available models
available_models = provider_info.get('models', [])
if self.provider == 'ollama':
ollama_models = self._get_ollama_available_models()
if ollama_models:
available_models = ollama_models
return {
'provider': self.provider,
'model': self.model,
'available': self.is_available(),
'supported_models': provider_info.get('models', []),
'available_models': available_models,
'env_key': provider_info.get('env_key', ''),
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
}
async def generate_sigma_rule(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> Optional[str]:
"""
Generate or enhance a SIGMA rule using the configured LLM.
Args:
cve_id: CVE identifier
poc_content: Proof-of-concept code content from GitHub
cve_description: CVE description from NVD
existing_rule: Optional existing SIGMA rule to enhance
Returns:
Generated SIGMA rule YAML content or None if failed
"""
if not self.is_available():
logger.warning("LLM client not available")
return None
try:
# Create the prompt template
prompt = self._build_sigma_generation_prompt(
cve_id, poc_content, cve_description, existing_rule
)
# Create the chain
chain = prompt | self.llm | self.output_parser
# Debug: Log what we're sending to the LLM
input_data = {
"cve_id": cve_id,
"poc_content": poc_content[:4000], # Truncate if too long
"cve_description": cve_description,
"existing_rule": existing_rule or "None"
}
logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}")
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
# Generate the response with timeout handling
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
import asyncio
try:
# Add timeout wrapper around the LLM call
response = await asyncio.wait_for(
chain.ainvoke(input_data),
timeout=150 # 2.5 minutes total timeout
)
except asyncio.TimeoutError:
logger.error(f"LLM request timed out for {cve_id}")
return None
except Exception as llm_error:
logger.error(f"LLM generation error for {cve_id}: {llm_error}")
return None
# Debug: Log raw LLM response
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
# Extract the SIGMA rule from response
sigma_rule = self._extract_sigma_rule(response)
# Post-process to ensure clean YAML
sigma_rule = self._post_process_sigma_rule(sigma_rule)
# Fix common YAML syntax errors
sigma_rule = self._fix_yaml_syntax_errors(sigma_rule)
# CRITICAL: Validate and fix CVE ID hallucination
sigma_rule = self._fix_hallucinated_cve_id(sigma_rule, cve_id)
# Additional fallback: If no CVE ID found, inject it into the rule
if not sigma_rule or 'CVE-' not in sigma_rule:
sigma_rule = self._inject_cve_id_into_rule(sigma_rule, cve_id)
# Debug: Log final processed rule
logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...")
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
return sigma_rule
except Exception as e:
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
return None
def _build_sigma_generation_prompt(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
"""Build the prompt template for SIGMA rule generation."""
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:**
The official SIGMA rule specification (v2.0.0) defines these requirements:
**MANDATORY Fields (must include):**
- title: Brief description (max 256 chars) - string
- logsource: Log data source specification - object with category/product/service
- detection: Search identifiers and conditions - object with selections and condition
**RECOMMENDED Fields:**
- id: Unique UUID (version 4) - string with UUID format
- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported"
- description: Detailed explanation - string
- author: Rule creator - string (use "AI Generated")
- date: Creation date - string in YYYY/MM/DD format
- modified: Last modification date - string in YYYY/MM/DD format
- references: Sources for rule derivation - array of strings (URLs)
- tags: MITRE ATT&CK techniques - array of strings
- level: Rule severity - enum: "informational", "low", "medium", "high", "critical"
- falsepositives: Known false positives - array of strings
- fields: Related fields - array of strings
- related: Related rules - array of objects with type and id
**YAML Structure Requirements:**
- Use UTF-8 encoding with LF line breaks
- Indent with 4 spaces (no tabs)
- Use lowercase keys only
- Use single quotes for string values
- No quotes for numeric values
- Follow proper YAML syntax
**Detection Structure:**
- Use selection blocks (selection, selection1, etc.)
- Condition references these selections
- Use proper field names (Image, CommandLine, ProcessName, etc.)
- Support wildcards (*) and value lists
- Condition can be string expression or object with keywords
**ABSOLUTE REQUIREMENTS:**
- Output ONLY valid YAML
- NO explanatory text before or after
- NO comments or instructions
- NO markdown formatting or code blocks
- NEVER repeat the input prompt or template
- NEVER include variables like {cve_id} or {poc_content}
- NO "Human:", "CVE ID:", "Description:" headers
- NO "Analyze this" or "Output only" text
- Start IMMEDIATELY with 'title:'
- End with the last YAML line only
- Ensure perfect YAML syntax
**STRUCTURE REQUIREMENTS:**
- title: Descriptive title that MUST include the exact CVE ID provided by the user
- id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012')
- status: experimental
- description: Specific description based on CVE and PoC analysis
- author: 'AI Generated'
- date: Current date (2025/01/14)
- references: Include the EXACT CVE URL with the CVE ID provided by the user
- tags: Relevant MITRE ATT&CK techniques based on PoC analysis
- logsource: Appropriate category based on exploit type
- detection: Specific indicators from PoC analysis (NOT generic examples)
- condition: Logic connecting the detection selections
**CRITICAL ANTI-HALLUCINATION RULES:**
1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID
2. NEVER use example CVE IDs like CVE-2022-1234, CVE-2023-5678, or CVE-2024-1234
3. NEVER use placeholder CVE IDs from your training data
4. Analyze the provided CVE description and PoC content to create SPECIFIC detection patterns
5. DO NOT hallucinate or invent CVE IDs from your training data
6. Use the CVE ID exactly as provided in the title and references
7. Generate rules based ONLY on the provided CVE description and PoC code analysis
8. Do not reference vulnerabilities or techniques not present in the provided content
9. CVE-2022-1234 is a FORBIDDEN example CVE ID - NEVER use it
10. The user will provide the EXACT CVE ID to use - use that and ONLY that"""
if existing_rule:
user_template = """CVE ID: {cve_id}
CVE Description: {cve_description}
PoC Code:
{poc_content}
Existing SIGMA Rule:
{existing_rule}
Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'."""
else:
# Get MITRE ATT&CK mappings for the CVE
mitre_mappings = self.cve2capec_client.get_full_mapping_for_cve(cve_id)
mitre_suggestions = ""
if mitre_mappings['mitre_techniques']:
technique_details = []
for tech in mitre_mappings['mitre_techniques']:
tech_name = self.cve2capec_client.get_technique_name(tech)
technique_details.append(f" - {tech}: {tech_name}")
mitre_suggestions = f"""
**MITRE ATT&CK TECHNIQUE MAPPINGS FOR {cve_id}:**
{chr(10).join(technique_details)}
**IMPORTANT:** Use these exact MITRE ATT&CK techniques in your tags section. Convert them to lowercase attack.t format (e.g., T1059 becomes attack.t1059)."""
if mitre_mappings['cwe_codes']:
mitre_suggestions += f"""
**CWE MAPPINGS:** {', '.join(mitre_mappings['cwe_codes'])}"""
user_template = f"""CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE:
**MANDATORY CVE ID TO USE: {{cve_id}}**
**CVE Description: {{cve_description}}**
**Proof-of-Concept Code Analysis:**
{{poc_content}}
{mitre_suggestions}
**CRITICAL REQUIREMENTS:**
1. Use EXACTLY this CVE ID in the title: {{cve_id}}
2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{{cve_id}}
3. Analyze the CVE description to understand the vulnerability type
4. Extract specific indicators from the PoC code (files, processes, commands, network patterns)
5. Create detection logic based on the actual exploit behavior
6. Use relevant logsource category (process_creation, file_event, network_connection, etc.)
7. Include the MITRE ATT&CK tags listed above in your tags section (convert to attack.t format)
**CRITICAL ANTI-HALLUCINATION REQUIREMENTS:**
- THE CVE ID IS: {{cve_id}}
- DO NOT use CVE-2022-1234, CVE-2023-1234, CVE-2024-1234, or any other example CVE ID
- DO NOT generate a different CVE ID from your training data
- You MUST use the exact CVE ID "{{cve_id}}" - this is the ONLY acceptable CVE ID for this rule
- Base your analysis ONLY on the provided CVE description and PoC code above
- Do not reference other vulnerabilities or exploits not mentioned in the provided content
- NEVER use placeholder CVE IDs like CVE-YYYY-NNNN or CVE-2022-1234
**ABSOLUTE REQUIREMENT: THE EXACT CVE ID TO USE IS: {{cve_id}}**
**FORBIDDEN: Do not use CVE-2022-1234, CVE-2023-5678, or any other example CVE ID**
Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {{cve_id}}."""
# Create the prompt template with proper variable definitions
prompt_template = ChatPromptTemplate.from_messages([
("system", system_message),
("human", user_template)
])
return prompt_template
def _extract_sigma_rule(self, response_text: str) -> str:
"""Extract and clean SIGMA rule YAML from LLM response."""
lines = response_text.split('\n')
yaml_lines = []
in_yaml_block = False
found_title = False
for line in lines:
stripped = line.strip()
# Skip code block markers
if stripped.startswith('```'):
if stripped.startswith('```yaml'):
in_yaml_block = True
elif stripped == '```' and in_yaml_block:
break
continue
# Skip obvious non-YAML content
if not in_yaml_block and not found_title:
if not stripped.startswith('title:'):
# Skip explanatory text and prompt artifacts
skip_phrases = [
'please note', 'this rule', 'you should', 'analysis:',
'explanation:', 'based on', 'the following', 'here is',
'note that', 'important:', 'remember', 'this is a',
'make sure to', 'you can modify', 'adjust the',
'human:', 'cve id:', 'cve description:', 'poc code:',
'exploit code:', 'analyze this', 'create a', 'output only'
]
if any(phrase in stripped.lower() for phrase in skip_phrases):
continue
# Skip template variables and prompt artifacts
if '{' in stripped and '}' in stripped:
continue
# Skip lines that contain template placeholder text
if 'cve_id' in stripped.lower() or 'cve description' in stripped.lower():
continue
# Skip lines that are clearly not YAML structure
if stripped and not ':' in stripped and len(stripped) > 20:
continue
# Start collecting when we find title or are in YAML block
if stripped.startswith('title:') or in_yaml_block:
found_title = True
in_yaml_block = True
# Skip explanatory comments
if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()):
continue
yaml_lines.append(line)
# Stop if we encounter obvious non-YAML after starting
elif found_title:
# Stop at explanatory text after the rule
stop_phrases = [
'please note', 'this rule should', 'make sure to',
'you can modify', 'adjust the', 'also, this is',
'based on the analysis', 'the rule above'
]
if any(phrase in stripped.lower() for phrase in stop_phrases):
break
# Stop at lines without colons that aren't indented (likely explanations)
if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped:
break
if not yaml_lines:
# Fallback: look for any line with YAML-like structure
for line in lines:
if ':' in line and not line.strip().startswith('#'):
yaml_lines.append(line)
return '\n'.join(yaml_lines).strip()
def _post_process_sigma_rule(self, rule_content: str) -> str:
"""Post-process SIGMA rule to ensure clean YAML format."""
lines = rule_content.split('\n')
cleaned_lines = []
for line in lines:
stripped = line.strip()
# Skip obvious non-YAML content and prompt artifacts
if any(phrase in stripped.lower() for phrase in [
'please note', 'you should replace', 'this is a proof-of-concept',
'please make sure', 'note that', 'important:', 'remember to',
'analysis shows', 'based on the', 'the rule above', 'this rule',
'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:',
'analyze this', 'create a', 'output only', 'generate a'
]):
continue
# Skip template variables and placeholder text
if '{' in stripped and '}' in stripped:
continue
# Skip lines that contain template placeholder patterns
if any(placeholder in stripped.lower() for placeholder in [
'cve_id', 'cve description', 'poc_content', 'existing_rule',
'{cve_id}', '{cve_description}', '{poc_content}'
]):
continue
# Skip lines that look like explanations
if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '):
# This might be explanatory text, skip it
if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']):
continue
# Skip empty explanatory sections
if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']:
continue
cleaned_lines.append(line)
return '\n'.join(cleaned_lines).strip()
def _fix_yaml_syntax_errors(self, rule_content: str) -> str:
"""Fix common YAML syntax errors in LLM-generated rules."""
import re
if not rule_content:
return rule_content
lines = rule_content.split('\n')
fixed_lines = []
fixes_applied = []
for line in lines:
fixed_line = line
# Fix invalid YAML alias syntax: - *image* -> - '*image*'
# YAML aliases must be alphanumeric, but LLM uses *word* or *multiple words* for wildcards
if '- *' in line and '*' in line:
# Match patterns like "- *image*" or "- *process*" or "- *unpatched system*"
pattern = r'(\s*-\s*)(\*[^*]+\*)'
if re.search(pattern, line):
fixed_line = re.sub(pattern, r"\1'\2'", line)
fixes_applied.append(f"Fixed invalid YAML alias syntax: {line.strip()} -> {fixed_line.strip()}")
# Also fix similar patterns in values: key: *value* -> key: '*value*'
elif re.search(r':\s*\*[^*]+\*\s*$', line) and not re.search(r'[\'"]', line):
pattern = r'(:\s*)(\*[^*]+\*)'
fixed_line = re.sub(pattern, r"\1'\2'", line)
fixes_applied.append(f"Fixed invalid YAML alias in value: {line.strip()} -> {fixed_line.strip()}")
# Fix unquoted strings that start with special characters
elif re.match(r'^\s*-\s*[*&|>]', line):
# If line starts with -, followed by special YAML chars, quote it
parts = line.split('-', 1)
if len(parts) == 2:
indent = parts[0]
content = parts[1].strip()
if content and not content.startswith(("'", '"')):
fixed_line = f"{indent}- '{content}'"
fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}")
# Fix invalid boolean values
elif ': *' in line and not line.strip().startswith('#'):
# Replace ": *something*" with ": '*something*'" if not already quoted
pattern = r'(:\s*)(\*[^*]+\*)'
if not re.search(r'[\'"]', line) and re.search(pattern, line):
fixed_line = re.sub(pattern, r"\1'\2'", line)
fixes_applied.append(f"Fixed unquoted wildcard value: {line.strip()} -> {fixed_line.strip()}")
# Fix missing quotes around values with special characters (but not YAML indicators)
elif re.search(r':\s*[*&]', line) and not re.search(r'[\'"]', line):
# Don't quote YAML multiline indicators (|, >)
if not re.search(r':\s*[|>]\s*$', line):
parts = line.split(':', 1)
if len(parts) == 2:
key = parts[0]
value = parts[1].strip()
if value and not value.startswith(("'", '"', '[', '{')):
fixed_line = f"{key}: '{value}'"
fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}")
# Fix invalid array syntax
elif re.search(r'^\s*\*[^*]+\*\s*$', line):
# Standalone *word* or *multiple words* lines should be quoted
indent = len(line) - len(line.lstrip())
content = line.strip()
fixed_line = f"{' ' * indent}'{content}'"
fixes_applied.append(f"Fixed standalone wildcard: {line.strip()} -> {fixed_line.strip()}")
fixed_lines.append(fixed_line)
result = '\n'.join(fixed_lines)
# Additional fixes for common YAML issues
# Fix missing spaces after colons
colon_fix = re.sub(r':([^\s])', r': \1', result)
if colon_fix != result:
fixes_applied.append("Added missing spaces after colons")
result = colon_fix
# Fix multiple spaces after colons
space_fix = re.sub(r':\s{2,}', ': ', result)
if space_fix != result:
fixes_applied.append("Fixed multiple spaces after colons")
result = space_fix
# Fix incorrect reference format: references: - https://... -> references:\n - https://...
ref_fix = re.sub(r'references:\s*-\s*', 'references:\n - ', result)
if ref_fix != result:
fixes_applied.append("Fixed references array format")
result = ref_fix
# Fix broken URLs in references (spaces in URLs)
url_fix = re.sub(r'https:\s*//nvd\.nist\.gov', 'https://nvd.nist.gov', result)
if url_fix != result:
fixes_applied.append("Fixed broken URLs in references")
result = url_fix
# Fix incorrect logsource format: logsource: category: X -> logsource:\n category: X
logsource_fix = re.sub(r'logsource:\s*(category|product|service):\s*', r'logsource:\n \1: ', result)
if logsource_fix != result:
fixes_applied.append("Fixed logsource structure format")
result = logsource_fix
# Fix incorrect detection format: detection: selection: key: value -> detection:\n selection:\n key: value
detection_fix = re.sub(r'detection:\s*(\w+):\s*(\w+):\s*', r'detection:\n \1:\n \2: ', result)
if detection_fix != result:
fixes_applied.append("Fixed detection structure format")
result = detection_fix
# Fix detection lines with == operators: detection: selection1: image == *value* -> detection:\n selection1:\n image: '*value*'
# This handles compressed syntax with equality operators
# Make the pattern more flexible to catch various formats
detection_eq_patterns = [
(r'detection:\s*(\w+):\s*(\w+)\s*==\s*(\*[^*\s]+\*)', r'detection:\n \1:\n \2: \'\3\''),
(r'detection:\s*(\w+):\s*(\w+)\s*==\s*([^\s]+)', r'detection:\n \1:\n \2: \'\3\''),
]
for pattern, replacement in detection_eq_patterns:
detection_eq_fix = re.sub(pattern, replacement, result)
if detection_eq_fix != result:
fixes_applied.append("Fixed detection equality operator syntax")
result = detection_eq_fix
break
# Fix standalone equality operators in detection sections: key == *value* -> key: '*value*'
# Also handle lines with multiple keys/values separated by colons and ==
lines = result.split('\n')
eq_fixed_lines = []
for line in lines:
original_line = line
# Look for pattern: whitespace + key == *value* or key == value
if ' == ' in line:
# Handle complex patterns like "detection: selection1: image == *value*"
if line.strip().startswith('detection:') and ' == ' in line:
# Split by colons to handle nested structure
parts = line.split(':')
if len(parts) >= 3:
# This looks like "detection: selection1: image == *value*"
base_indent = len(line) - len(line.lstrip())
# Extract the parts
detection_part = parts[0].strip() # "detection"
selection_part = parts[1].strip() # "selection1"
key_value_part = ':'.join(parts[2:]).strip() # "image == *value*"
# Parse the key == value part
if ' == ' in key_value_part:
eq_parts = key_value_part.split(' == ', 1)
key = eq_parts[0].strip()
value = eq_parts[1].strip()
# Quote the value if needed
if value.startswith('*') and value.endswith('*') and not value.startswith("'"):
value = f"'{value}'"
elif not value.startswith(("'", '"', '[', '{')):
value = f"'{value}'"
# Reconstruct as proper YAML
eq_fixed_lines.append(f"{' ' * base_indent}detection:")
eq_fixed_lines.append(f"{' ' * (base_indent + 4)}{selection_part}:")
eq_fixed_lines.append(f"{' ' * (base_indent + 8)}{key}: {value}")
fixes_applied.append(f"Fixed complex detection equality: {selection_part}: {key} == {value}")
continue
# Handle simpler patterns: " key == value"
elif re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line):
match = re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line)
indent = match.group(1)
key = match.group(2)
value = match.group(3).strip()
# Ensure wildcards are quoted
if value.startswith('*') and value.endswith('*') and not value.startswith("'"):
value = f"'{value}'"
elif not value.startswith(("'", '"', '[', '{')):
value = f"'{value}'"
eq_fixed_lines.append(f"{indent}{key}: {value}")
fixes_applied.append(f"Fixed equality operator: {key} == {value}")
continue
eq_fixed_lines.append(original_line)
if len(eq_fixed_lines) != len(lines):
result = '\n'.join(eq_fixed_lines)
# Fix invalid array-as-value syntax: key: - value -> key:\n - value
# This handles cases like "CommandLine: - '*image*'" which should be "CommandLine:\n - '*image*'"
lines = result.split('\n')
fixed_lines = []
for line in lines:
# Look for pattern: whitespace + key: - value
if re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line):
match = re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line)
indent = match.group(1)
key = match.group(2)
value = match.group(3)
# Convert to proper array format
fixed_lines.append(f"{indent}{key}:")
fixed_lines.append(f"{indent} - {value}")
fixes_applied.append(f"Fixed array-as-value syntax: {key}: - {value}")
else:
fixed_lines.append(line)
if len(fixed_lines) != len(lines):
result = '\n'.join(fixed_lines)
# Fix complex nested syntax errors like "selection1: Image: - '*path*': value"
# This should be "selection1:\n Image:\n - '*path*': value"
complex_fix = re.sub(r'^(\s+)(\w+):\s*(\w+):\s*-\s*(.+)$',
r'\1\2:\n\1 \3:\n\1 - \4',
result, flags=re.MULTILINE)
if complex_fix != result:
fixes_applied.append("Fixed complex nested structure syntax")
result = complex_fix
# Fix incorrect tags format: tags: - T1059.001 -> tags:\n - T1059.001
tags_fix = re.sub(r'tags:\s*-\s*', 'tags:\n - ', result)
if tags_fix != result:
fixes_applied.append("Fixed tags array format")
result = tags_fix
# Fix other common single-line array formats
for field in ['falsepositives', 'level', 'related']:
field_pattern = f'{field}:\\s*-\\s*'
field_replacement = f'{field}:\n - '
field_fix = re.sub(field_pattern, field_replacement, result)
if field_fix != result:
fixes_applied.append(f"Fixed {field} array format")
result = field_fix
# Fix placeholder UUID if LLM used the example one
import uuid
placeholder_uuid = '12345678-1234-1234-1234-123456789012'
if placeholder_uuid in result:
new_uuid = str(uuid.uuid4())
result = result.replace(placeholder_uuid, new_uuid)
fixes_applied.append(f"Replaced placeholder UUID with {new_uuid[:8]}...")
# Fix orphaned list items (standalone lines starting with -)
lines = result.split('\n')
fixed_lines = []
for i, line in enumerate(lines):
stripped = line.strip()
# Check for orphaned list items (lines starting with - but not part of an array)
if (stripped.startswith('- ') and
i > 0 and
not lines[i-1].strip().endswith(':') and
':' not in stripped and
not stripped.startswith('- https://')): # Don't remove reference URLs
# Check if this looks like a MITRE ATT&CK tag
if re.match(r'- T\d{4}', stripped):
# Try to find the tags section and add it there
tags_line_found = False
for j in range(len(fixed_lines)-1, -1, -1):
if fixed_lines[j].strip().startswith('tags:'):
# This is an orphaned tag, add it to the tags array
fixed_lines.append(f" {stripped}")
fixes_applied.append(f"Fixed orphaned MITRE tag: {stripped}")
tags_line_found = True
break
if not tags_line_found:
# No tags section found, remove the orphaned item
fixes_applied.append(f"Removed orphaned tag (no tags section): {stripped}")
continue
else:
# Other orphaned list items, remove them
fixes_applied.append(f"Removed orphaned list item: {stripped}")
continue
fixed_lines.append(line)
result = '\n'.join(fixed_lines)
# Final pass: Remove lines that are still malformed and would cause YAML parsing errors
lines = result.split('\n')
final_lines = []
for line in lines:
stripped = line.strip()
# Skip lines that have multiple colons in problematic patterns
if re.search(r':\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*', line):
# This looks like "key: subkey: - 'value': more_stuff" which is malformed
fixes_applied.append(f"Removed malformed nested line: {stripped[:50]}...")
continue
# Skip lines with invalid YAML mapping structures
if re.search(r'^\s*\w+:\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*\w+', line):
fixes_applied.append(f"Removed invalid mapping structure: {stripped[:50]}...")
continue
final_lines.append(line)
if len(final_lines) != len(lines):
result = '\n'.join(final_lines)
# Log if we made any fixes
if fixes_applied:
logger.info(f"Applied YAML syntax fixes: {', '.join(fixes_applied)}")
# Final YAML structure validation and repair
result = self._validate_and_repair_yaml_structure(result, fixes_applied)
return result
def _validate_and_repair_yaml_structure(self, content: str, fixes_applied: list) -> str:
"""Use YAML library to validate and repair structural issues."""
try:
# First, try to parse the YAML to see if it's valid
yaml.safe_load(content)
# If we get here, the YAML is valid
return content
except yaml.YAMLError as e:
logger.warning(f"YAML structure validation failed: {e}")
# Try to repair common structural issues
repaired_content = self._repair_yaml_structure(content, str(e))
# Test if the repair worked
try:
yaml.safe_load(repaired_content)
fixes_applied.append("Repaired YAML document structure")
logger.info("Successfully repaired YAML structure")
return repaired_content
except yaml.YAMLError as e2:
logger.warning(f"YAML repair attempt failed: {e2}")
# Last resort: try to build a minimal valid SIGMA rule
return self._build_minimal_valid_rule(content, fixes_applied)
def _repair_yaml_structure(self, content: str, error_msg: str) -> str:
"""Attempt to repair common YAML structural issues."""
lines = content.split('\n')
repaired_lines = []
# Track indentation levels to detect issues
expected_indent = 0
in_detection = False
detection_indent = 0
for i, line in enumerate(lines):
stripped = line.strip()
current_indent = len(line) - len(line.lstrip())
# Skip empty lines
if not stripped:
repaired_lines.append(line)
continue
# Track if we're in the detection section
if stripped.startswith('detection:'):
in_detection = True
detection_indent = current_indent
repaired_lines.append(line)
continue
elif in_detection and current_indent <= detection_indent and not stripped.startswith(('condition:', 'timeframe:')):
# We've left the detection section
in_detection = False
# Fix indentation issues in detection section
if in_detection:
# Ensure proper indentation for detection subsections
if stripped.startswith(('selection', 'filter', 'condition')):
# This should be indented under detection
if current_indent <= detection_indent:
corrected_line = ' ' * (detection_indent + 4) + stripped
repaired_lines.append(corrected_line)
continue
elif current_indent > detection_indent + 4:
# This might be a detection field that needs proper indentation
if ':' in stripped and not stripped.startswith('-'):
# This looks like a field under a selection
if i > 0 and 'selection' in lines[i-1]:
corrected_line = ' ' * (detection_indent + 8) + stripped
repaired_lines.append(corrected_line)
continue
# Fix lines that start with wrong indentation
if ':' in stripped and not stripped.startswith('-'):
# This is a key-value pair
key = stripped.split(':')[0].strip()
# Top-level keys should not be indented
if key in ['title', 'id', 'status', 'description', 'author', 'date', 'references', 'tags', 'logsource', 'detection', 'falsepositives', 'level']:
if current_indent > 0:
corrected_line = stripped
repaired_lines.append(corrected_line)
continue
repaired_lines.append(line)
return '\n'.join(repaired_lines)
def _build_minimal_valid_rule(self, content: str, fixes_applied: list) -> str:
"""Build a minimal valid SIGMA rule from the content."""
lines = content.split('\n')
# Extract key components
title = "Unknown SIGMA Rule"
rule_id = "00000000-0000-0000-0000-000000000000"
description = "Generated SIGMA rule"
for line in lines:
stripped = line.strip()
if stripped.startswith('title:'):
title = stripped.split(':', 1)[1].strip().strip('"\'')
elif stripped.startswith('id:'):
rule_id = stripped.split(':', 1)[1].strip().strip('"\'')
elif stripped.startswith('description:'):
description = stripped.split(':', 1)[1].strip().strip('"\'')
# Build minimal valid rule
minimal_rule = f"""title: '{title}'
id: {rule_id}
status: experimental
description: '{description}'
author: 'AI Generated'
date: 2025/01/14
references:
- https://example.com
logsource:
category: process_creation
detection:
selection:
Image: '*'
condition: selection
level: medium"""
fixes_applied.append("Built minimal valid SIGMA rule structure")
logger.warning("Generated minimal valid SIGMA rule as fallback")
return minimal_rule
def _fix_hallucinated_cve_id(self, rule_content: str, correct_cve_id: str) -> str:
"""Detect and fix hallucinated CVE IDs in the generated rule."""
import re
# Pattern to match CVE IDs (CVE-YYYY-NNNNN format)
cve_pattern = r'CVE-\d{4}-\d{4,7}'
# Find all CVE IDs in the rule content
found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE)
if found_cves:
# Check if any found CVE is different from the correct one
hallucinated_cves = [cve for cve in found_cves if cve.upper() != correct_cve_id.upper()]
if hallucinated_cves:
logger.error(f"CRITICAL: LLM hallucinated CVE IDs: {hallucinated_cves}, expected: {correct_cve_id}")
logger.error(f"This indicates the LLM is not following the prompt correctly!")
# Replace all hallucinated CVE IDs with the correct one
corrected_content = rule_content
for hallucinated_cve in set(hallucinated_cves): # Use set to avoid duplicates
corrected_content = re.sub(
re.escape(hallucinated_cve),
correct_cve_id,
corrected_content,
flags=re.IGNORECASE
)
logger.info(f"Successfully corrected hallucinated CVE IDs to {correct_cve_id}")
return corrected_content
else:
logger.info(f"CVE ID validation passed: found correct {correct_cve_id}")
else:
# No CVE ID found in rule - this might be an issue, but let's add it
logger.warning(f"No CVE ID found in generated rule for {correct_cve_id}, this might need manual review")
return rule_content
def _inject_cve_id_into_rule(self, rule_content: str, cve_id: str) -> str:
"""Inject CVE ID into a rule that lacks it."""
if not rule_content:
logger.warning(f"Empty rule content for {cve_id}, cannot inject CVE ID")
return rule_content
lines = rule_content.split('\n')
modified_lines = []
for i, line in enumerate(lines):
stripped = line.strip()
# Fix title line if it has placeholders
if stripped.startswith('title:'):
if '{cve_id}' in line.lower() or '{cve_description}' in line.lower():
# Replace with a proper title
modified_lines.append(f"title: 'Detection of {cve_id} exploitation'")
elif cve_id not in line:
# Add CVE ID to existing title
title_text = line.split(':', 1)[1].strip(' \'"')
modified_lines.append(f"title: '{cve_id}: {title_text}'")
else:
modified_lines.append(line)
# Fix references section if it has placeholders
elif stripped.startswith('- https://nvd.nist.gov/vuln/detail/') and '{cve_id}' in line:
modified_lines.append(f" - https://nvd.nist.gov/vuln/detail/{cve_id}")
# Skip lines with template placeholders
elif any(placeholder in line.lower() for placeholder in ['{cve_id}', '{cve_description}', '{poc_content}']):
continue
else:
modified_lines.append(line)
result = '\n'.join(modified_lines)
logger.info(f"Injected CVE ID {cve_id} into rule")
return result
async def enhance_existing_rule(self,
existing_rule: str,
poc_content: str,
cve_id: str) -> Optional[str]:
"""
Enhance an existing SIGMA rule with PoC analysis.
Args:
existing_rule: Existing SIGMA rule YAML
poc_content: PoC code content
cve_id: CVE identifier
Returns:
Enhanced SIGMA rule or None if failed
"""
if not self.is_available():
return None
try:
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
**Task:** Enhance the existing rule by:
1. Adding more specific detection patterns found in the PoC
2. Improving the condition logic
3. Adding relevant tags or MITRE ATT&CK mappings
4. Keeping the rule structure intact but making it more effective
Output ONLY the enhanced SIGMA rule in valid YAML format."""
user_template = """**CVE ID:** {cve_id}
**PoC Code:**
```
{poc_content}
```
**Existing SIGMA Rule:**
```yaml
{existing_rule}
```"""
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
chain = prompt | self.llm | self.output_parser
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:3000],
"existing_rule": existing_rule
})
enhanced_rule = self._extract_sigma_rule(response)
enhanced_rule = self._post_process_sigma_rule(enhanced_rule)
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
return enhanced_rule
except Exception as e:
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
return None
def validate_sigma_rule(self, rule_content: str, expected_cve_id: str = None) -> bool:
"""Validate that the generated rule follows SIGMA specification."""
try:
# Parse as YAML
parsed = yaml.safe_load(rule_content)
if not isinstance(parsed, dict):
logger.warning("Rule must be a YAML dictionary")
return False
# Check MANDATORY fields per SIGMA spec
mandatory_fields = ['title', 'logsource', 'detection']
for field in mandatory_fields:
if field not in parsed:
logger.warning(f"Missing mandatory field: {field}")
return False
# Validate title
title = parsed.get('title', '')
if not isinstance(title, str) or len(title) > 256:
logger.warning("Title must be string ≤256 characters")
return False
# Validate logsource structure
logsource = parsed.get('logsource', {})
if not isinstance(logsource, dict):
logger.warning("Logsource must be a dictionary")
return False
# Validate detection structure
detection = parsed.get('detection', {})
if not isinstance(detection, dict):
logger.warning("Detection must be a dictionary")
return False
# Check for condition (can be in detection or at root)
has_condition = 'condition' in detection or 'condition' in parsed
if not has_condition:
logger.warning("Missing condition field")
return False
# Check for at least one selection
selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter']
for key in detection.keys() if key != 'condition')
if not selection_found:
logger.warning("Detection must have at least one selection")
return False
# Validate status if present
status = parsed.get('status')
if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']:
logger.warning(f"Invalid status: {status}")
return False
# Additional validation: Check for correct CVE ID if provided
if expected_cve_id:
import re
cve_pattern = r'CVE-\d{4}-\d{4,7}'
found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE)
if found_cves:
# Check if all found CVE IDs match the expected one
wrong_cves = [cve for cve in found_cves if cve.upper() != expected_cve_id.upper()]
if wrong_cves:
logger.warning(f"Rule contains wrong CVE IDs: {wrong_cves}, expected {expected_cve_id}")
return False
else:
logger.warning(f"Rule does not contain expected CVE ID: {expected_cve_id}")
# Don't fail validation for missing CVE ID, just warn
logger.info("SIGMA rule validation passed")
return True
except yaml.YAMLError as e:
error_msg = str(e)
if "alias" in error_msg.lower() and "*" in error_msg:
logger.warning(f"YAML alias syntax error (likely unquoted wildcard): {e}")
elif "expected" in error_msg.lower():
logger.warning(f"YAML structure error: {e}")
else:
logger.warning(f"YAML parsing error: {e}")
return False
except Exception as e:
logger.warning(f"Rule validation error: {e}")
return False
@classmethod
def get_available_providers(cls) -> List[Dict[str, Any]]:
"""Get list of available LLM providers and their configuration status."""
providers = []
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
env_key = provider_info.get('env_key', '')
api_key_configured = bool(os.getenv(env_key))
providers.append({
'name': provider_name,
'models': provider_info.get('models', []),
'default_model': provider_info.get('default_model', ''),
'env_key': env_key,
'api_key_configured': api_key_configured,
'available': api_key_configured or provider_name == 'ollama'
})
return providers
def switch_provider(self, provider: str, model: str = None):
"""Switch to a different LLM provider and model."""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.model = model or self._get_default_model(provider)
self._initialize_llm()
logger.info(f"Switched to provider: {provider} with model: {self.model}")
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
"""Check if an Ollama model is available locally"""
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
return True
return False
except Exception as e:
logger.error(f"Error checking Ollama models: {e}")
return False
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
import requests
import json
# Use the pull API endpoint
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=300, # 5 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
if data.get('status'):
logger.info(f"Ollama pull progress: {data.get('status')}")
if data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
return False
except Exception as e:
logger.error(f"Error pulling Ollama model {model}: {e}")
return False
def _get_ollama_available_models(self) -> List[str]:
"""Get list of available Ollama models"""
try:
import requests
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
return [m.get('name', '') for m in models if m.get('name')]
return []
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
return []
async def test_connection(self) -> Dict[str, Any]:
"""Test connection to the configured LLM provider."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
return {
"available": False,
"error": "OpenAI API key not configured",
"models": [],
"current_model": self.model,
"has_api_key": False
}
# Test OpenAI connection without actual API call to avoid timeouts
if self.llm:
return {
"available": True,
"models": self.SUPPORTED_PROVIDERS['openai']['models'],
"current_model": self.model,
"has_api_key": True
}
else:
return {
"available": False,
"error": "OpenAI client not initialized",
"models": [],
"current_model": self.model,
"has_api_key": True
}
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
return {
"available": False,
"error": "Anthropic API key not configured",
"models": [],
"current_model": self.model,
"has_api_key": False
}
# Test Anthropic connection without actual API call to avoid timeouts
if self.llm:
return {
"available": True,
"models": self.SUPPORTED_PROVIDERS['anthropic']['models'],
"current_model": self.model,
"has_api_key": True
}
else:
return {
"available": False,
"error": "Anthropic client not initialized",
"models": [],
"current_model": self.model,
"has_api_key": True
}
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Test Ollama connection
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
available_models = self._get_ollama_available_models()
# Check if model is available using proper model name matching
model_available = self._check_ollama_model_available(base_url, self.model)
return {
"available": model_available,
"models": available_models,
"current_model": self.model,
"base_url": base_url,
"error": None if model_available else f"Model {self.model} not available"
}
else:
return {
"available": False,
"error": f"Ollama server not responding (HTTP {response.status_code})",
"models": [],
"current_model": self.model,
"base_url": base_url
}
except Exception as e:
return {
"available": False,
"error": f"Cannot connect to Ollama server: {str(e)}",
"models": [],
"current_model": self.model,
"base_url": base_url
}
else:
return {
"available": False,
"error": f"Unsupported provider: {self.provider}",
"models": [],
"current_model": self.model
}
except Exception as e:
return {
"available": False,
"error": f"Connection test failed: {str(e)}",
"models": [],
"current_model": self.model
}