""" LangChain-based LLM client for enhanced SIGMA rule generation. Supports multiple LLM providers: OpenAI, Anthropic, and local models. """ import os import logging from typing import Optional, Dict, Any, List from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser import yaml from cve2capec_client import CVE2CAPECClient logger = logging.getLogger(__name__) class LLMClient: """Multi-provider LLM client for SIGMA rule generation using LangChain.""" SUPPORTED_PROVIDERS = { 'openai': { 'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'], 'env_key': 'OPENAI_API_KEY', 'default_model': 'gpt-4o-mini' }, 'anthropic': { 'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'], 'env_key': 'ANTHROPIC_API_KEY', 'default_model': 'claude-3-5-sonnet-20241022' }, 'ollama': { 'models': ['llama3.2', 'codellama', 'mistral', 'llama2'], 'env_key': 'OLLAMA_BASE_URL', 'default_model': 'llama3.2' } } def __init__(self, provider: str = None, model: str = None): """Initialize LLM client with specified provider and model.""" self.provider = provider or self._detect_provider() self.model = model or self._get_default_model(self.provider) self.llm = None self.output_parser = StrOutputParser() self.cve2capec_client = CVE2CAPECClient() self._initialize_llm() def _detect_provider(self) -> str: """Auto-detect available LLM provider based on environment variables.""" # Check for API keys in order of preference if os.getenv('ANTHROPIC_API_KEY'): return 'anthropic' elif os.getenv('OPENAI_API_KEY'): return 'openai' elif os.getenv('OLLAMA_BASE_URL'): return 'ollama' else: # Default to OpenAI if no keys found return 'openai' def _get_default_model(self, provider: str) -> str: """Get default model for the specified provider.""" return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini') def _initialize_llm(self): """Initialize the LLM based on provider and model.""" try: if self.provider == 'openai': api_key = os.getenv('OPENAI_API_KEY') if not api_key: logger.warning("OpenAI API key not found") return self.llm = ChatOpenAI( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: logger.warning("Anthropic API key not found") return self.llm = ChatAnthropic( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') # Check if model is available, if not try to pull it if not self._check_ollama_model_available(base_url, self.model): logger.info(f"Model {self.model} not found, attempting to pull...") if self._pull_ollama_model(base_url, self.model): logger.info(f"Successfully pulled model {self.model}") else: logger.error(f"Failed to pull model {self.model}") return self.llm = Ollama( model=self.model, base_url=base_url, temperature=0.1, num_ctx=4096, # Context window size top_p=0.9, top_k=40 ) if self.llm: logger.info(f"LLM client initialized: {self.provider} with model {self.model}") else: logger.error(f"Failed to initialize LLM client for provider: {self.provider}") except Exception as e: logger.error(f"Error initializing LLM client: {e}") self.llm = None def is_available(self) -> bool: """Check if LLM client is available and configured.""" return self.llm is not None def get_provider_info(self) -> Dict[str, Any]: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) # For Ollama, get actually available models available_models = provider_info.get('models', []) if self.provider == 'ollama': ollama_models = self._get_ollama_available_models() if ollama_models: available_models = ollama_models return { 'provider': self.provider, 'model': self.model, 'available': self.is_available(), 'supported_models': provider_info.get('models', []), 'available_models': available_models, 'env_key': provider_info.get('env_key', ''), 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) } async def generate_sigma_rule(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> Optional[str]: """ Generate or enhance a SIGMA rule using the configured LLM. Args: cve_id: CVE identifier poc_content: Proof-of-concept code content from GitHub cve_description: CVE description from NVD existing_rule: Optional existing SIGMA rule to enhance Returns: Generated SIGMA rule YAML content or None if failed """ if not self.is_available(): logger.warning("LLM client not available") return None try: # Create the prompt template prompt = self._build_sigma_generation_prompt( cve_id, poc_content, cve_description, existing_rule ) # Create the chain chain = prompt | self.llm | self.output_parser # Debug: Log what we're sending to the LLM input_data = { "cve_id": cve_id, "poc_content": poc_content[:4000], # Truncate if too long "cve_description": cve_description, "existing_rule": existing_rule or "None" } logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}") logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...") logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...") # Generate the response with timeout handling logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}") import asyncio try: # Add timeout wrapper around the LLM call response = await asyncio.wait_for( chain.ainvoke(input_data), timeout=150 # 2.5 minutes total timeout ) except asyncio.TimeoutError: logger.error(f"LLM request timed out for {cve_id}") return None except Exception as llm_error: logger.error(f"LLM generation error for {cve_id}: {llm_error}") return None # Debug: Log raw LLM response logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...") # Extract the SIGMA rule from response sigma_rule = self._extract_sigma_rule(response) # Post-process to ensure clean YAML sigma_rule = self._post_process_sigma_rule(sigma_rule) # Fix common YAML syntax errors sigma_rule = self._fix_yaml_syntax_errors(sigma_rule) # CRITICAL: Validate and fix CVE ID hallucination sigma_rule = self._fix_hallucinated_cve_id(sigma_rule, cve_id) # Additional fallback: If no CVE ID found, inject it into the rule if not sigma_rule or 'CVE-' not in sigma_rule: sigma_rule = self._inject_cve_id_into_rule(sigma_rule, cve_id) # Debug: Log final processed rule logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...") logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}") return sigma_rule except Exception as e: logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}") return None def _build_sigma_generation_prompt(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> ChatPromptTemplate: """Build the prompt template for SIGMA rule generation.""" system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification. **OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:** The official SIGMA rule specification (v2.0.0) defines these requirements: **MANDATORY Fields (must include):** - title: Brief description (max 256 chars) - string - logsource: Log data source specification - object with category/product/service - detection: Search identifiers and conditions - object with selections and condition **RECOMMENDED Fields:** - id: Unique UUID (version 4) - string with UUID format - status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported" - description: Detailed explanation - string - author: Rule creator - string (use "AI Generated") - date: Creation date - string in YYYY/MM/DD format - modified: Last modification date - string in YYYY/MM/DD format - references: Sources for rule derivation - array of strings (URLs) - tags: MITRE ATT&CK techniques - array of strings - level: Rule severity - enum: "informational", "low", "medium", "high", "critical" - falsepositives: Known false positives - array of strings - fields: Related fields - array of strings - related: Related rules - array of objects with type and id **YAML Structure Requirements:** - Use UTF-8 encoding with LF line breaks - Indent with 4 spaces (no tabs) - Use lowercase keys only - Use single quotes for string values - No quotes for numeric values - Follow proper YAML syntax **Detection Structure:** - Use selection blocks (selection, selection1, etc.) - Condition references these selections - Use proper field names (Image, CommandLine, ProcessName, etc.) - Support wildcards (*) and value lists - Condition can be string expression or object with keywords **ABSOLUTE REQUIREMENTS:** - Output ONLY valid YAML - NO explanatory text before or after - NO comments or instructions - NO markdown formatting or code blocks - NEVER repeat the input prompt or template - NEVER include variables like {cve_id} or {poc_content} - NO "Human:", "CVE ID:", "Description:" headers - NO "Analyze this" or "Output only" text - Start IMMEDIATELY with 'title:' - End with the last YAML line only - Ensure perfect YAML syntax **STRUCTURE REQUIREMENTS:** - title: Descriptive title that MUST include the exact CVE ID provided by the user - id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012') - status: experimental - description: Specific description based on CVE and PoC analysis - author: 'AI Generated' - date: Current date (2025/01/14) - references: Include the EXACT CVE URL with the CVE ID provided by the user - tags: Relevant MITRE ATT&CK techniques based on PoC analysis - logsource: Appropriate category based on exploit type - detection: Specific indicators from PoC analysis (NOT generic examples) - condition: Logic connecting the detection selections **CRITICAL ANTI-HALLUCINATION RULES:** 1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID 2. NEVER use example CVE IDs like CVE-2022-1234, CVE-2023-5678, or CVE-2024-1234 3. NEVER use placeholder CVE IDs from your training data 4. Analyze the provided CVE description and PoC content to create SPECIFIC detection patterns 5. DO NOT hallucinate or invent CVE IDs from your training data 6. Use the CVE ID exactly as provided in the title and references 7. Generate rules based ONLY on the provided CVE description and PoC code analysis 8. Do not reference vulnerabilities or techniques not present in the provided content 9. CVE-2022-1234 is a FORBIDDEN example CVE ID - NEVER use it 10. The user will provide the EXACT CVE ID to use - use that and ONLY that""" if existing_rule: user_template = """CVE ID: {cve_id} CVE Description: {cve_description} PoC Code: {poc_content} Existing SIGMA Rule: {existing_rule} Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'.""" else: # Get MITRE ATT&CK mappings for the CVE mitre_mappings = self.cve2capec_client.get_full_mapping_for_cve(cve_id) mitre_suggestions = "" if mitre_mappings['mitre_techniques']: technique_details = [] for tech in mitre_mappings['mitre_techniques']: tech_name = self.cve2capec_client.get_technique_name(tech) technique_details.append(f" - {tech}: {tech_name}") mitre_suggestions = f""" **MITRE ATT&CK TECHNIQUE MAPPINGS FOR {cve_id}:** {chr(10).join(technique_details)} **IMPORTANT:** Use these exact MITRE ATT&CK techniques in your tags section. Convert them to lowercase attack.t format (e.g., T1059 becomes attack.t1059).""" if mitre_mappings['cwe_codes']: mitre_suggestions += f""" **CWE MAPPINGS:** {', '.join(mitre_mappings['cwe_codes'])}""" user_template = f"""CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE: **MANDATORY CVE ID TO USE: {{cve_id}}** **CVE Description: {{cve_description}}** **Proof-of-Concept Code Analysis:** {{poc_content}} {mitre_suggestions} **CRITICAL REQUIREMENTS:** 1. Use EXACTLY this CVE ID in the title: {{cve_id}} 2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{{cve_id}} 3. Analyze the CVE description to understand the vulnerability type 4. Extract specific indicators from the PoC code (files, processes, commands, network patterns) 5. Create detection logic based on the actual exploit behavior 6. Use relevant logsource category (process_creation, file_event, network_connection, etc.) 7. Include the MITRE ATT&CK tags listed above in your tags section (convert to attack.t format) **CRITICAL ANTI-HALLUCINATION REQUIREMENTS:** - THE CVE ID IS: {{cve_id}} - DO NOT use CVE-2022-1234, CVE-2023-1234, CVE-2024-1234, or any other example CVE ID - DO NOT generate a different CVE ID from your training data - You MUST use the exact CVE ID "{{cve_id}}" - this is the ONLY acceptable CVE ID for this rule - Base your analysis ONLY on the provided CVE description and PoC code above - Do not reference other vulnerabilities or exploits not mentioned in the provided content - NEVER use placeholder CVE IDs like CVE-YYYY-NNNN or CVE-2022-1234 **ABSOLUTE REQUIREMENT: THE EXACT CVE ID TO USE IS: {{cve_id}}** **FORBIDDEN: Do not use CVE-2022-1234, CVE-2023-5678, or any other example CVE ID** Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {{cve_id}}.""" # Create the prompt template with proper variable definitions prompt_template = ChatPromptTemplate.from_messages([ ("system", system_message), ("human", user_template) ]) return prompt_template def _extract_sigma_rule(self, response_text: str) -> str: """Extract and clean SIGMA rule YAML from LLM response.""" lines = response_text.split('\n') yaml_lines = [] in_yaml_block = False found_title = False for line in lines: stripped = line.strip() # Skip code block markers if stripped.startswith('```'): if stripped.startswith('```yaml'): in_yaml_block = True elif stripped == '```' and in_yaml_block: break continue # Skip obvious non-YAML content if not in_yaml_block and not found_title: if not stripped.startswith('title:'): # Skip explanatory text and prompt artifacts skip_phrases = [ 'please note', 'this rule', 'you should', 'analysis:', 'explanation:', 'based on', 'the following', 'here is', 'note that', 'important:', 'remember', 'this is a', 'make sure to', 'you can modify', 'adjust the', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only' ] if any(phrase in stripped.lower() for phrase in skip_phrases): continue # Skip template variables and prompt artifacts if '{' in stripped and '}' in stripped: continue # Skip lines that contain template placeholder text if 'cve_id' in stripped.lower() or 'cve description' in stripped.lower(): continue # Skip lines that are clearly not YAML structure if stripped and not ':' in stripped and len(stripped) > 20: continue # Start collecting when we find title or are in YAML block if stripped.startswith('title:') or in_yaml_block: found_title = True in_yaml_block = True # Skip explanatory comments if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()): continue yaml_lines.append(line) # Stop if we encounter obvious non-YAML after starting elif found_title: # Stop at explanatory text after the rule stop_phrases = [ 'please note', 'this rule should', 'make sure to', 'you can modify', 'adjust the', 'also, this is', 'based on the analysis', 'the rule above' ] if any(phrase in stripped.lower() for phrase in stop_phrases): break # Stop at lines without colons that aren't indented (likely explanations) if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped: break if not yaml_lines: # Fallback: look for any line with YAML-like structure for line in lines: if ':' in line and not line.strip().startswith('#'): yaml_lines.append(line) return '\n'.join(yaml_lines).strip() def _post_process_sigma_rule(self, rule_content: str) -> str: """Post-process SIGMA rule to ensure clean YAML format.""" lines = rule_content.split('\n') cleaned_lines = [] for line in lines: stripped = line.strip() # Skip obvious non-YAML content and prompt artifacts if any(phrase in stripped.lower() for phrase in [ 'please note', 'you should replace', 'this is a proof-of-concept', 'please make sure', 'note that', 'important:', 'remember to', 'analysis shows', 'based on the', 'the rule above', 'this rule', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only', 'generate a' ]): continue # Skip template variables and placeholder text if '{' in stripped and '}' in stripped: continue # Skip lines that contain template placeholder patterns if any(placeholder in stripped.lower() for placeholder in [ 'cve_id', 'cve description', 'poc_content', 'existing_rule', '{cve_id}', '{cve_description}', '{poc_content}' ]): continue # Skip lines that look like explanations if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '): # This might be explanatory text, skip it if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']): continue # Skip empty explanatory sections if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']: continue cleaned_lines.append(line) return '\n'.join(cleaned_lines).strip() def _fix_yaml_syntax_errors(self, rule_content: str) -> str: """Fix common YAML syntax errors in LLM-generated rules.""" import re if not rule_content: return rule_content lines = rule_content.split('\n') fixed_lines = [] fixes_applied = [] for line in lines: fixed_line = line # Fix invalid YAML alias syntax: - *image* -> - '*image*' # YAML aliases must be alphanumeric, but LLM uses *word* or *multiple words* for wildcards if '- *' in line and '*' in line: # Match patterns like "- *image*" or "- *process*" or "- *unpatched system*" pattern = r'(\s*-\s*)(\*[^*]+\*)' if re.search(pattern, line): fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed invalid YAML alias syntax: {line.strip()} -> {fixed_line.strip()}") # Also fix similar patterns in values: key: *value* -> key: '*value*' elif re.search(r':\s*\*[^*]+\*\s*$', line) and not re.search(r'[\'"]', line): pattern = r'(:\s*)(\*[^*]+\*)' fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed invalid YAML alias in value: {line.strip()} -> {fixed_line.strip()}") # Fix unquoted strings that start with special characters elif re.match(r'^\s*-\s*[*&|>]', line): # If line starts with -, followed by special YAML chars, quote it parts = line.split('-', 1) if len(parts) == 2: indent = parts[0] content = parts[1].strip() if content and not content.startswith(("'", '"')): fixed_line = f"{indent}- '{content}'" fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}") # Fix invalid boolean values elif ': *' in line and not line.strip().startswith('#'): # Replace ": *something*" with ": '*something*'" if not already quoted pattern = r'(:\s*)(\*[^*]+\*)' if not re.search(r'[\'"]', line) and re.search(pattern, line): fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed unquoted wildcard value: {line.strip()} -> {fixed_line.strip()}") # Fix missing quotes around values with special characters (but not YAML indicators) elif re.search(r':\s*[*&]', line) and not re.search(r'[\'"]', line): # Don't quote YAML multiline indicators (|, >) if not re.search(r':\s*[|>]\s*$', line): parts = line.split(':', 1) if len(parts) == 2: key = parts[0] value = parts[1].strip() if value and not value.startswith(("'", '"', '[', '{')): fixed_line = f"{key}: '{value}'" fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}") # Fix invalid array syntax elif re.search(r'^\s*\*[^*]+\*\s*$', line): # Standalone *word* or *multiple words* lines should be quoted indent = len(line) - len(line.lstrip()) content = line.strip() fixed_line = f"{' ' * indent}'{content}'" fixes_applied.append(f"Fixed standalone wildcard: {line.strip()} -> {fixed_line.strip()}") fixed_lines.append(fixed_line) result = '\n'.join(fixed_lines) # Additional fixes for common YAML issues # Fix missing spaces after colons colon_fix = re.sub(r':([^\s])', r': \1', result) if colon_fix != result: fixes_applied.append("Added missing spaces after colons") result = colon_fix # Fix multiple spaces after colons space_fix = re.sub(r':\s{2,}', ': ', result) if space_fix != result: fixes_applied.append("Fixed multiple spaces after colons") result = space_fix # Fix incorrect reference format: references: - https://... -> references:\n - https://... ref_fix = re.sub(r'references:\s*-\s*', 'references:\n - ', result) if ref_fix != result: fixes_applied.append("Fixed references array format") result = ref_fix # Fix broken URLs in references (spaces in URLs) url_fix = re.sub(r'https:\s*//nvd\.nist\.gov', 'https://nvd.nist.gov', result) if url_fix != result: fixes_applied.append("Fixed broken URLs in references") result = url_fix # Fix incorrect logsource format: logsource: category: X -> logsource:\n category: X logsource_fix = re.sub(r'logsource:\s*(category|product|service):\s*', r'logsource:\n \1: ', result) if logsource_fix != result: fixes_applied.append("Fixed logsource structure format") result = logsource_fix # Fix incorrect detection format: detection: selection: key: value -> detection:\n selection:\n key: value detection_fix = re.sub(r'detection:\s*(\w+):\s*(\w+):\s*', r'detection:\n \1:\n \2: ', result) if detection_fix != result: fixes_applied.append("Fixed detection structure format") result = detection_fix # Fix detection lines with == operators: detection: selection1: image == *value* -> detection:\n selection1:\n image: '*value*' # This handles compressed syntax with equality operators # Make the pattern more flexible to catch various formats detection_eq_patterns = [ (r'detection:\s*(\w+):\s*(\w+)\s*==\s*(\*[^*\s]+\*)', r'detection:\n \1:\n \2: \'\3\''), (r'detection:\s*(\w+):\s*(\w+)\s*==\s*([^\s]+)', r'detection:\n \1:\n \2: \'\3\''), ] for pattern, replacement in detection_eq_patterns: detection_eq_fix = re.sub(pattern, replacement, result) if detection_eq_fix != result: fixes_applied.append("Fixed detection equality operator syntax") result = detection_eq_fix break # Fix standalone equality operators in detection sections: key == *value* -> key: '*value*' # Also handle lines with multiple keys/values separated by colons and == lines = result.split('\n') eq_fixed_lines = [] for line in lines: original_line = line # Look for pattern: whitespace + key == *value* or key == value if ' == ' in line: # Handle complex patterns like "detection: selection1: image == *value*" if line.strip().startswith('detection:') and ' == ' in line: # Split by colons to handle nested structure parts = line.split(':') if len(parts) >= 3: # This looks like "detection: selection1: image == *value*" base_indent = len(line) - len(line.lstrip()) # Extract the parts detection_part = parts[0].strip() # "detection" selection_part = parts[1].strip() # "selection1" key_value_part = ':'.join(parts[2:]).strip() # "image == *value*" # Parse the key == value part if ' == ' in key_value_part: eq_parts = key_value_part.split(' == ', 1) key = eq_parts[0].strip() value = eq_parts[1].strip() # Quote the value if needed if value.startswith('*') and value.endswith('*') and not value.startswith("'"): value = f"'{value}'" elif not value.startswith(("'", '"', '[', '{')): value = f"'{value}'" # Reconstruct as proper YAML eq_fixed_lines.append(f"{' ' * base_indent}detection:") eq_fixed_lines.append(f"{' ' * (base_indent + 4)}{selection_part}:") eq_fixed_lines.append(f"{' ' * (base_indent + 8)}{key}: {value}") fixes_applied.append(f"Fixed complex detection equality: {selection_part}: {key} == {value}") continue # Handle simpler patterns: " key == value" elif re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line): match = re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line) indent = match.group(1) key = match.group(2) value = match.group(3).strip() # Ensure wildcards are quoted if value.startswith('*') and value.endswith('*') and not value.startswith("'"): value = f"'{value}'" elif not value.startswith(("'", '"', '[', '{')): value = f"'{value}'" eq_fixed_lines.append(f"{indent}{key}: {value}") fixes_applied.append(f"Fixed equality operator: {key} == {value}") continue eq_fixed_lines.append(original_line) if len(eq_fixed_lines) != len(lines): result = '\n'.join(eq_fixed_lines) # Fix invalid array-as-value syntax: key: - value -> key:\n - value # This handles cases like "CommandLine: - '*image*'" which should be "CommandLine:\n - '*image*'" lines = result.split('\n') fixed_lines = [] for line in lines: # Look for pattern: whitespace + key: - value if re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line): match = re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line) indent = match.group(1) key = match.group(2) value = match.group(3) # Convert to proper array format fixed_lines.append(f"{indent}{key}:") fixed_lines.append(f"{indent} - {value}") fixes_applied.append(f"Fixed array-as-value syntax: {key}: - {value}") else: fixed_lines.append(line) if len(fixed_lines) != len(lines): result = '\n'.join(fixed_lines) # Fix complex nested syntax errors like "selection1: Image: - '*path*': value" # This should be "selection1:\n Image:\n - '*path*': value" complex_fix = re.sub(r'^(\s+)(\w+):\s*(\w+):\s*-\s*(.+)$', r'\1\2:\n\1 \3:\n\1 - \4', result, flags=re.MULTILINE) if complex_fix != result: fixes_applied.append("Fixed complex nested structure syntax") result = complex_fix # Fix incorrect tags format: tags: - T1059.001 -> tags:\n - T1059.001 tags_fix = re.sub(r'tags:\s*-\s*', 'tags:\n - ', result) if tags_fix != result: fixes_applied.append("Fixed tags array format") result = tags_fix # Fix other common single-line array formats for field in ['falsepositives', 'level', 'related']: field_pattern = f'{field}:\\s*-\\s*' field_replacement = f'{field}:\n - ' field_fix = re.sub(field_pattern, field_replacement, result) if field_fix != result: fixes_applied.append(f"Fixed {field} array format") result = field_fix # Fix placeholder UUID if LLM used the example one import uuid placeholder_uuid = '12345678-1234-1234-1234-123456789012' if placeholder_uuid in result: new_uuid = str(uuid.uuid4()) result = result.replace(placeholder_uuid, new_uuid) fixes_applied.append(f"Replaced placeholder UUID with {new_uuid[:8]}...") # Fix orphaned list items (standalone lines starting with -) lines = result.split('\n') fixed_lines = [] for i, line in enumerate(lines): stripped = line.strip() # Check for orphaned list items (lines starting with - but not part of an array) if (stripped.startswith('- ') and i > 0 and not lines[i-1].strip().endswith(':') and ':' not in stripped and not stripped.startswith('- https://')): # Don't remove reference URLs # Check if this looks like a MITRE ATT&CK tag if re.match(r'- T\d{4}', stripped): # Try to find the tags section and add it there tags_line_found = False for j in range(len(fixed_lines)-1, -1, -1): if fixed_lines[j].strip().startswith('tags:'): # This is an orphaned tag, add it to the tags array fixed_lines.append(f" {stripped}") fixes_applied.append(f"Fixed orphaned MITRE tag: {stripped}") tags_line_found = True break if not tags_line_found: # No tags section found, remove the orphaned item fixes_applied.append(f"Removed orphaned tag (no tags section): {stripped}") continue else: # Other orphaned list items, remove them fixes_applied.append(f"Removed orphaned list item: {stripped}") continue fixed_lines.append(line) result = '\n'.join(fixed_lines) # Final pass: Remove lines that are still malformed and would cause YAML parsing errors lines = result.split('\n') final_lines = [] for line in lines: stripped = line.strip() # Skip lines that have multiple colons in problematic patterns if re.search(r':\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*', line): # This looks like "key: subkey: - 'value': more_stuff" which is malformed fixes_applied.append(f"Removed malformed nested line: {stripped[:50]}...") continue # Skip lines with invalid YAML mapping structures if re.search(r'^\s*\w+:\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*\w+', line): fixes_applied.append(f"Removed invalid mapping structure: {stripped[:50]}...") continue final_lines.append(line) if len(final_lines) != len(lines): result = '\n'.join(final_lines) # Log if we made any fixes if fixes_applied: logger.info(f"Applied YAML syntax fixes: {', '.join(fixes_applied)}") # Final YAML structure validation and repair result = self._validate_and_repair_yaml_structure(result, fixes_applied) return result def _validate_and_repair_yaml_structure(self, content: str, fixes_applied: list) -> str: """Use YAML library to validate and repair structural issues.""" try: # First, try to parse the YAML to see if it's valid yaml.safe_load(content) # If we get here, the YAML is valid return content except yaml.YAMLError as e: logger.warning(f"YAML structure validation failed: {e}") # Try to repair common structural issues repaired_content = self._repair_yaml_structure(content, str(e)) # Test if the repair worked try: yaml.safe_load(repaired_content) fixes_applied.append("Repaired YAML document structure") logger.info("Successfully repaired YAML structure") return repaired_content except yaml.YAMLError as e2: logger.warning(f"YAML repair attempt failed: {e2}") # Last resort: try to build a minimal valid SIGMA rule return self._build_minimal_valid_rule(content, fixes_applied) def _repair_yaml_structure(self, content: str, error_msg: str) -> str: """Attempt to repair common YAML structural issues.""" lines = content.split('\n') repaired_lines = [] # Track indentation levels to detect issues expected_indent = 0 in_detection = False detection_indent = 0 for i, line in enumerate(lines): stripped = line.strip() current_indent = len(line) - len(line.lstrip()) # Skip empty lines if not stripped: repaired_lines.append(line) continue # Track if we're in the detection section if stripped.startswith('detection:'): in_detection = True detection_indent = current_indent repaired_lines.append(line) continue elif in_detection and current_indent <= detection_indent and not stripped.startswith(('condition:', 'timeframe:')): # We've left the detection section in_detection = False # Fix indentation issues in detection section if in_detection: # Ensure proper indentation for detection subsections if stripped.startswith(('selection', 'filter', 'condition')): # This should be indented under detection if current_indent <= detection_indent: corrected_line = ' ' * (detection_indent + 4) + stripped repaired_lines.append(corrected_line) continue elif current_indent > detection_indent + 4: # This might be a detection field that needs proper indentation if ':' in stripped and not stripped.startswith('-'): # This looks like a field under a selection if i > 0 and 'selection' in lines[i-1]: corrected_line = ' ' * (detection_indent + 8) + stripped repaired_lines.append(corrected_line) continue # Fix lines that start with wrong indentation if ':' in stripped and not stripped.startswith('-'): # This is a key-value pair key = stripped.split(':')[0].strip() # Top-level keys should not be indented if key in ['title', 'id', 'status', 'description', 'author', 'date', 'references', 'tags', 'logsource', 'detection', 'falsepositives', 'level']: if current_indent > 0: corrected_line = stripped repaired_lines.append(corrected_line) continue repaired_lines.append(line) return '\n'.join(repaired_lines) def _build_minimal_valid_rule(self, content: str, fixes_applied: list) -> str: """Build a minimal valid SIGMA rule from the content.""" lines = content.split('\n') # Extract key components title = "Unknown SIGMA Rule" rule_id = "00000000-0000-0000-0000-000000000000" description = "Generated SIGMA rule" for line in lines: stripped = line.strip() if stripped.startswith('title:'): title = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('id:'): rule_id = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('description:'): description = stripped.split(':', 1)[1].strip().strip('"\'') # Build minimal valid rule minimal_rule = f"""title: '{title}' id: {rule_id} status: experimental description: '{description}' author: 'AI Generated' date: 2025/01/14 references: - https://example.com logsource: category: process_creation detection: selection: Image: '*' condition: selection level: medium""" fixes_applied.append("Built minimal valid SIGMA rule structure") logger.warning("Generated minimal valid SIGMA rule as fallback") return minimal_rule def _fix_hallucinated_cve_id(self, rule_content: str, correct_cve_id: str) -> str: """Detect and fix hallucinated CVE IDs in the generated rule.""" import re # Pattern to match CVE IDs (CVE-YYYY-NNNNN format) cve_pattern = r'CVE-\d{4}-\d{4,7}' # Find all CVE IDs in the rule content found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE) if found_cves: # Check if any found CVE is different from the correct one hallucinated_cves = [cve for cve in found_cves if cve.upper() != correct_cve_id.upper()] if hallucinated_cves: logger.error(f"CRITICAL: LLM hallucinated CVE IDs: {hallucinated_cves}, expected: {correct_cve_id}") logger.error(f"This indicates the LLM is not following the prompt correctly!") # Replace all hallucinated CVE IDs with the correct one corrected_content = rule_content for hallucinated_cve in set(hallucinated_cves): # Use set to avoid duplicates corrected_content = re.sub( re.escape(hallucinated_cve), correct_cve_id, corrected_content, flags=re.IGNORECASE ) logger.info(f"Successfully corrected hallucinated CVE IDs to {correct_cve_id}") return corrected_content else: logger.info(f"CVE ID validation passed: found correct {correct_cve_id}") else: # No CVE ID found in rule - this might be an issue, but let's add it logger.warning(f"No CVE ID found in generated rule for {correct_cve_id}, this might need manual review") return rule_content def _inject_cve_id_into_rule(self, rule_content: str, cve_id: str) -> str: """Inject CVE ID into a rule that lacks it.""" if not rule_content: logger.warning(f"Empty rule content for {cve_id}, cannot inject CVE ID") return rule_content lines = rule_content.split('\n') modified_lines = [] for i, line in enumerate(lines): stripped = line.strip() # Fix title line if it has placeholders if stripped.startswith('title:'): if '{cve_id}' in line.lower() or '{cve_description}' in line.lower(): # Replace with a proper title modified_lines.append(f"title: 'Detection of {cve_id} exploitation'") elif cve_id not in line: # Add CVE ID to existing title title_text = line.split(':', 1)[1].strip(' \'"') modified_lines.append(f"title: '{cve_id}: {title_text}'") else: modified_lines.append(line) # Fix references section if it has placeholders elif stripped.startswith('- https://nvd.nist.gov/vuln/detail/') and '{cve_id}' in line: modified_lines.append(f" - https://nvd.nist.gov/vuln/detail/{cve_id}") # Skip lines with template placeholders elif any(placeholder in line.lower() for placeholder in ['{cve_id}', '{cve_description}', '{poc_content}']): continue else: modified_lines.append(line) result = '\n'.join(modified_lines) logger.info(f"Injected CVE ID {cve_id} into rule") return result async def enhance_existing_rule(self, existing_rule: str, poc_content: str, cve_id: str) -> Optional[str]: """ Enhance an existing SIGMA rule with PoC analysis. Args: existing_rule: Existing SIGMA rule YAML poc_content: PoC code content cve_id: CVE identifier Returns: Enhanced SIGMA rule or None if failed """ if not self.is_available(): return None try: system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. **Task:** Enhance the existing rule by: 1. Adding more specific detection patterns found in the PoC 2. Improving the condition logic 3. Adding relevant tags or MITRE ATT&CK mappings 4. Keeping the rule structure intact but making it more effective Output ONLY the enhanced SIGMA rule in valid YAML format.""" user_template = """**CVE ID:** {cve_id} **PoC Code:** ``` {poc_content} ``` **Existing SIGMA Rule:** ```yaml {existing_rule} ```""" prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) chain = prompt | self.llm | self.output_parser response = await chain.ainvoke({ "cve_id": cve_id, "poc_content": poc_content[:3000], "existing_rule": existing_rule }) enhanced_rule = self._extract_sigma_rule(response) enhanced_rule = self._post_process_sigma_rule(enhanced_rule) logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") return enhanced_rule except Exception as e: logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") return None def validate_sigma_rule(self, rule_content: str, expected_cve_id: str = None) -> bool: """Validate that the generated rule follows SIGMA specification.""" try: # Parse as YAML parsed = yaml.safe_load(rule_content) if not isinstance(parsed, dict): logger.warning("Rule must be a YAML dictionary") return False # Check MANDATORY fields per SIGMA spec mandatory_fields = ['title', 'logsource', 'detection'] for field in mandatory_fields: if field not in parsed: logger.warning(f"Missing mandatory field: {field}") return False # Validate title title = parsed.get('title', '') if not isinstance(title, str) or len(title) > 256: logger.warning("Title must be string ≤256 characters") return False # Validate logsource structure logsource = parsed.get('logsource', {}) if not isinstance(logsource, dict): logger.warning("Logsource must be a dictionary") return False # Validate detection structure detection = parsed.get('detection', {}) if not isinstance(detection, dict): logger.warning("Detection must be a dictionary") return False # Check for condition (can be in detection or at root) has_condition = 'condition' in detection or 'condition' in parsed if not has_condition: logger.warning("Missing condition field") return False # Check for at least one selection selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter'] for key in detection.keys() if key != 'condition') if not selection_found: logger.warning("Detection must have at least one selection") return False # Validate status if present status = parsed.get('status') if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']: logger.warning(f"Invalid status: {status}") return False # Additional validation: Check for correct CVE ID if provided if expected_cve_id: import re cve_pattern = r'CVE-\d{4}-\d{4,7}' found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE) if found_cves: # Check if all found CVE IDs match the expected one wrong_cves = [cve for cve in found_cves if cve.upper() != expected_cve_id.upper()] if wrong_cves: logger.warning(f"Rule contains wrong CVE IDs: {wrong_cves}, expected {expected_cve_id}") return False else: logger.warning(f"Rule does not contain expected CVE ID: {expected_cve_id}") # Don't fail validation for missing CVE ID, just warn logger.info("SIGMA rule validation passed") return True except yaml.YAMLError as e: error_msg = str(e) if "alias" in error_msg.lower() and "*" in error_msg: logger.warning(f"YAML alias syntax error (likely unquoted wildcard): {e}") elif "expected" in error_msg.lower(): logger.warning(f"YAML structure error: {e}") else: logger.warning(f"YAML parsing error: {e}") return False except Exception as e: logger.warning(f"Rule validation error: {e}") return False @classmethod def get_available_providers(cls) -> List[Dict[str, Any]]: """Get list of available LLM providers and their configuration status.""" providers = [] for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items(): env_key = provider_info.get('env_key', '') api_key_configured = bool(os.getenv(env_key)) providers.append({ 'name': provider_name, 'models': provider_info.get('models', []), 'default_model': provider_info.get('default_model', ''), 'env_key': env_key, 'api_key_configured': api_key_configured, 'available': api_key_configured or provider_name == 'ollama' }) return providers def switch_provider(self, provider: str, model: str = None): """Switch to a different LLM provider and model.""" if provider not in self.SUPPORTED_PROVIDERS: raise ValueError(f"Unsupported provider: {provider}") self.provider = provider self.model = model or self._get_default_model(provider) self._initialize_llm() logger.info(f"Switched to provider: {provider} with model: {self.model}") def _check_ollama_model_available(self, base_url: str, model: str) -> bool: """Check if an Ollama model is available locally""" try: import requests response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) for m in models: if m.get('name', '').startswith(model + ':') or m.get('name') == model: return True return False except Exception as e: logger.error(f"Error checking Ollama models: {e}") return False def _pull_ollama_model(self, base_url: str, model: str) -> bool: """Pull an Ollama model""" try: import requests import json # Use the pull API endpoint payload = {"name": model} response = requests.post( f"{base_url}/api/pull", json=payload, timeout=300, # 5 minutes timeout for model download stream=True ) if response.status_code == 200: # Stream the response to monitor progress for line in response.iter_lines(): if line: try: data = json.loads(line.decode('utf-8')) if data.get('status'): logger.info(f"Ollama pull progress: {data.get('status')}") if data.get('error'): logger.error(f"Ollama pull error: {data.get('error')}") return False except json.JSONDecodeError: continue return True else: logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") return False except Exception as e: logger.error(f"Error pulling Ollama model {model}: {e}") return False def _get_ollama_available_models(self) -> List[str]: """Get list of available Ollama models""" try: import requests base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) return [m.get('name', '') for m in models if m.get('name')] return [] except Exception as e: logger.error(f"Error getting Ollama models: {e}") return [] async def test_connection(self) -> Dict[str, Any]: """Test connection to the configured LLM provider.""" try: if self.provider == 'openai': api_key = os.getenv('OPENAI_API_KEY') if not api_key: return { "available": False, "error": "OpenAI API key not configured", "models": [], "current_model": self.model, "has_api_key": False } # Test OpenAI connection without actual API call to avoid timeouts if self.llm: return { "available": True, "models": self.SUPPORTED_PROVIDERS['openai']['models'], "current_model": self.model, "has_api_key": True } else: return { "available": False, "error": "OpenAI client not initialized", "models": [], "current_model": self.model, "has_api_key": True } elif self.provider == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: return { "available": False, "error": "Anthropic API key not configured", "models": [], "current_model": self.model, "has_api_key": False } # Test Anthropic connection without actual API call to avoid timeouts if self.llm: return { "available": True, "models": self.SUPPORTED_PROVIDERS['anthropic']['models'], "current_model": self.model, "has_api_key": True } else: return { "available": False, "error": "Anthropic client not initialized", "models": [], "current_model": self.model, "has_api_key": True } elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') # Test Ollama connection try: import requests response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: available_models = self._get_ollama_available_models() # Check if model is available using proper model name matching model_available = self._check_ollama_model_available(base_url, self.model) return { "available": model_available, "models": available_models, "current_model": self.model, "base_url": base_url, "error": None if model_available else f"Model {self.model} not available" } else: return { "available": False, "error": f"Ollama server not responding (HTTP {response.status_code})", "models": [], "current_model": self.model, "base_url": base_url } except Exception as e: return { "available": False, "error": f"Cannot connect to Ollama server: {str(e)}", "models": [], "current_model": self.model, "base_url": base_url } else: return { "available": False, "error": f"Unsupported provider: {self.provider}", "models": [], "current_model": self.model } except Exception as e: return { "available": False, "error": f"Connection test failed: {str(e)}", "models": [], "current_model": self.model }