1389 lines
No EOL
61 KiB
Python
1389 lines
No EOL
61 KiB
Python
"""
|
|
LangChain-based LLM client for enhanced SIGMA rule generation.
|
|
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
|
|
"""
|
|
import os
|
|
import logging
|
|
from typing import Optional, Dict, Any, List
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_community.llms import Ollama
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
import yaml
|
|
from cve2capec_client import CVE2CAPECClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LLMClient:
|
|
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
|
|
|
|
SUPPORTED_PROVIDERS = {
|
|
'openai': {
|
|
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
|
|
'env_key': 'OPENAI_API_KEY',
|
|
'default_model': 'gpt-4o-mini'
|
|
},
|
|
'anthropic': {
|
|
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
|
|
'env_key': 'ANTHROPIC_API_KEY',
|
|
'default_model': 'claude-3-5-sonnet-20241022'
|
|
},
|
|
'ollama': {
|
|
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
|
|
'env_key': 'OLLAMA_BASE_URL',
|
|
'default_model': 'llama3.2'
|
|
}
|
|
}
|
|
|
|
def __init__(self, provider: str = None, model: str = None):
|
|
"""Initialize LLM client with specified provider and model."""
|
|
self.provider = provider or self._detect_provider()
|
|
self.model = model or self._get_default_model(self.provider)
|
|
self.llm = None
|
|
self.output_parser = StrOutputParser()
|
|
self.cve2capec_client = CVE2CAPECClient()
|
|
|
|
self._initialize_llm()
|
|
|
|
def _detect_provider(self) -> str:
|
|
"""Auto-detect available LLM provider based on environment variables."""
|
|
# Check for API keys in order of preference
|
|
if os.getenv('ANTHROPIC_API_KEY'):
|
|
return 'anthropic'
|
|
elif os.getenv('OPENAI_API_KEY'):
|
|
return 'openai'
|
|
elif os.getenv('OLLAMA_BASE_URL'):
|
|
return 'ollama'
|
|
else:
|
|
# Default to OpenAI if no keys found
|
|
return 'openai'
|
|
|
|
def _get_default_model(self, provider: str) -> str:
|
|
"""Get default model for the specified provider."""
|
|
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
|
|
|
|
def _initialize_llm(self):
|
|
"""Initialize the LLM based on provider and model."""
|
|
try:
|
|
if self.provider == 'openai':
|
|
api_key = os.getenv('OPENAI_API_KEY')
|
|
if not api_key:
|
|
logger.warning("OpenAI API key not found")
|
|
return
|
|
|
|
self.llm = ChatOpenAI(
|
|
model=self.model,
|
|
api_key=api_key,
|
|
temperature=0.1,
|
|
max_tokens=2000
|
|
)
|
|
|
|
elif self.provider == 'anthropic':
|
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
if not api_key:
|
|
logger.warning("Anthropic API key not found")
|
|
return
|
|
|
|
self.llm = ChatAnthropic(
|
|
model=self.model,
|
|
api_key=api_key,
|
|
temperature=0.1,
|
|
max_tokens=2000
|
|
)
|
|
|
|
elif self.provider == 'ollama':
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
|
|
# Check if model is available, if not try to pull it
|
|
if not self._check_ollama_model_available(base_url, self.model):
|
|
logger.info(f"Model {self.model} not found, attempting to pull...")
|
|
if self._pull_ollama_model(base_url, self.model):
|
|
logger.info(f"Successfully pulled model {self.model}")
|
|
else:
|
|
logger.error(f"Failed to pull model {self.model}")
|
|
return
|
|
|
|
self.llm = Ollama(
|
|
model=self.model,
|
|
base_url=base_url,
|
|
temperature=0.1,
|
|
num_ctx=4096, # Context window size
|
|
top_p=0.9,
|
|
top_k=40
|
|
)
|
|
|
|
if self.llm:
|
|
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
|
|
else:
|
|
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing LLM client: {e}")
|
|
self.llm = None
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if LLM client is available and configured."""
|
|
return self.llm is not None
|
|
|
|
def get_provider_info(self) -> Dict[str, Any]:
|
|
"""Get information about the current provider and configuration."""
|
|
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
|
|
|
# For Ollama, get actually available models
|
|
available_models = provider_info.get('models', [])
|
|
if self.provider == 'ollama':
|
|
ollama_models = self._get_ollama_available_models()
|
|
if ollama_models:
|
|
available_models = ollama_models
|
|
|
|
return {
|
|
'provider': self.provider,
|
|
'model': self.model,
|
|
'available': self.is_available(),
|
|
'supported_models': provider_info.get('models', []),
|
|
'available_models': available_models,
|
|
'env_key': provider_info.get('env_key', ''),
|
|
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
|
}
|
|
|
|
async def generate_sigma_rule(self,
|
|
cve_id: str,
|
|
poc_content: str,
|
|
cve_description: str,
|
|
existing_rule: Optional[str] = None) -> Optional[str]:
|
|
"""
|
|
Generate or enhance a SIGMA rule using the configured LLM.
|
|
|
|
Args:
|
|
cve_id: CVE identifier
|
|
poc_content: Proof-of-concept code content from GitHub
|
|
cve_description: CVE description from NVD
|
|
existing_rule: Optional existing SIGMA rule to enhance
|
|
|
|
Returns:
|
|
Generated SIGMA rule YAML content or None if failed
|
|
"""
|
|
if not self.is_available():
|
|
logger.warning("LLM client not available")
|
|
return None
|
|
|
|
try:
|
|
# Create the prompt template
|
|
prompt = self._build_sigma_generation_prompt(
|
|
cve_id, poc_content, cve_description, existing_rule
|
|
)
|
|
|
|
# Create the chain
|
|
chain = prompt | self.llm | self.output_parser
|
|
|
|
# Debug: Log what we're sending to the LLM
|
|
input_data = {
|
|
"cve_id": cve_id,
|
|
"poc_content": poc_content[:4000], # Truncate if too long
|
|
"cve_description": cve_description,
|
|
"existing_rule": existing_rule or "None"
|
|
}
|
|
logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}")
|
|
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
|
|
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
|
|
|
|
# Generate the response with timeout handling
|
|
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
|
|
|
|
import asyncio
|
|
try:
|
|
# Add timeout wrapper around the LLM call
|
|
response = await asyncio.wait_for(
|
|
chain.ainvoke(input_data),
|
|
timeout=150 # 2.5 minutes total timeout
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"LLM request timed out for {cve_id}")
|
|
return None
|
|
except Exception as llm_error:
|
|
logger.error(f"LLM generation error for {cve_id}: {llm_error}")
|
|
return None
|
|
|
|
# Debug: Log raw LLM response
|
|
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
|
|
|
# Extract the SIGMA rule from response
|
|
sigma_rule = self._extract_sigma_rule(response)
|
|
|
|
# Post-process to ensure clean YAML
|
|
sigma_rule = self._post_process_sigma_rule(sigma_rule)
|
|
|
|
# Fix common YAML syntax errors
|
|
sigma_rule = self._fix_yaml_syntax_errors(sigma_rule)
|
|
|
|
# CRITICAL: Validate and fix CVE ID hallucination
|
|
sigma_rule = self._fix_hallucinated_cve_id(sigma_rule, cve_id)
|
|
|
|
# Additional fallback: If no CVE ID found, inject it into the rule
|
|
if not sigma_rule or 'CVE-' not in sigma_rule:
|
|
sigma_rule = self._inject_cve_id_into_rule(sigma_rule, cve_id)
|
|
|
|
# Debug: Log final processed rule
|
|
logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...")
|
|
|
|
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
|
|
return sigma_rule
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
|
|
return None
|
|
|
|
def _build_sigma_generation_prompt(self,
|
|
cve_id: str,
|
|
poc_content: str,
|
|
cve_description: str,
|
|
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
|
|
"""Build the prompt template for SIGMA rule generation."""
|
|
|
|
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
|
|
|
**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:**
|
|
The official SIGMA rule specification (v2.0.0) defines these requirements:
|
|
|
|
**MANDATORY Fields (must include):**
|
|
- title: Brief description (max 256 chars) - string
|
|
- logsource: Log data source specification - object with category/product/service
|
|
- detection: Search identifiers and conditions - object with selections and condition
|
|
|
|
**RECOMMENDED Fields:**
|
|
- id: Unique UUID (version 4) - string with UUID format
|
|
- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported"
|
|
- description: Detailed explanation - string
|
|
- author: Rule creator - string (use "AI Generated")
|
|
- date: Creation date - string in YYYY/MM/DD format
|
|
- modified: Last modification date - string in YYYY/MM/DD format
|
|
- references: Sources for rule derivation - array of strings (URLs)
|
|
- tags: MITRE ATT&CK techniques - array of strings
|
|
- level: Rule severity - enum: "informational", "low", "medium", "high", "critical"
|
|
- falsepositives: Known false positives - array of strings
|
|
- fields: Related fields - array of strings
|
|
- related: Related rules - array of objects with type and id
|
|
|
|
**YAML Structure Requirements:**
|
|
- Use UTF-8 encoding with LF line breaks
|
|
- Indent with 4 spaces (no tabs)
|
|
- Use lowercase keys only
|
|
- Use single quotes for string values
|
|
- No quotes for numeric values
|
|
- Follow proper YAML syntax
|
|
|
|
**Detection Structure:**
|
|
- Use selection blocks (selection, selection1, etc.)
|
|
- Condition references these selections
|
|
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
|
- Support wildcards (*) and value lists
|
|
- Condition can be string expression or object with keywords
|
|
|
|
**ABSOLUTE REQUIREMENTS:**
|
|
- Output ONLY valid YAML
|
|
- NO explanatory text before or after
|
|
- NO comments or instructions
|
|
- NO markdown formatting or code blocks
|
|
- NEVER repeat the input prompt or template
|
|
- NEVER include variables like {cve_id} or {poc_content}
|
|
- NO "Human:", "CVE ID:", "Description:" headers
|
|
- NO "Analyze this" or "Output only" text
|
|
- Start IMMEDIATELY with 'title:'
|
|
- End with the last YAML line only
|
|
- Ensure perfect YAML syntax
|
|
|
|
**STRUCTURE REQUIREMENTS:**
|
|
- title: Descriptive title that MUST include the exact CVE ID provided by the user
|
|
- id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012')
|
|
- status: experimental
|
|
- description: Specific description based on CVE and PoC analysis
|
|
- author: 'AI Generated'
|
|
- date: Current date (2025/01/14)
|
|
- references: Include the EXACT CVE URL with the CVE ID provided by the user
|
|
- tags: Relevant MITRE ATT&CK techniques based on PoC analysis
|
|
- logsource: Appropriate category based on exploit type
|
|
- detection: Specific indicators from PoC analysis (NOT generic examples)
|
|
- condition: Logic connecting the detection selections
|
|
|
|
**CRITICAL ANTI-HALLUCINATION RULES:**
|
|
1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID
|
|
2. NEVER use example CVE IDs like CVE-2022-1234, CVE-2023-5678, or CVE-2024-1234
|
|
3. NEVER use placeholder CVE IDs from your training data
|
|
4. Analyze the provided CVE description and PoC content to create SPECIFIC detection patterns
|
|
5. DO NOT hallucinate or invent CVE IDs from your training data
|
|
6. Use the CVE ID exactly as provided in the title and references
|
|
7. Generate rules based ONLY on the provided CVE description and PoC code analysis
|
|
8. Do not reference vulnerabilities or techniques not present in the provided content
|
|
9. CVE-2022-1234 is a FORBIDDEN example CVE ID - NEVER use it
|
|
10. The user will provide the EXACT CVE ID to use - use that and ONLY that"""
|
|
|
|
if existing_rule:
|
|
user_template = """CVE ID: {cve_id}
|
|
CVE Description: {cve_description}
|
|
|
|
PoC Code:
|
|
{poc_content}
|
|
|
|
Existing SIGMA Rule:
|
|
{existing_rule}
|
|
|
|
Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'."""
|
|
else:
|
|
# Get MITRE ATT&CK mappings for the CVE
|
|
mitre_mappings = self.cve2capec_client.get_full_mapping_for_cve(cve_id)
|
|
mitre_suggestions = ""
|
|
|
|
if mitre_mappings['mitre_techniques']:
|
|
technique_details = []
|
|
for tech in mitre_mappings['mitre_techniques']:
|
|
tech_name = self.cve2capec_client.get_technique_name(tech)
|
|
technique_details.append(f" - {tech}: {tech_name}")
|
|
|
|
mitre_suggestions = f"""
|
|
**MITRE ATT&CK TECHNIQUE MAPPINGS FOR {cve_id}:**
|
|
{chr(10).join(technique_details)}
|
|
|
|
**IMPORTANT:** Use these exact MITRE ATT&CK techniques in your tags section. Convert them to lowercase attack.t format (e.g., T1059 becomes attack.t1059)."""
|
|
|
|
if mitre_mappings['cwe_codes']:
|
|
mitre_suggestions += f"""
|
|
|
|
**CWE MAPPINGS:** {', '.join(mitre_mappings['cwe_codes'])}"""
|
|
|
|
user_template = f"""CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE:
|
|
|
|
**MANDATORY CVE ID TO USE: {{cve_id}}**
|
|
**CVE Description: {{cve_description}}**
|
|
|
|
**Proof-of-Concept Code Analysis:**
|
|
{{poc_content}}
|
|
|
|
{mitre_suggestions}
|
|
|
|
**CRITICAL REQUIREMENTS:**
|
|
1. Use EXACTLY this CVE ID in the title: {{cve_id}}
|
|
2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{{cve_id}}
|
|
3. Analyze the CVE description to understand the vulnerability type
|
|
4. Extract specific indicators from the PoC code (files, processes, commands, network patterns)
|
|
5. Create detection logic based on the actual exploit behavior
|
|
6. Use relevant logsource category (process_creation, file_event, network_connection, etc.)
|
|
7. Include the MITRE ATT&CK tags listed above in your tags section (convert to attack.t format)
|
|
|
|
**CRITICAL ANTI-HALLUCINATION REQUIREMENTS:**
|
|
- THE CVE ID IS: {{cve_id}}
|
|
- DO NOT use CVE-2022-1234, CVE-2023-1234, CVE-2024-1234, or any other example CVE ID
|
|
- DO NOT generate a different CVE ID from your training data
|
|
- You MUST use the exact CVE ID "{{cve_id}}" - this is the ONLY acceptable CVE ID for this rule
|
|
- Base your analysis ONLY on the provided CVE description and PoC code above
|
|
- Do not reference other vulnerabilities or exploits not mentioned in the provided content
|
|
- NEVER use placeholder CVE IDs like CVE-YYYY-NNNN or CVE-2022-1234
|
|
|
|
**ABSOLUTE REQUIREMENT: THE EXACT CVE ID TO USE IS: {{cve_id}}**
|
|
**FORBIDDEN: Do not use CVE-2022-1234, CVE-2023-5678, or any other example CVE ID**
|
|
|
|
Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {{cve_id}}."""
|
|
|
|
# Create the prompt template with proper variable definitions
|
|
prompt_template = ChatPromptTemplate.from_messages([
|
|
("system", system_message),
|
|
("human", user_template)
|
|
])
|
|
|
|
return prompt_template
|
|
|
|
def _extract_sigma_rule(self, response_text: str) -> str:
|
|
"""Extract and clean SIGMA rule YAML from LLM response."""
|
|
lines = response_text.split('\n')
|
|
yaml_lines = []
|
|
in_yaml_block = False
|
|
found_title = False
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Skip code block markers
|
|
if stripped.startswith('```'):
|
|
if stripped.startswith('```yaml'):
|
|
in_yaml_block = True
|
|
elif stripped == '```' and in_yaml_block:
|
|
break
|
|
continue
|
|
|
|
# Skip obvious non-YAML content
|
|
if not in_yaml_block and not found_title:
|
|
if not stripped.startswith('title:'):
|
|
# Skip explanatory text and prompt artifacts
|
|
skip_phrases = [
|
|
'please note', 'this rule', 'you should', 'analysis:',
|
|
'explanation:', 'based on', 'the following', 'here is',
|
|
'note that', 'important:', 'remember', 'this is a',
|
|
'make sure to', 'you can modify', 'adjust the',
|
|
'human:', 'cve id:', 'cve description:', 'poc code:',
|
|
'exploit code:', 'analyze this', 'create a', 'output only'
|
|
]
|
|
if any(phrase in stripped.lower() for phrase in skip_phrases):
|
|
continue
|
|
|
|
# Skip template variables and prompt artifacts
|
|
if '{' in stripped and '}' in stripped:
|
|
continue
|
|
|
|
# Skip lines that contain template placeholder text
|
|
if 'cve_id' in stripped.lower() or 'cve description' in stripped.lower():
|
|
continue
|
|
|
|
# Skip lines that are clearly not YAML structure
|
|
if stripped and not ':' in stripped and len(stripped) > 20:
|
|
continue
|
|
|
|
# Start collecting when we find title or are in YAML block
|
|
if stripped.startswith('title:') or in_yaml_block:
|
|
found_title = True
|
|
in_yaml_block = True
|
|
|
|
# Skip explanatory comments
|
|
if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()):
|
|
continue
|
|
|
|
yaml_lines.append(line)
|
|
|
|
# Stop if we encounter obvious non-YAML after starting
|
|
elif found_title:
|
|
# Stop at explanatory text after the rule
|
|
stop_phrases = [
|
|
'please note', 'this rule should', 'make sure to',
|
|
'you can modify', 'adjust the', 'also, this is',
|
|
'based on the analysis', 'the rule above'
|
|
]
|
|
if any(phrase in stripped.lower() for phrase in stop_phrases):
|
|
break
|
|
|
|
# Stop at lines without colons that aren't indented (likely explanations)
|
|
if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped:
|
|
break
|
|
|
|
if not yaml_lines:
|
|
# Fallback: look for any line with YAML-like structure
|
|
for line in lines:
|
|
if ':' in line and not line.strip().startswith('#'):
|
|
yaml_lines.append(line)
|
|
|
|
return '\n'.join(yaml_lines).strip()
|
|
|
|
def _post_process_sigma_rule(self, rule_content: str) -> str:
|
|
"""Post-process SIGMA rule to ensure clean YAML format."""
|
|
lines = rule_content.split('\n')
|
|
cleaned_lines = []
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Skip obvious non-YAML content and prompt artifacts
|
|
if any(phrase in stripped.lower() for phrase in [
|
|
'please note', 'you should replace', 'this is a proof-of-concept',
|
|
'please make sure', 'note that', 'important:', 'remember to',
|
|
'analysis shows', 'based on the', 'the rule above', 'this rule',
|
|
'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:',
|
|
'analyze this', 'create a', 'output only', 'generate a'
|
|
]):
|
|
continue
|
|
|
|
# Skip template variables and placeholder text
|
|
if '{' in stripped and '}' in stripped:
|
|
continue
|
|
|
|
# Skip lines that contain template placeholder patterns
|
|
if any(placeholder in stripped.lower() for placeholder in [
|
|
'cve_id', 'cve description', 'poc_content', 'existing_rule',
|
|
'{cve_id}', '{cve_description}', '{poc_content}'
|
|
]):
|
|
continue
|
|
|
|
# Skip lines that look like explanations
|
|
if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '):
|
|
# This might be explanatory text, skip it
|
|
if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']):
|
|
continue
|
|
|
|
# Skip empty explanatory sections
|
|
if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']:
|
|
continue
|
|
|
|
cleaned_lines.append(line)
|
|
|
|
return '\n'.join(cleaned_lines).strip()
|
|
|
|
def _fix_yaml_syntax_errors(self, rule_content: str) -> str:
|
|
"""Fix common YAML syntax errors in LLM-generated rules."""
|
|
import re
|
|
|
|
if not rule_content:
|
|
return rule_content
|
|
|
|
lines = rule_content.split('\n')
|
|
fixed_lines = []
|
|
fixes_applied = []
|
|
|
|
for line in lines:
|
|
fixed_line = line
|
|
|
|
# Fix invalid YAML alias syntax: - *image* -> - '*image*'
|
|
# YAML aliases must be alphanumeric, but LLM uses *word* or *multiple words* for wildcards
|
|
if '- *' in line and '*' in line:
|
|
# Match patterns like "- *image*" or "- *process*" or "- *unpatched system*"
|
|
pattern = r'(\s*-\s*)(\*[^*]+\*)'
|
|
if re.search(pattern, line):
|
|
fixed_line = re.sub(pattern, r"\1'\2'", line)
|
|
fixes_applied.append(f"Fixed invalid YAML alias syntax: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
# Also fix similar patterns in values: key: *value* -> key: '*value*'
|
|
elif re.search(r':\s*\*[^*]+\*\s*$', line) and not re.search(r'[\'"]', line):
|
|
pattern = r'(:\s*)(\*[^*]+\*)'
|
|
fixed_line = re.sub(pattern, r"\1'\2'", line)
|
|
fixes_applied.append(f"Fixed invalid YAML alias in value: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
# Fix unquoted strings that start with special characters
|
|
elif re.match(r'^\s*-\s*[*&|>]', line):
|
|
# If line starts with -, followed by special YAML chars, quote it
|
|
parts = line.split('-', 1)
|
|
if len(parts) == 2:
|
|
indent = parts[0]
|
|
content = parts[1].strip()
|
|
if content and not content.startswith(("'", '"')):
|
|
fixed_line = f"{indent}- '{content}'"
|
|
fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
# Fix invalid boolean values
|
|
elif ': *' in line and not line.strip().startswith('#'):
|
|
# Replace ": *something*" with ": '*something*'" if not already quoted
|
|
pattern = r'(:\s*)(\*[^*]+\*)'
|
|
if not re.search(r'[\'"]', line) and re.search(pattern, line):
|
|
fixed_line = re.sub(pattern, r"\1'\2'", line)
|
|
fixes_applied.append(f"Fixed unquoted wildcard value: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
# Fix missing quotes around values with special characters (but not YAML indicators)
|
|
elif re.search(r':\s*[*&]', line) and not re.search(r'[\'"]', line):
|
|
# Don't quote YAML multiline indicators (|, >)
|
|
if not re.search(r':\s*[|>]\s*$', line):
|
|
parts = line.split(':', 1)
|
|
if len(parts) == 2:
|
|
key = parts[0]
|
|
value = parts[1].strip()
|
|
if value and not value.startswith(("'", '"', '[', '{')):
|
|
fixed_line = f"{key}: '{value}'"
|
|
fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
# Fix invalid array syntax
|
|
elif re.search(r'^\s*\*[^*]+\*\s*$', line):
|
|
# Standalone *word* or *multiple words* lines should be quoted
|
|
indent = len(line) - len(line.lstrip())
|
|
content = line.strip()
|
|
fixed_line = f"{' ' * indent}'{content}'"
|
|
fixes_applied.append(f"Fixed standalone wildcard: {line.strip()} -> {fixed_line.strip()}")
|
|
|
|
fixed_lines.append(fixed_line)
|
|
|
|
result = '\n'.join(fixed_lines)
|
|
|
|
# Additional fixes for common YAML issues
|
|
# Fix missing spaces after colons
|
|
colon_fix = re.sub(r':([^\s])', r': \1', result)
|
|
if colon_fix != result:
|
|
fixes_applied.append("Added missing spaces after colons")
|
|
result = colon_fix
|
|
|
|
# Fix multiple spaces after colons
|
|
space_fix = re.sub(r':\s{2,}', ': ', result)
|
|
if space_fix != result:
|
|
fixes_applied.append("Fixed multiple spaces after colons")
|
|
result = space_fix
|
|
|
|
# Fix incorrect reference format: references: - https://... -> references:\n - https://...
|
|
ref_fix = re.sub(r'references:\s*-\s*', 'references:\n - ', result)
|
|
if ref_fix != result:
|
|
fixes_applied.append("Fixed references array format")
|
|
result = ref_fix
|
|
|
|
# Fix broken URLs in references (spaces in URLs)
|
|
url_fix = re.sub(r'https:\s*//nvd\.nist\.gov', 'https://nvd.nist.gov', result)
|
|
if url_fix != result:
|
|
fixes_applied.append("Fixed broken URLs in references")
|
|
result = url_fix
|
|
|
|
# Fix incorrect logsource format: logsource: category: X -> logsource:\n category: X
|
|
logsource_fix = re.sub(r'logsource:\s*(category|product|service):\s*', r'logsource:\n \1: ', result)
|
|
if logsource_fix != result:
|
|
fixes_applied.append("Fixed logsource structure format")
|
|
result = logsource_fix
|
|
|
|
# Fix incorrect detection format: detection: selection: key: value -> detection:\n selection:\n key: value
|
|
detection_fix = re.sub(r'detection:\s*(\w+):\s*(\w+):\s*', r'detection:\n \1:\n \2: ', result)
|
|
if detection_fix != result:
|
|
fixes_applied.append("Fixed detection structure format")
|
|
result = detection_fix
|
|
|
|
# Fix detection lines with == operators: detection: selection1: image == *value* -> detection:\n selection1:\n image: '*value*'
|
|
# This handles compressed syntax with equality operators
|
|
# Make the pattern more flexible to catch various formats
|
|
detection_eq_patterns = [
|
|
(r'detection:\s*(\w+):\s*(\w+)\s*==\s*(\*[^*\s]+\*)', r'detection:\n \1:\n \2: \'\3\''),
|
|
(r'detection:\s*(\w+):\s*(\w+)\s*==\s*([^\s]+)', r'detection:\n \1:\n \2: \'\3\''),
|
|
]
|
|
|
|
for pattern, replacement in detection_eq_patterns:
|
|
detection_eq_fix = re.sub(pattern, replacement, result)
|
|
if detection_eq_fix != result:
|
|
fixes_applied.append("Fixed detection equality operator syntax")
|
|
result = detection_eq_fix
|
|
break
|
|
|
|
# Fix standalone equality operators in detection sections: key == *value* -> key: '*value*'
|
|
# Also handle lines with multiple keys/values separated by colons and ==
|
|
lines = result.split('\n')
|
|
eq_fixed_lines = []
|
|
for line in lines:
|
|
original_line = line
|
|
|
|
# Look for pattern: whitespace + key == *value* or key == value
|
|
if ' == ' in line:
|
|
# Handle complex patterns like "detection: selection1: image == *value*"
|
|
if line.strip().startswith('detection:') and ' == ' in line:
|
|
# Split by colons to handle nested structure
|
|
parts = line.split(':')
|
|
if len(parts) >= 3:
|
|
# This looks like "detection: selection1: image == *value*"
|
|
base_indent = len(line) - len(line.lstrip())
|
|
|
|
# Extract the parts
|
|
detection_part = parts[0].strip() # "detection"
|
|
selection_part = parts[1].strip() # "selection1"
|
|
key_value_part = ':'.join(parts[2:]).strip() # "image == *value*"
|
|
|
|
# Parse the key == value part
|
|
if ' == ' in key_value_part:
|
|
eq_parts = key_value_part.split(' == ', 1)
|
|
key = eq_parts[0].strip()
|
|
value = eq_parts[1].strip()
|
|
|
|
# Quote the value if needed
|
|
if value.startswith('*') and value.endswith('*') and not value.startswith("'"):
|
|
value = f"'{value}'"
|
|
elif not value.startswith(("'", '"', '[', '{')):
|
|
value = f"'{value}'"
|
|
|
|
# Reconstruct as proper YAML
|
|
eq_fixed_lines.append(f"{' ' * base_indent}detection:")
|
|
eq_fixed_lines.append(f"{' ' * (base_indent + 4)}{selection_part}:")
|
|
eq_fixed_lines.append(f"{' ' * (base_indent + 8)}{key}: {value}")
|
|
fixes_applied.append(f"Fixed complex detection equality: {selection_part}: {key} == {value}")
|
|
continue
|
|
|
|
# Handle simpler patterns: " key == value"
|
|
elif re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line):
|
|
match = re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line)
|
|
indent = match.group(1)
|
|
key = match.group(2)
|
|
value = match.group(3).strip()
|
|
|
|
# Ensure wildcards are quoted
|
|
if value.startswith('*') and value.endswith('*') and not value.startswith("'"):
|
|
value = f"'{value}'"
|
|
elif not value.startswith(("'", '"', '[', '{')):
|
|
value = f"'{value}'"
|
|
|
|
eq_fixed_lines.append(f"{indent}{key}: {value}")
|
|
fixes_applied.append(f"Fixed equality operator: {key} == {value}")
|
|
continue
|
|
|
|
eq_fixed_lines.append(original_line)
|
|
|
|
if len(eq_fixed_lines) != len(lines):
|
|
result = '\n'.join(eq_fixed_lines)
|
|
|
|
# Fix invalid array-as-value syntax: key: - value -> key:\n - value
|
|
# This handles cases like "CommandLine: - '*image*'" which should be "CommandLine:\n - '*image*'"
|
|
lines = result.split('\n')
|
|
fixed_lines = []
|
|
for line in lines:
|
|
# Look for pattern: whitespace + key: - value
|
|
if re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line):
|
|
match = re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line)
|
|
indent = match.group(1)
|
|
key = match.group(2)
|
|
value = match.group(3)
|
|
# Convert to proper array format
|
|
fixed_lines.append(f"{indent}{key}:")
|
|
fixed_lines.append(f"{indent} - {value}")
|
|
fixes_applied.append(f"Fixed array-as-value syntax: {key}: - {value}")
|
|
else:
|
|
fixed_lines.append(line)
|
|
|
|
if len(fixed_lines) != len(lines):
|
|
result = '\n'.join(fixed_lines)
|
|
|
|
# Fix complex nested syntax errors like "selection1: Image: - '*path*': value"
|
|
# This should be "selection1:\n Image:\n - '*path*': value"
|
|
complex_fix = re.sub(r'^(\s+)(\w+):\s*(\w+):\s*-\s*(.+)$',
|
|
r'\1\2:\n\1 \3:\n\1 - \4',
|
|
result, flags=re.MULTILINE)
|
|
if complex_fix != result:
|
|
fixes_applied.append("Fixed complex nested structure syntax")
|
|
result = complex_fix
|
|
|
|
# Fix incorrect tags format: tags: - T1059.001 -> tags:\n - T1059.001
|
|
tags_fix = re.sub(r'tags:\s*-\s*', 'tags:\n - ', result)
|
|
if tags_fix != result:
|
|
fixes_applied.append("Fixed tags array format")
|
|
result = tags_fix
|
|
|
|
# Fix other common single-line array formats
|
|
for field in ['falsepositives', 'level', 'related']:
|
|
field_pattern = f'{field}:\\s*-\\s*'
|
|
field_replacement = f'{field}:\n - '
|
|
field_fix = re.sub(field_pattern, field_replacement, result)
|
|
if field_fix != result:
|
|
fixes_applied.append(f"Fixed {field} array format")
|
|
result = field_fix
|
|
|
|
# Fix placeholder UUID if LLM used the example one
|
|
import uuid
|
|
placeholder_uuid = '12345678-1234-1234-1234-123456789012'
|
|
if placeholder_uuid in result:
|
|
new_uuid = str(uuid.uuid4())
|
|
result = result.replace(placeholder_uuid, new_uuid)
|
|
fixes_applied.append(f"Replaced placeholder UUID with {new_uuid[:8]}...")
|
|
|
|
# Fix orphaned list items (standalone lines starting with -)
|
|
lines = result.split('\n')
|
|
fixed_lines = []
|
|
|
|
for i, line in enumerate(lines):
|
|
stripped = line.strip()
|
|
|
|
# Check for orphaned list items (lines starting with - but not part of an array)
|
|
if (stripped.startswith('- ') and
|
|
i > 0 and
|
|
not lines[i-1].strip().endswith(':') and
|
|
':' not in stripped and
|
|
not stripped.startswith('- https://')): # Don't remove reference URLs
|
|
|
|
# Check if this looks like a MITRE ATT&CK tag
|
|
if re.match(r'- T\d{4}', stripped):
|
|
# Try to find the tags section and add it there
|
|
tags_line_found = False
|
|
for j in range(len(fixed_lines)-1, -1, -1):
|
|
if fixed_lines[j].strip().startswith('tags:'):
|
|
# This is an orphaned tag, add it to the tags array
|
|
fixed_lines.append(f" {stripped}")
|
|
fixes_applied.append(f"Fixed orphaned MITRE tag: {stripped}")
|
|
tags_line_found = True
|
|
break
|
|
|
|
if not tags_line_found:
|
|
# No tags section found, remove the orphaned item
|
|
fixes_applied.append(f"Removed orphaned tag (no tags section): {stripped}")
|
|
continue
|
|
else:
|
|
# Other orphaned list items, remove them
|
|
fixes_applied.append(f"Removed orphaned list item: {stripped}")
|
|
continue
|
|
|
|
fixed_lines.append(line)
|
|
|
|
result = '\n'.join(fixed_lines)
|
|
|
|
# Final pass: Remove lines that are still malformed and would cause YAML parsing errors
|
|
lines = result.split('\n')
|
|
final_lines = []
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Skip lines that have multiple colons in problematic patterns
|
|
if re.search(r':\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*', line):
|
|
# This looks like "key: subkey: - 'value': more_stuff" which is malformed
|
|
fixes_applied.append(f"Removed malformed nested line: {stripped[:50]}...")
|
|
continue
|
|
|
|
# Skip lines with invalid YAML mapping structures
|
|
if re.search(r'^\s*\w+:\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*\w+', line):
|
|
fixes_applied.append(f"Removed invalid mapping structure: {stripped[:50]}...")
|
|
continue
|
|
|
|
final_lines.append(line)
|
|
|
|
if len(final_lines) != len(lines):
|
|
result = '\n'.join(final_lines)
|
|
|
|
# Log if we made any fixes
|
|
if fixes_applied:
|
|
logger.info(f"Applied YAML syntax fixes: {', '.join(fixes_applied)}")
|
|
|
|
# Final YAML structure validation and repair
|
|
result = self._validate_and_repair_yaml_structure(result, fixes_applied)
|
|
|
|
return result
|
|
|
|
def _validate_and_repair_yaml_structure(self, content: str, fixes_applied: list) -> str:
|
|
"""Use YAML library to validate and repair structural issues."""
|
|
try:
|
|
# First, try to parse the YAML to see if it's valid
|
|
yaml.safe_load(content)
|
|
# If we get here, the YAML is valid
|
|
return content
|
|
except yaml.YAMLError as e:
|
|
logger.warning(f"YAML structure validation failed: {e}")
|
|
|
|
# Try to repair common structural issues
|
|
repaired_content = self._repair_yaml_structure(content, str(e))
|
|
|
|
# Test if the repair worked
|
|
try:
|
|
yaml.safe_load(repaired_content)
|
|
fixes_applied.append("Repaired YAML document structure")
|
|
logger.info("Successfully repaired YAML structure")
|
|
return repaired_content
|
|
except yaml.YAMLError as e2:
|
|
logger.warning(f"YAML repair attempt failed: {e2}")
|
|
|
|
# Last resort: try to build a minimal valid SIGMA rule
|
|
return self._build_minimal_valid_rule(content, fixes_applied)
|
|
|
|
def _repair_yaml_structure(self, content: str, error_msg: str) -> str:
|
|
"""Attempt to repair common YAML structural issues."""
|
|
lines = content.split('\n')
|
|
repaired_lines = []
|
|
|
|
# Track indentation levels to detect issues
|
|
expected_indent = 0
|
|
in_detection = False
|
|
detection_indent = 0
|
|
|
|
for i, line in enumerate(lines):
|
|
stripped = line.strip()
|
|
current_indent = len(line) - len(line.lstrip())
|
|
|
|
# Skip empty lines
|
|
if not stripped:
|
|
repaired_lines.append(line)
|
|
continue
|
|
|
|
# Track if we're in the detection section
|
|
if stripped.startswith('detection:'):
|
|
in_detection = True
|
|
detection_indent = current_indent
|
|
repaired_lines.append(line)
|
|
continue
|
|
elif in_detection and current_indent <= detection_indent and not stripped.startswith(('condition:', 'timeframe:')):
|
|
# We've left the detection section
|
|
in_detection = False
|
|
|
|
# Fix indentation issues in detection section
|
|
if in_detection:
|
|
# Ensure proper indentation for detection subsections
|
|
if stripped.startswith(('selection', 'filter', 'condition')):
|
|
# This should be indented under detection
|
|
if current_indent <= detection_indent:
|
|
corrected_line = ' ' * (detection_indent + 4) + stripped
|
|
repaired_lines.append(corrected_line)
|
|
continue
|
|
elif current_indent > detection_indent + 4:
|
|
# This might be a detection field that needs proper indentation
|
|
if ':' in stripped and not stripped.startswith('-'):
|
|
# This looks like a field under a selection
|
|
if i > 0 and 'selection' in lines[i-1]:
|
|
corrected_line = ' ' * (detection_indent + 8) + stripped
|
|
repaired_lines.append(corrected_line)
|
|
continue
|
|
|
|
# Fix lines that start with wrong indentation
|
|
if ':' in stripped and not stripped.startswith('-'):
|
|
# This is a key-value pair
|
|
key = stripped.split(':')[0].strip()
|
|
|
|
# Top-level keys should not be indented
|
|
if key in ['title', 'id', 'status', 'description', 'author', 'date', 'references', 'tags', 'logsource', 'detection', 'falsepositives', 'level']:
|
|
if current_indent > 0:
|
|
corrected_line = stripped
|
|
repaired_lines.append(corrected_line)
|
|
continue
|
|
|
|
repaired_lines.append(line)
|
|
|
|
return '\n'.join(repaired_lines)
|
|
|
|
def _build_minimal_valid_rule(self, content: str, fixes_applied: list) -> str:
|
|
"""Build a minimal valid SIGMA rule from the content."""
|
|
lines = content.split('\n')
|
|
|
|
# Extract key components
|
|
title = "Unknown SIGMA Rule"
|
|
rule_id = "00000000-0000-0000-0000-000000000000"
|
|
description = "Generated SIGMA rule"
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
if stripped.startswith('title:'):
|
|
title = stripped.split(':', 1)[1].strip().strip('"\'')
|
|
elif stripped.startswith('id:'):
|
|
rule_id = stripped.split(':', 1)[1].strip().strip('"\'')
|
|
elif stripped.startswith('description:'):
|
|
description = stripped.split(':', 1)[1].strip().strip('"\'')
|
|
|
|
# Build minimal valid rule
|
|
minimal_rule = f"""title: '{title}'
|
|
id: {rule_id}
|
|
status: experimental
|
|
description: '{description}'
|
|
author: 'AI Generated'
|
|
date: 2025/01/14
|
|
references:
|
|
- https://example.com
|
|
logsource:
|
|
category: process_creation
|
|
detection:
|
|
selection:
|
|
Image: '*'
|
|
condition: selection
|
|
level: medium"""
|
|
|
|
fixes_applied.append("Built minimal valid SIGMA rule structure")
|
|
logger.warning("Generated minimal valid SIGMA rule as fallback")
|
|
return minimal_rule
|
|
|
|
def _fix_hallucinated_cve_id(self, rule_content: str, correct_cve_id: str) -> str:
|
|
"""Detect and fix hallucinated CVE IDs in the generated rule."""
|
|
import re
|
|
|
|
# Pattern to match CVE IDs (CVE-YYYY-NNNNN format)
|
|
cve_pattern = r'CVE-\d{4}-\d{4,7}'
|
|
|
|
# Find all CVE IDs in the rule content
|
|
found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE)
|
|
|
|
if found_cves:
|
|
# Check if any found CVE is different from the correct one
|
|
hallucinated_cves = [cve for cve in found_cves if cve.upper() != correct_cve_id.upper()]
|
|
|
|
if hallucinated_cves:
|
|
logger.error(f"CRITICAL: LLM hallucinated CVE IDs: {hallucinated_cves}, expected: {correct_cve_id}")
|
|
logger.error(f"This indicates the LLM is not following the prompt correctly!")
|
|
|
|
# Replace all hallucinated CVE IDs with the correct one
|
|
corrected_content = rule_content
|
|
for hallucinated_cve in set(hallucinated_cves): # Use set to avoid duplicates
|
|
corrected_content = re.sub(
|
|
re.escape(hallucinated_cve),
|
|
correct_cve_id,
|
|
corrected_content,
|
|
flags=re.IGNORECASE
|
|
)
|
|
|
|
logger.info(f"Successfully corrected hallucinated CVE IDs to {correct_cve_id}")
|
|
return corrected_content
|
|
else:
|
|
logger.info(f"CVE ID validation passed: found correct {correct_cve_id}")
|
|
else:
|
|
# No CVE ID found in rule - this might be an issue, but let's add it
|
|
logger.warning(f"No CVE ID found in generated rule for {correct_cve_id}, this might need manual review")
|
|
|
|
return rule_content
|
|
|
|
def _inject_cve_id_into_rule(self, rule_content: str, cve_id: str) -> str:
|
|
"""Inject CVE ID into a rule that lacks it."""
|
|
if not rule_content:
|
|
logger.warning(f"Empty rule content for {cve_id}, cannot inject CVE ID")
|
|
return rule_content
|
|
|
|
lines = rule_content.split('\n')
|
|
modified_lines = []
|
|
|
|
for i, line in enumerate(lines):
|
|
stripped = line.strip()
|
|
|
|
# Fix title line if it has placeholders
|
|
if stripped.startswith('title:'):
|
|
if '{cve_id}' in line.lower() or '{cve_description}' in line.lower():
|
|
# Replace with a proper title
|
|
modified_lines.append(f"title: 'Detection of {cve_id} exploitation'")
|
|
elif cve_id not in line:
|
|
# Add CVE ID to existing title
|
|
title_text = line.split(':', 1)[1].strip(' \'"')
|
|
modified_lines.append(f"title: '{cve_id}: {title_text}'")
|
|
else:
|
|
modified_lines.append(line)
|
|
|
|
# Fix references section if it has placeholders
|
|
elif stripped.startswith('- https://nvd.nist.gov/vuln/detail/') and '{cve_id}' in line:
|
|
modified_lines.append(f" - https://nvd.nist.gov/vuln/detail/{cve_id}")
|
|
|
|
# Skip lines with template placeholders
|
|
elif any(placeholder in line.lower() for placeholder in ['{cve_id}', '{cve_description}', '{poc_content}']):
|
|
continue
|
|
|
|
else:
|
|
modified_lines.append(line)
|
|
|
|
result = '\n'.join(modified_lines)
|
|
logger.info(f"Injected CVE ID {cve_id} into rule")
|
|
return result
|
|
|
|
async def enhance_existing_rule(self,
|
|
existing_rule: str,
|
|
poc_content: str,
|
|
cve_id: str) -> Optional[str]:
|
|
"""
|
|
Enhance an existing SIGMA rule with PoC analysis.
|
|
|
|
Args:
|
|
existing_rule: Existing SIGMA rule YAML
|
|
poc_content: PoC code content
|
|
cve_id: CVE identifier
|
|
|
|
Returns:
|
|
Enhanced SIGMA rule or None if failed
|
|
"""
|
|
if not self.is_available():
|
|
return None
|
|
|
|
try:
|
|
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
|
|
|
|
**Task:** Enhance the existing rule by:
|
|
1. Adding more specific detection patterns found in the PoC
|
|
2. Improving the condition logic
|
|
3. Adding relevant tags or MITRE ATT&CK mappings
|
|
4. Keeping the rule structure intact but making it more effective
|
|
|
|
Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|
|
|
user_template = """**CVE ID:** {cve_id}
|
|
|
|
**PoC Code:**
|
|
```
|
|
{poc_content}
|
|
```
|
|
|
|
**Existing SIGMA Rule:**
|
|
```yaml
|
|
{existing_rule}
|
|
```"""
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
SystemMessage(content=system_message),
|
|
HumanMessage(content=user_template)
|
|
])
|
|
|
|
chain = prompt | self.llm | self.output_parser
|
|
|
|
response = await chain.ainvoke({
|
|
"cve_id": cve_id,
|
|
"poc_content": poc_content[:3000],
|
|
"existing_rule": existing_rule
|
|
})
|
|
|
|
enhanced_rule = self._extract_sigma_rule(response)
|
|
enhanced_rule = self._post_process_sigma_rule(enhanced_rule)
|
|
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
|
|
return enhanced_rule
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
|
|
return None
|
|
|
|
def validate_sigma_rule(self, rule_content: str, expected_cve_id: str = None) -> bool:
|
|
"""Validate that the generated rule follows SIGMA specification."""
|
|
try:
|
|
# Parse as YAML
|
|
parsed = yaml.safe_load(rule_content)
|
|
|
|
if not isinstance(parsed, dict):
|
|
logger.warning("Rule must be a YAML dictionary")
|
|
return False
|
|
|
|
# Check MANDATORY fields per SIGMA spec
|
|
mandatory_fields = ['title', 'logsource', 'detection']
|
|
for field in mandatory_fields:
|
|
if field not in parsed:
|
|
logger.warning(f"Missing mandatory field: {field}")
|
|
return False
|
|
|
|
# Validate title
|
|
title = parsed.get('title', '')
|
|
if not isinstance(title, str) or len(title) > 256:
|
|
logger.warning("Title must be string ≤256 characters")
|
|
return False
|
|
|
|
# Validate logsource structure
|
|
logsource = parsed.get('logsource', {})
|
|
if not isinstance(logsource, dict):
|
|
logger.warning("Logsource must be a dictionary")
|
|
return False
|
|
|
|
# Validate detection structure
|
|
detection = parsed.get('detection', {})
|
|
if not isinstance(detection, dict):
|
|
logger.warning("Detection must be a dictionary")
|
|
return False
|
|
|
|
# Check for condition (can be in detection or at root)
|
|
has_condition = 'condition' in detection or 'condition' in parsed
|
|
if not has_condition:
|
|
logger.warning("Missing condition field")
|
|
return False
|
|
|
|
# Check for at least one selection
|
|
selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter']
|
|
for key in detection.keys() if key != 'condition')
|
|
if not selection_found:
|
|
logger.warning("Detection must have at least one selection")
|
|
return False
|
|
|
|
# Validate status if present
|
|
status = parsed.get('status')
|
|
if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']:
|
|
logger.warning(f"Invalid status: {status}")
|
|
return False
|
|
|
|
# Additional validation: Check for correct CVE ID if provided
|
|
if expected_cve_id:
|
|
import re
|
|
cve_pattern = r'CVE-\d{4}-\d{4,7}'
|
|
found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE)
|
|
|
|
if found_cves:
|
|
# Check if all found CVE IDs match the expected one
|
|
wrong_cves = [cve for cve in found_cves if cve.upper() != expected_cve_id.upper()]
|
|
if wrong_cves:
|
|
logger.warning(f"Rule contains wrong CVE IDs: {wrong_cves}, expected {expected_cve_id}")
|
|
return False
|
|
else:
|
|
logger.warning(f"Rule does not contain expected CVE ID: {expected_cve_id}")
|
|
# Don't fail validation for missing CVE ID, just warn
|
|
|
|
logger.info("SIGMA rule validation passed")
|
|
return True
|
|
|
|
except yaml.YAMLError as e:
|
|
error_msg = str(e)
|
|
if "alias" in error_msg.lower() and "*" in error_msg:
|
|
logger.warning(f"YAML alias syntax error (likely unquoted wildcard): {e}")
|
|
elif "expected" in error_msg.lower():
|
|
logger.warning(f"YAML structure error: {e}")
|
|
else:
|
|
logger.warning(f"YAML parsing error: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"Rule validation error: {e}")
|
|
return False
|
|
|
|
@classmethod
|
|
def get_available_providers(cls) -> List[Dict[str, Any]]:
|
|
"""Get list of available LLM providers and their configuration status."""
|
|
providers = []
|
|
|
|
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
|
|
env_key = provider_info.get('env_key', '')
|
|
api_key_configured = bool(os.getenv(env_key))
|
|
|
|
providers.append({
|
|
'name': provider_name,
|
|
'models': provider_info.get('models', []),
|
|
'default_model': provider_info.get('default_model', ''),
|
|
'env_key': env_key,
|
|
'api_key_configured': api_key_configured,
|
|
'available': api_key_configured or provider_name == 'ollama'
|
|
})
|
|
|
|
return providers
|
|
|
|
def switch_provider(self, provider: str, model: str = None):
|
|
"""Switch to a different LLM provider and model."""
|
|
if provider not in self.SUPPORTED_PROVIDERS:
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|
|
self.provider = provider
|
|
self.model = model or self._get_default_model(provider)
|
|
self._initialize_llm()
|
|
|
|
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
|
|
|
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
|
|
"""Check if an Ollama model is available locally"""
|
|
try:
|
|
import requests
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
for m in models:
|
|
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error checking Ollama models: {e}")
|
|
return False
|
|
|
|
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
|
|
"""Pull an Ollama model"""
|
|
try:
|
|
import requests
|
|
import json
|
|
|
|
# Use the pull API endpoint
|
|
payload = {"name": model}
|
|
response = requests.post(
|
|
f"{base_url}/api/pull",
|
|
json=payload,
|
|
timeout=300, # 5 minutes timeout for model download
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Stream the response to monitor progress
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line.decode('utf-8'))
|
|
if data.get('status'):
|
|
logger.info(f"Ollama pull progress: {data.get('status')}")
|
|
if data.get('error'):
|
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
continue
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error pulling Ollama model {model}: {e}")
|
|
return False
|
|
|
|
def _get_ollama_available_models(self) -> List[str]:
|
|
"""Get list of available Ollama models"""
|
|
try:
|
|
import requests
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
return [m.get('name', '') for m in models if m.get('name')]
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"Error getting Ollama models: {e}")
|
|
return []
|
|
|
|
async def test_connection(self) -> Dict[str, Any]:
|
|
"""Test connection to the configured LLM provider."""
|
|
try:
|
|
if self.provider == 'openai':
|
|
api_key = os.getenv('OPENAI_API_KEY')
|
|
if not api_key:
|
|
return {
|
|
"available": False,
|
|
"error": "OpenAI API key not configured",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"has_api_key": False
|
|
}
|
|
|
|
# Test OpenAI connection without actual API call to avoid timeouts
|
|
if self.llm:
|
|
return {
|
|
"available": True,
|
|
"models": self.SUPPORTED_PROVIDERS['openai']['models'],
|
|
"current_model": self.model,
|
|
"has_api_key": True
|
|
}
|
|
else:
|
|
return {
|
|
"available": False,
|
|
"error": "OpenAI client not initialized",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"has_api_key": True
|
|
}
|
|
|
|
elif self.provider == 'anthropic':
|
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
if not api_key:
|
|
return {
|
|
"available": False,
|
|
"error": "Anthropic API key not configured",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"has_api_key": False
|
|
}
|
|
|
|
# Test Anthropic connection without actual API call to avoid timeouts
|
|
if self.llm:
|
|
return {
|
|
"available": True,
|
|
"models": self.SUPPORTED_PROVIDERS['anthropic']['models'],
|
|
"current_model": self.model,
|
|
"has_api_key": True
|
|
}
|
|
else:
|
|
return {
|
|
"available": False,
|
|
"error": "Anthropic client not initialized",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"has_api_key": True
|
|
}
|
|
|
|
elif self.provider == 'ollama':
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
|
|
# Test Ollama connection
|
|
try:
|
|
import requests
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
available_models = self._get_ollama_available_models()
|
|
# Check if model is available using proper model name matching
|
|
model_available = self._check_ollama_model_available(base_url, self.model)
|
|
|
|
return {
|
|
"available": model_available,
|
|
"models": available_models,
|
|
"current_model": self.model,
|
|
"base_url": base_url,
|
|
"error": None if model_available else f"Model {self.model} not available"
|
|
}
|
|
else:
|
|
return {
|
|
"available": False,
|
|
"error": f"Ollama server not responding (HTTP {response.status_code})",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"base_url": base_url
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"available": False,
|
|
"error": f"Cannot connect to Ollama server: {str(e)}",
|
|
"models": [],
|
|
"current_model": self.model,
|
|
"base_url": base_url
|
|
}
|
|
|
|
else:
|
|
return {
|
|
"available": False,
|
|
"error": f"Unsupported provider: {self.provider}",
|
|
"models": [],
|
|
"current_model": self.model
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"available": False,
|
|
"error": f"Connection test failed: {str(e)}",
|
|
"models": [],
|
|
"current_model": self.model
|
|
} |