""" LangChain-based LLM client for enhanced SIGMA rule generation. Supports multiple LLM providers: OpenAI, Anthropic, and Ollama. """ import os import logging from typing import Optional, Dict, Any, List from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser import yaml from cve2capec_client import CVE2CAPECClient logger = logging.getLogger(__name__) class LLMClient: """Multi-provider LLM client for SIGMA rule generation using LangChain.""" SUPPORTED_PROVIDERS = { 'openai': { 'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'], 'env_key': 'OPENAI_API_KEY', 'default_model': 'gpt-4o-mini' }, 'anthropic': { 'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'], 'env_key': 'ANTHROPIC_API_KEY', 'default_model': 'claude-3-5-sonnet-20241022' }, 'ollama': { 'models': ['llama3.2', 'codellama', 'mistral', 'llama2', 'sigma-llama-finetuned'], 'env_key': 'OLLAMA_BASE_URL', 'default_model': 'llama3.2' } } def __init__(self, provider: str = None, model: str = None): """Initialize LLM client with specified provider and model.""" self.provider = provider or self._detect_provider() self.model = model or self._get_default_model(self.provider) self.llm = None self.output_parser = StrOutputParser() self.cve2capec_client = CVE2CAPECClient() self._initialize_llm() def _detect_provider(self) -> str: """Auto-detect available LLM provider based on environment variables.""" # Check for API keys in order of preference if os.getenv('ANTHROPIC_API_KEY'): return 'anthropic' elif os.getenv('OPENAI_API_KEY'): return 'openai' elif os.getenv('OLLAMA_BASE_URL'): return 'ollama' else: # Default to OpenAI if no keys found return 'openai' def _get_default_model(self, provider: str) -> str: """Get default model for the specified provider.""" return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini') def _initialize_llm(self): """Initialize the LLM based on provider and model.""" try: if self.provider == 'openai': api_key = os.getenv('OPENAI_API_KEY') if not api_key: logger.warning("OpenAI API key not found") return self.llm = ChatOpenAI( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: logger.warning("Anthropic API key not found") return self.llm = ChatAnthropic( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') # Check if model is available, if not try to pull it if not self._check_ollama_model_available(base_url, self.model): logger.info(f"Model {self.model} not found, attempting to pull...") if self._pull_ollama_model(base_url, self.model): logger.info(f"Successfully pulled model {self.model}") else: logger.error(f"Failed to pull model {self.model}") return self.llm = Ollama( model=self.model, base_url=base_url, temperature=0.1 ) if self.llm: logger.info(f"LLM client initialized: {self.provider} with model {self.model}") else: logger.error(f"Failed to initialize LLM client for provider: {self.provider}") except Exception as e: logger.error(f"Error initializing LLM client: {e}") self.llm = None def is_available(self) -> bool: """Check if LLM client is available and configured.""" return self.llm is not None def get_provider_info(self) -> Dict[str, Any]: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) # For Ollama and fine-tuned, get actually available models available_models = provider_info.get('models', []) if self.provider in ['ollama', 'finetuned']: ollama_models = self._get_ollama_available_models() if ollama_models: available_models = ollama_models return { 'provider': self.provider, 'model': self.model, 'available': self.is_available(), 'supported_models': provider_info.get('models', []), 'available_models': available_models, 'env_key': provider_info.get('env_key', ''), 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) } async def generate_sigma_rule(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> Optional[str]: """ Generate or enhance a SIGMA rule using the configured LLM. Args: cve_id: CVE identifier poc_content: Proof-of-concept code content from GitHub cve_description: CVE description from NVD existing_rule: Optional existing SIGMA rule to enhance Returns: Generated SIGMA rule YAML content or None if failed """ if not self.is_available(): logger.warning("LLM client not available") return None try: # Create the prompt template prompt = self._build_sigma_generation_prompt( cve_id, poc_content, cve_description, existing_rule ) # Create the chain chain = prompt | self.llm | self.output_parser # Debug: Log what we're sending to the LLM input_data = { "cve_id": cve_id, "poc_content": poc_content[:4000], # Truncate if too long "cve_description": cve_description, "existing_rule": existing_rule or "None" } logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}") logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...") logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...") # Generate the response with memory error handling logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}") try: response = await chain.ainvoke(input_data) except Exception as llm_error: # Handle memory issues or model loading failures error_msg = str(llm_error).lower() if any(keyword in error_msg for keyword in ["memory", "out of memory", "too large", "available", "model request"]): logger.error(f"LLM memory error for {cve_id}: {llm_error}") # For memory errors, we don't have specific fallback logic currently logger.error(f"No fallback available for provider {self.provider}") return None else: # Re-raise non-memory errors raise llm_error # Debug: Log raw LLM response logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...") # Extract the SIGMA rule from response sigma_rule = self._extract_sigma_rule(response) # Post-process to ensure clean YAML sigma_rule = self._post_process_sigma_rule(sigma_rule) # Fix common YAML syntax errors sigma_rule = self._fix_yaml_syntax_errors(sigma_rule) # CRITICAL: Validate and fix CVE ID hallucination sigma_rule = self._fix_hallucinated_cve_id(sigma_rule, cve_id) # Additional fallback: If no CVE ID found, inject it into the rule if not sigma_rule or 'CVE-' not in sigma_rule: sigma_rule = self._inject_cve_id_into_rule(sigma_rule, cve_id) # Debug: Log final processed rule logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...") logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}") return sigma_rule except Exception as e: logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}") return None def _build_sigma_generation_prompt(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> ChatPromptTemplate: """Build the prompt template for SIGMA rule generation.""" system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification. **CRITICAL: You must follow the exact SIGMA specification format:** 1. **YAML Structure Requirements:** - Use UTF-8 encoding with LF line breaks - Indent with 4 spaces (no tabs) - Use lowercase keys only - Use single quotes for string values - No quotes for numeric values - Follow proper YAML syntax 2. **MANDATORY Fields (must include):** - title: Brief description (max 256 chars) - logsource: Log data source specification - detection: Search identifiers and conditions - condition: How detection elements combine 3. **RECOMMENDED Fields:** - id: Unique UUID - status: 'experimental' (for new rules) - description: Detailed explanation - author: 'AI Generated' - date: Current date (YYYY/MM/DD) - references: Array with CVE link - tags: MITRE ATT&CK techniques 4. **Detection Structure:** - Use selection blocks (selection, selection1, etc.) - Condition references these selections - Use proper field names (Image, CommandLine, ProcessName, etc.) - Support wildcards (*) and value lists **ABSOLUTE REQUIREMENTS:** - Output ONLY valid YAML - NO explanatory text before or after - NO comments or instructions - NO markdown formatting or code blocks - NEVER repeat the input prompt or template - NEVER include variables like {cve_id} or {poc_content} - NO "Human:", "CVE ID:", "Description:" headers - NO "Analyze this" or "Output only" text - Start IMMEDIATELY with 'title:' - End with the last YAML line only - Ensure perfect YAML syntax **STRUCTURE REQUIREMENTS:** - title: Descriptive title that MUST include the exact CVE ID provided by the user - id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012') - status: experimental - description: Specific description based on CVE and PoC analysis - author: 'AI Generated' - date: Current date (2025/01/16) - references: Include the EXACT CVE URL with the CVE ID provided by the user - tags: Relevant MITRE ATT&CK techniques based on PoC analysis - logsource: Appropriate category based on exploit type - detection: Specific indicators from PoC analysis (NOT generic examples) - condition: Logic connecting the detection selections **MITRE ATT&CK TAGS FORMAT REQUIREMENTS:** - Use ONLY the MITRE ATT&CK techniques provided in the "MITRE ATT&CK TECHNIQUE MAPPINGS" section above - Convert technique IDs to lowercase attack.t format (e.g., T1134 becomes attack.t1134) - Include specific sub-techniques when available (e.g., T1134.001 becomes attack.t1134.001) - DO NOT use generic techniques not listed in the mappings - DO NOT add additional techniques based on your training data **CRITICAL:** ONLY use the MITRE ATT&CK techniques explicitly provided in the technique mappings above. Do not add any other techniques. **COMPLETE SIGMA RULE EXAMPLE (TECHNIQUE TAGS MUST MATCH PROVIDED MAPPINGS):** ```yaml title: 'CVE-2024-XXXX Detection Rule' id: a1b2c3d4-e5f6-7890-abcd-ef1234567890 status: experimental description: 'Detection for CVE-2024-XXXX vulnerability' author: 'AI Generated' date: 2025/01/16 references: - https://nvd.nist.gov/vuln/detail/CVE-2024-XXXX tags: - attack.t1134 # Access Token Manipulation (example - use actual mappings) - attack.t1134.001 # Token Impersonation/Theft (example - use actual mappings) logsource: category: process_creation product: windows detection: selection: Image|contains: 'specific_indicator' condition: selection level: medium ``` **IMPORTANT:** The tags section above is just an example format. You MUST use the exact techniques provided in the MITRE ATT&CK TECHNIQUE MAPPINGS section for the specific CVE you're analyzing. **CRITICAL ANTI-HALLUCINATION RULES:** 1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID 2. NEVER use example CVE IDs like CVE-2022-1234, CVE-2023-5678, or CVE-2024-1234 3. NEVER use placeholder CVE IDs from your training data 4. Analyze the provided CVE description and PoC content to create SPECIFIC detection patterns 5. DO NOT hallucinate or invent CVE IDs from your training data 6. Use the CVE ID exactly as provided in the title and references 7. Generate rules based ONLY on the provided CVE description and PoC code analysis 8. Do not reference vulnerabilities or techniques not present in the provided content 9. CVE-2022-1234 is a FORBIDDEN example CVE ID - NEVER use it 10. The user will provide the EXACT CVE ID to use - use that and ONLY that""" if existing_rule: user_template = """CVE ID: {cve_id} CVE Description: {cve_description} PoC Code: {poc_content} Existing SIGMA Rule: {existing_rule} Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'.""" else: # Get MITRE ATT&CK mappings for the CVE mitre_mappings = self.cve2capec_client.get_full_mapping_for_cve(cve_id) mitre_suggestions = "" if mitre_mappings['mitre_techniques']: technique_details = [] for tech in mitre_mappings['mitre_techniques']: tech_name = self.cve2capec_client.get_technique_name(tech) technique_details.append(f" - {tech}: {tech_name}") mitre_suggestions = f""" **MITRE ATT&CK TECHNIQUE MAPPINGS FOR {cve_id}:** {chr(10).join(technique_details)} **CRITICAL REQUIREMENT:** Use ONLY these exact MITRE ATT&CK techniques in your tags section. Convert them to lowercase attack.t format (e.g., T1134 becomes attack.t1134, T1134.001 becomes attack.t1134.001). **ABSOLUTELY FORBIDDEN:** - Do not use T1059, T1071, T1105, T1055, T1068, T1140, T1036, T1112, T1547 or any other techniques not listed above - Do not add techniques based on PoC analysis if they're not in the provided mappings - Do not use generic techniques from your training data If no MITRE techniques are provided above, use only CVE and CWE tags.""" if mitre_mappings['cwe_codes']: mitre_suggestions += f""" **CWE MAPPINGS:** {', '.join(mitre_mappings['cwe_codes'])}""" user_template = f"""CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE: **MANDATORY CVE ID TO USE: {{cve_id}}** **CVE Description: {{cve_description}}** **Proof-of-Concept Code Analysis:** {{poc_content}} {mitre_suggestions} **CRITICAL REQUIREMENTS:** 1. Use EXACTLY this CVE ID in the title: {{cve_id}} 2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{{cve_id}} 3. Analyze the CVE description to understand the vulnerability type 4. If the PoC analysis above contains structured indicators, use those EXACT indicators in your detection rules 5. **USE ONLY THE MITRE ATT&CK TECHNIQUES LISTED IN THE MAPPINGS ABOVE** - Do not add any other techniques 6. Choose the appropriate logsource category based on the primary indicator types (process_creation, file_event, network_connection, registry_event, etc.) 7. Convert the mapped MITRE techniques to lowercase attack.t format (T1134 → attack.t1134, T1134.001 → attack.t1134.001) **DETECTION PATTERN GUIDANCE:** - For Process Execution indicators: Use Image, CommandLine, or ProcessName fields - For File System indicators: Use TargetFilename, SourceFilename, or FilePath fields - For Network indicators: Use DestinationHostname, DestinationIp, or DestinationPort fields - For Registry indicators: Use TargetObject, Details, or EventType fields - For Command indicators: Use CommandLine or ProcessCommandLine fields **TAGS FORMATTING REQUIREMENTS:** - Use ONLY the MITRE ATT&CK techniques provided in the "MITRE ATT&CK TECHNIQUE MAPPINGS" section above - Convert to lowercase attack.t format: T1134 → attack.t1134, T1134.001 → attack.t1134.001 - Include comments for clarity: attack.t1134 # Access Token Manipulation - Use specific sub-techniques when available - DO NOT add techniques not listed in the provided mappings - DO NOT use generic techniques from your training data **CRITICAL ANTI-HALLUCINATION REQUIREMENTS:** - THE CVE ID IS: {{cve_id}} - DO NOT use CVE-2022-1234, CVE-2023-1234, CVE-2024-1234, or any other example CVE ID - DO NOT generate a different CVE ID from your training data - You MUST use the exact CVE ID "{{cve_id}}" - this is the ONLY acceptable CVE ID for this rule - Base your analysis ONLY on the provided CVE description and PoC code above - If structured indicators are provided in the PoC analysis, use those exact values - Do not reference other vulnerabilities or exploits not mentioned in the provided content - NEVER use placeholder CVE IDs like CVE-YYYY-NNNN or CVE-2022-1234 **ABSOLUTE REQUIREMENT: THE EXACT CVE ID TO USE IS: {{cve_id}}** **FORBIDDEN: Do not use CVE-2022-1234, CVE-2023-5678, or any other example CVE ID** Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {{cve_id}}.""" # Create the prompt template with proper variable definitions prompt_template = ChatPromptTemplate.from_messages([ ("system", system_message), ("human", user_template) ]) return prompt_template def _extract_sigma_rule(self, response_text: str) -> str: """Extract and clean SIGMA rule YAML from LLM response.""" lines = response_text.split('\n') yaml_lines = [] in_yaml_block = False found_title = False for line in lines: stripped = line.strip() # Skip code block markers if stripped.startswith('```'): if stripped.startswith('```yaml'): in_yaml_block = True elif stripped == '```' and in_yaml_block: break continue # Skip obvious non-YAML content if not in_yaml_block and not found_title: if not stripped.startswith('title:'): # Skip explanatory text and prompt artifacts skip_phrases = [ 'please note', 'this rule', 'you should', 'analysis:', 'explanation:', 'based on', 'the following', 'here is', 'note that', 'important:', 'remember', 'this is a', 'make sure to', 'you can modify', 'adjust the', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only' ] if any(phrase in stripped.lower() for phrase in skip_phrases): continue # Skip template variables and prompt artifacts if '{' in stripped and '}' in stripped: continue # Skip lines that contain template placeholder text if 'cve_id' in stripped.lower() or 'cve description' in stripped.lower(): continue # Skip lines that are clearly not YAML structure if stripped and not ':' in stripped and len(stripped) > 20: continue # Start collecting when we find title or are in YAML block if stripped.startswith('title:') or in_yaml_block: found_title = True in_yaml_block = True # Skip explanatory comments if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()): continue yaml_lines.append(line) # Stop if we encounter obvious non-YAML after starting elif found_title: # Stop at explanatory text after the rule stop_phrases = [ 'please note', 'this rule should', 'make sure to', 'you can modify', 'adjust the', 'also, this is', 'based on the analysis', 'the rule above' ] if any(phrase in stripped.lower() for phrase in stop_phrases): break # Stop at lines without colons that aren't indented (likely explanations) if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped: break if not yaml_lines: # Fallback: look for any line with YAML-like structure for line in lines: if ':' in line and not line.strip().startswith('#'): yaml_lines.append(line) return '\n'.join(yaml_lines).strip() def _post_process_sigma_rule(self, rule_content: str) -> str: """Post-process SIGMA rule to ensure clean YAML format.""" lines = rule_content.split('\n') cleaned_lines = [] for line in lines: stripped = line.strip() # Skip obvious non-YAML content and prompt artifacts if any(phrase in stripped.lower() for phrase in [ 'please note', 'you should replace', 'this is a proof-of-concept', 'please make sure', 'note that', 'important:', 'remember to', 'analysis shows', 'based on the', 'the rule above', 'this rule', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only', 'generate a' ]): continue # Skip template variables and placeholder text if '{' in stripped and '}' in stripped: continue # Skip lines that contain template placeholder patterns if any(placeholder in stripped.lower() for placeholder in [ 'cve_id', 'cve description', 'poc_content', 'existing_rule', '{cve_id}', '{cve_description}', '{poc_content}' ]): continue # Skip lines that look like explanations if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '): # This might be explanatory text, skip it if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']): continue # Skip empty explanatory sections if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']: continue cleaned_lines.append(line) return '\n'.join(cleaned_lines).strip() def _fix_yaml_syntax_errors(self, rule_content: str) -> str: """Fix common YAML syntax errors in LLM-generated rules.""" import re if not rule_content: return rule_content lines = rule_content.split('\n') fixed_lines = [] fixes_applied = [] for line in lines: fixed_line = line # Fix invalid YAML alias syntax: - *image* -> - '*image*' # YAML aliases must be alphanumeric, but LLM uses *word* or *multiple words* for wildcards if '- *' in line and '*' in line: # Match patterns like "- *image*" or "- *process*" or "- *unpatched system*" pattern = r'(\s*-\s*)(\*[^*]+\*)' if re.search(pattern, line): fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed invalid YAML alias syntax: {line.strip()} -> {fixed_line.strip()}") # Also fix similar patterns in values: key: *value* -> key: '*value*' elif re.search(r':\s*\*[^*]+\*\s*$', line) and not re.search(r'[\'"]', line): pattern = r'(:\s*)(\*[^*]+\*)' fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed invalid YAML alias in value: {line.strip()} -> {fixed_line.strip()}") # Fix unquoted strings that start with special characters elif re.match(r'^\s*-\s*[*&|>]', line): # If line starts with -, followed by special YAML chars, quote it parts = line.split('-', 1) if len(parts) == 2: indent = parts[0] content = parts[1].strip() if content and not content.startswith(("'", '"')): fixed_line = f"{indent}- '{content}'" fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}") # Fix invalid boolean values elif ': *' in line and not line.strip().startswith('#'): # Replace ": *something*" with ": '*something*'" if not already quoted pattern = r'(:\s*)(\*[^*]+\*)' if not re.search(r'[\'"]', line) and re.search(pattern, line): fixed_line = re.sub(pattern, r"\1'\2'", line) fixes_applied.append(f"Fixed unquoted wildcard value: {line.strip()} -> {fixed_line.strip()}") # Fix missing quotes around values with special characters (but not YAML indicators) elif re.search(r':\s*[*&]', line) and not re.search(r'[\'"]', line): # Don't quote YAML multiline indicators (|, >) if not re.search(r':\s*[|>]\s*$', line): parts = line.split(':', 1) if len(parts) == 2: key = parts[0] value = parts[1].strip() if value and not value.startswith(("'", '"', '[', '{')): fixed_line = f"{key}: '{value}'" fixes_applied.append(f"Quoted special character value: {line.strip()} -> {fixed_line.strip()}") # Fix invalid array syntax elif re.search(r'^\s*\*[^*]+\*\s*$', line): # Standalone *word* or *multiple words* lines should be quoted indent = len(line) - len(line.lstrip()) content = line.strip() fixed_line = f"{' ' * indent}'{content}'" fixes_applied.append(f"Fixed standalone wildcard: {line.strip()} -> {fixed_line.strip()}") fixed_lines.append(fixed_line) result = '\n'.join(fixed_lines) # Additional fixes for common YAML issues # Fix missing spaces after colons colon_fix = re.sub(r':([^\s])', r': \1', result) if colon_fix != result: fixes_applied.append("Added missing spaces after colons") result = colon_fix # Fix multiple spaces after colons space_fix = re.sub(r':\s{2,}', ': ', result) if space_fix != result: fixes_applied.append("Fixed multiple spaces after colons") result = space_fix # Fix incorrect reference format: references: - https://... -> references:\n - https://... ref_fix = re.sub(r'references:\s*-\s*', 'references:\n - ', result) if ref_fix != result: fixes_applied.append("Fixed references array format") result = ref_fix # Fix broken URLs in references (spaces in URLs) url_fix = re.sub(r'https:\s*//nvd\.nist\.gov', 'https://nvd.nist.gov', result) if url_fix != result: fixes_applied.append("Fixed broken URLs in references") result = url_fix # Fix incorrect logsource format: logsource: category: X -> logsource:\n category: X logsource_fix = re.sub(r'logsource:\s*(category|product|service):\s*', r'logsource:\n \1: ', result) if logsource_fix != result: fixes_applied.append("Fixed logsource structure format") result = logsource_fix # Fix incorrect detection format: detection: selection: key: value -> detection:\n selection:\n key: value detection_fix = re.sub(r'detection:\s*(\w+):\s*(\w+):\s*', r'detection:\n \1:\n \2: ', result) if detection_fix != result: fixes_applied.append("Fixed detection structure format") result = detection_fix # Fix detection lines with == operators: detection: selection1: image == *value* -> detection:\n selection1:\n image: '*value*' # This handles compressed syntax with equality operators # Make the pattern more flexible to catch various formats detection_eq_patterns = [ (r'detection:\s*(\w+):\s*(\w+)\s*==\s*(\*[^*\s]+\*)', r'detection:\n \1:\n \2: \'\3\''), (r'detection:\s*(\w+):\s*(\w+)\s*==\s*([^\s]+)', r'detection:\n \1:\n \2: \'\3\''), ] for pattern, replacement in detection_eq_patterns: detection_eq_fix = re.sub(pattern, replacement, result) if detection_eq_fix != result: fixes_applied.append("Fixed detection equality operator syntax") result = detection_eq_fix break # Fix standalone equality operators in detection sections: key == *value* -> key: '*value*' # Also handle lines with multiple keys/values separated by colons and == lines = result.split('\n') eq_fixed_lines = [] for line in lines: original_line = line # Look for pattern: whitespace + key == *value* or key == value if ' == ' in line: # Handle complex patterns like "detection: selection1: image == *value*" if line.strip().startswith('detection:') and ' == ' in line: # Split by colons to handle nested structure parts = line.split(':') if len(parts) >= 3: # This looks like "detection: selection1: image == *value*" base_indent = len(line) - len(line.lstrip()) # Extract the parts detection_part = parts[0].strip() # "detection" selection_part = parts[1].strip() # "selection1" key_value_part = ':'.join(parts[2:]).strip() # "image == *value*" # Parse the key == value part if ' == ' in key_value_part: eq_parts = key_value_part.split(' == ', 1) key = eq_parts[0].strip() value = eq_parts[1].strip() # Quote the value if needed if value.startswith('*') and value.endswith('*') and not value.startswith("'"): value = f"'{value}'" elif not value.startswith(("'", '"', '[', '{')): value = f"'{value}'" # Reconstruct as proper YAML eq_fixed_lines.append(f"{' ' * base_indent}detection:") eq_fixed_lines.append(f"{' ' * (base_indent + 4)}{selection_part}:") eq_fixed_lines.append(f"{' ' * (base_indent + 8)}{key}: {value}") fixes_applied.append(f"Fixed complex detection equality: {selection_part}: {key} == {value}") continue # Handle simpler patterns: " key == value" elif re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line): match = re.match(r'^(\s+)(\w+)\s*==\s*(.+)$', line) indent = match.group(1) key = match.group(2) value = match.group(3).strip() # Ensure wildcards are quoted if value.startswith('*') and value.endswith('*') and not value.startswith("'"): value = f"'{value}'" elif not value.startswith(("'", '"', '[', '{')): value = f"'{value}'" eq_fixed_lines.append(f"{indent}{key}: {value}") fixes_applied.append(f"Fixed equality operator: {key} == {value}") continue eq_fixed_lines.append(original_line) if len(eq_fixed_lines) != len(lines): result = '\n'.join(eq_fixed_lines) # Fix invalid array-as-value syntax: key: - value -> key:\n - value # This handles cases like "CommandLine: - '*image*'" which should be "CommandLine:\n - '*image*'" lines = result.split('\n') fixed_lines = [] for line in lines: # Look for pattern: whitespace + key: - value if re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line): match = re.match(r'^(\s+)(\w+):\s*-\s*(.+)$', line) indent = match.group(1) key = match.group(2) value = match.group(3) # Convert to proper array format fixed_lines.append(f"{indent}{key}:") fixed_lines.append(f"{indent} - {value}") fixes_applied.append(f"Fixed array-as-value syntax: {key}: - {value}") else: fixed_lines.append(line) if len(fixed_lines) != len(lines): result = '\n'.join(fixed_lines) # Fix complex nested syntax errors like "selection1: Image: - '*path*': value" # This should be "selection1:\n Image:\n - '*path*': value" complex_fix = re.sub(r'^(\s+)(\w+):\s*(\w+):\s*-\s*(.+)$', r'\1\2:\n\1 \3:\n\1 - \4', result, flags=re.MULTILINE) if complex_fix != result: fixes_applied.append("Fixed complex nested structure syntax") result = complex_fix # Fix incorrect tags format: tags: - T1059.001 -> tags:\n - T1059.001 tags_fix = re.sub(r'tags:\s*-\s*', 'tags:\n - ', result) if tags_fix != result: fixes_applied.append("Fixed tags array format") result = tags_fix # Fix other common single-line array formats for field in ['falsepositives', 'level', 'related']: field_pattern = f'{field}:\\s*-\\s*' field_replacement = f'{field}:\n - ' field_fix = re.sub(field_pattern, field_replacement, result) if field_fix != result: fixes_applied.append(f"Fixed {field} array format") result = field_fix # Fix placeholder UUID if LLM used the example one import uuid placeholder_uuid = '12345678-1234-1234-1234-123456789012' if placeholder_uuid in result: new_uuid = str(uuid.uuid4()) result = result.replace(placeholder_uuid, new_uuid) fixes_applied.append(f"Replaced placeholder UUID with {new_uuid[:8]}...") # Fix orphaned list items (standalone lines starting with -) lines = result.split('\n') fixed_lines = [] for i, line in enumerate(lines): stripped = line.strip() # Check for orphaned list items (lines starting with - but not part of an array) # But be more careful - don't remove items that are properly indented under a parent if (stripped.startswith('- ') and i > 0 and not lines[i-1].strip().endswith(':') and ':' not in stripped and not stripped.startswith('- https://') and # Don't remove reference URLs not stripped.startswith('- attack.') and # Don't remove MITRE ATT&CK tags not re.match(r'- [a-z0-9._-]+$', stripped)): # Don't remove simple tags # Check if this is properly indented under a parent (like tags:) is_properly_indented = False current_indent = len(line) - len(line.lstrip()) # Look backwards to find a parent with less indentation for j in range(i-1, -1, -1): prev_line = lines[j] prev_stripped = prev_line.strip() prev_indent = len(prev_line) - len(prev_line.lstrip()) if prev_stripped and prev_indent < current_indent: # Found a parent with less indentation if prev_stripped.endswith(':'): is_properly_indented = True break else: # This is likely orphaned break if not is_properly_indented: # This is truly orphaned, remove it fixes_applied.append(f"Removed orphaned list item: {stripped}") continue fixed_lines.append(line) result = '\n'.join(fixed_lines) # Final pass: Remove lines that are still malformed and would cause YAML parsing errors lines = result.split('\n') final_lines = [] for line in lines: stripped = line.strip() # Skip lines that have multiple colons in problematic patterns if re.search(r':\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*', line): # This looks like "key: subkey: - 'value': more_stuff" which is malformed fixes_applied.append(f"Removed malformed nested line: {stripped[:50]}...") continue # Skip lines with invalid YAML mapping structures if re.search(r'^\s*\w+:\s*\w+:\s*-\s*[\'"][^\'":]*[\'"]:\s*\w+', line): fixes_applied.append(f"Removed invalid mapping structure: {stripped[:50]}...") continue final_lines.append(line) if len(final_lines) != len(lines): result = '\n'.join(final_lines) # Log if we made any fixes if fixes_applied: logger.info(f"Applied YAML syntax fixes: {', '.join(fixes_applied)}") # Final YAML structure validation and repair result = self._validate_and_repair_yaml_structure(result, fixes_applied) return result def _validate_and_repair_yaml_structure(self, content: str, fixes_applied: list) -> str: """Use YAML library to validate and repair structural issues.""" try: # First, try to parse the YAML to see if it's valid yaml.safe_load(content) # If we get here, the YAML is valid return content except yaml.YAMLError as e: logger.warning(f"YAML structure validation failed: {e}") # Try to repair common structural issues repaired_content = self._repair_yaml_structure(content, str(e)) # Test if the repair worked try: yaml.safe_load(repaired_content) fixes_applied.append("Repaired YAML document structure") logger.info("Successfully repaired YAML structure") return repaired_content except yaml.YAMLError as e2: logger.warning(f"YAML repair attempt failed: {e2}") # Try a more aggressive repair before falling back to minimal rule aggressive_repair = self._aggressive_yaml_repair(content) try: yaml.safe_load(aggressive_repair) fixes_applied.append("Applied aggressive YAML repair") logger.info("Successfully repaired YAML with aggressive method") return aggressive_repair except yaml.YAMLError as e3: logger.warning(f"Aggressive repair also failed: {e3}") # Last resort: try to build a minimal valid SIGMA rule return self._build_minimal_valid_rule(content, fixes_applied) def _repair_yaml_structure(self, content: str, error_msg: str) -> str: """Attempt to repair common YAML structural issues.""" lines = content.split('\n') repaired_lines = [] # Track indentation levels to detect issues expected_indent = 0 in_detection = False detection_indent = 0 in_tags = False tags_indent = 0 for i, line in enumerate(lines): stripped = line.strip() current_indent = len(line) - len(line.lstrip()) # Skip empty lines if not stripped: repaired_lines.append(line) continue # Track if we're in the tags section if stripped.startswith('tags:'): in_tags = True tags_indent = current_indent repaired_lines.append(line) continue elif in_tags and current_indent <= tags_indent and not stripped.startswith('-'): # We've left the tags section in_tags = False # Fix tags section indentation if in_tags and stripped.startswith('-'): # Ensure proper indentation for tag items if current_indent <= tags_indent: corrected_line = ' ' * (tags_indent + 2) + stripped repaired_lines.append(corrected_line) continue # Track if we're in the detection section if stripped.startswith('detection:'): in_detection = True detection_indent = current_indent repaired_lines.append(line) continue elif in_detection and current_indent <= detection_indent and not stripped.startswith(('condition:', 'timeframe:')): # We've left the detection section in_detection = False # Fix indentation issues in detection section if in_detection: # Ensure proper indentation for detection subsections if stripped.startswith(('selection', 'filter', 'condition')): # This should be indented under detection if current_indent <= detection_indent: corrected_line = ' ' * (detection_indent + 4) + stripped repaired_lines.append(corrected_line) continue elif current_indent > detection_indent + 4: # This might be a detection field that needs proper indentation if ':' in stripped and not stripped.startswith('-'): # This looks like a field under a selection if i > 0 and 'selection' in lines[i-1]: corrected_line = ' ' * (detection_indent + 8) + stripped repaired_lines.append(corrected_line) continue # Fix logsource section indentation if stripped.startswith('logsource:'): # Logsource should be at root level (no indentation) if current_indent > 0: corrected_line = stripped repaired_lines.append(corrected_line) continue elif line.lstrip().startswith(('category:', 'product:', 'service:')) and i > 0: # These should be indented under logsource prev_line = lines[i-1].strip() if prev_line.startswith('logsource:') or any('logsource' in repaired_lines[j] for j in range(max(0, len(repaired_lines)-5), len(repaired_lines))): corrected_line = ' ' + stripped repaired_lines.append(corrected_line) continue # Fix lines that start with wrong indentation if ':' in stripped and not stripped.startswith('-'): # This is a key-value pair key = stripped.split(':')[0].strip() # Top-level keys should not be indented if key in ['title', 'id', 'status', 'description', 'author', 'date', 'references', 'tags', 'logsource', 'detection', 'falsepositives', 'level']: if current_indent > 0: corrected_line = stripped repaired_lines.append(corrected_line) continue repaired_lines.append(line) return '\n'.join(repaired_lines) def _aggressive_yaml_repair(self, content: str) -> str: """Aggressive YAML repair that reconstructs the document structure.""" lines = content.split('\n') # Extract key components title = "Generated SIGMA Rule" rule_id = "00000000-0000-0000-0000-000000000000" description = "Generated detection rule" author = "AI Generated" date = "2025/01/16" references = [] tags = [] logsource_category = "process_creation" logsource_product = "windows" detection_rules = [] condition = "selection" level = "medium" # Parse existing content for i, line in enumerate(lines): stripped = line.strip() if stripped.startswith('title:'): title = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('id:'): rule_id = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('description:'): description = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('author:'): author = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('date:'): date = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('level:'): level = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('condition:'): condition = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('- http'): references.append(stripped[2:].strip()) elif stripped.startswith('- attack.') or stripped.startswith('- cve-') or stripped.startswith('- exploit.') or stripped.startswith('- poc.') or stripped.startswith('- cwe.'): tags.append(stripped[2:].strip()) elif 'category:' in stripped: logsource_category = stripped.split(':', 1)[1].strip().strip('"\'') elif 'product:' in stripped: logsource_product = stripped.split(':', 1)[1].strip().strip('"\'') # Build a clean YAML structure yaml_content = f"""title: '{title}' id: {rule_id} status: experimental description: '{description}' author: '{author}' date: {date} references:""" if references: for ref in references: yaml_content += f"\n - {ref}" else: yaml_content += "\n - https://example.com" yaml_content += "\ntags:" if tags: for tag in tags: yaml_content += f"\n - {tag}" else: yaml_content += "\n - unknown" yaml_content += f""" logsource: category: {logsource_category} product: {logsource_product} detection: selection: Image: '*' condition: {condition} level: {level}""" return yaml_content def _build_minimal_valid_rule(self, content: str, fixes_applied: list) -> str: """Build a minimal valid SIGMA rule from the content.""" lines = content.split('\n') # Extract key components title = "Unknown SIGMA Rule" rule_id = "00000000-0000-0000-0000-000000000000" description = "Generated SIGMA rule" for line in lines: stripped = line.strip() if stripped.startswith('title:'): title = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('id:'): rule_id = stripped.split(':', 1)[1].strip().strip('"\'') elif stripped.startswith('description:'): description = stripped.split(':', 1)[1].strip().strip('"\'') # Build minimal valid rule minimal_rule = f"""title: '{title}' id: {rule_id} status: experimental description: '{description}' author: 'AI Generated' date: 2025/01/16 references: - https://example.com logsource: category: process_creation detection: selection: Image: '*' condition: selection level: medium""" fixes_applied.append("Built minimal valid SIGMA rule structure") logger.warning("Generated minimal valid SIGMA rule as fallback") return minimal_rule def _fix_hallucinated_cve_id(self, rule_content: str, correct_cve_id: str) -> str: """Detect and fix hallucinated CVE IDs in the generated rule.""" import re # Pattern to match CVE IDs (CVE-YYYY-NNNNN format) cve_pattern = r'CVE-\d{4}-\d{4,7}' # Find all CVE IDs in the rule content found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE) if found_cves: # Check if any found CVE is different from the correct one hallucinated_cves = [cve for cve in found_cves if cve.upper() != correct_cve_id.upper()] if hallucinated_cves: logger.error(f"CRITICAL: LLM hallucinated CVE IDs: {hallucinated_cves}, expected: {correct_cve_id}") logger.error(f"This indicates the LLM is not following the prompt correctly!") # Replace all hallucinated CVE IDs with the correct one corrected_content = rule_content for hallucinated_cve in set(hallucinated_cves): # Use set to avoid duplicates corrected_content = re.sub( re.escape(hallucinated_cve), correct_cve_id, corrected_content, flags=re.IGNORECASE ) logger.info(f"Successfully corrected hallucinated CVE IDs to {correct_cve_id}") return corrected_content else: logger.info(f"CVE ID validation passed: found correct {correct_cve_id}") else: # No CVE ID found in rule - this might be an issue, but let's add it logger.warning(f"No CVE ID found in generated rule for {correct_cve_id}, this might need manual review") return rule_content def _inject_cve_id_into_rule(self, rule_content: str, cve_id: str) -> str: """Inject CVE ID into a rule that lacks it.""" if not rule_content: logger.warning(f"Empty rule content for {cve_id}, cannot inject CVE ID") return rule_content lines = rule_content.split('\n') modified_lines = [] for i, line in enumerate(lines): stripped = line.strip() # Fix title line if it has placeholders if stripped.startswith('title:'): if '{cve_id}' in line.lower() or '{cve_description}' in line.lower(): # Replace with a proper title modified_lines.append(f"title: 'Detection of {cve_id} exploitation'") elif cve_id not in line: # Add CVE ID to existing title title_text = line.split(':', 1)[1].strip(' \'"') modified_lines.append(f"title: '{cve_id}: {title_text}'") else: modified_lines.append(line) # Fix references section if it has placeholders elif stripped.startswith('- https://nvd.nist.gov/vuln/detail/') and '{cve_id}' in line: modified_lines.append(f" - https://nvd.nist.gov/vuln/detail/{cve_id}") # Skip lines with template placeholders elif any(placeholder in line.lower() for placeholder in ['{cve_id}', '{cve_description}', '{poc_content}']): continue else: modified_lines.append(line) result = '\n'.join(modified_lines) logger.info(f"Injected CVE ID {cve_id} into rule") return result async def enhance_existing_rule(self, existing_rule: str, poc_content: str, cve_id: str) -> Optional[str]: """ Enhance an existing SIGMA rule with PoC analysis. Args: existing_rule: Existing SIGMA rule YAML poc_content: PoC code content cve_id: CVE identifier Returns: Enhanced SIGMA rule or None if failed """ if not self.is_available(): return None try: system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. **Task:** Enhance the existing rule by: 1. Adding more specific detection patterns found in the PoC 2. Improving the condition logic 3. Adding relevant tags or MITRE ATT&CK mappings 4. Keeping the rule structure intact but making it more effective Output ONLY the enhanced SIGMA rule in valid YAML format.""" user_template = """**CVE ID:** {cve_id} **PoC Code:** ``` {poc_content} ``` **Existing SIGMA Rule:** ```yaml {existing_rule} ```""" prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) chain = prompt | self.llm | self.output_parser response = await chain.ainvoke({ "cve_id": cve_id, "poc_content": poc_content[:3000], "existing_rule": existing_rule }) enhanced_rule = self._extract_sigma_rule(response) enhanced_rule = self._post_process_sigma_rule(enhanced_rule) logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") return enhanced_rule except Exception as e: logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") return None def validate_sigma_rule(self, rule_content: str, expected_cve_id: str = None) -> bool: """Validate that the generated rule follows SIGMA specification.""" try: # Parse as YAML parsed = yaml.safe_load(rule_content) if not isinstance(parsed, dict): logger.warning("Rule must be a YAML dictionary") return False # Check MANDATORY fields per SIGMA spec mandatory_fields = ['title', 'logsource', 'detection'] for field in mandatory_fields: if field not in parsed: logger.warning(f"Missing mandatory field: {field}") return False # Validate title title = parsed.get('title', '') if not isinstance(title, str) or len(title) > 256: logger.warning("Title must be string ≤256 characters") return False # Validate logsource structure logsource = parsed.get('logsource', {}) if not isinstance(logsource, dict): logger.warning("Logsource must be a dictionary") return False # Validate detection structure detection = parsed.get('detection', {}) if not isinstance(detection, dict): logger.warning("Detection must be a dictionary") return False # Check for condition (can be in detection or at root) has_condition = 'condition' in detection or 'condition' in parsed if not has_condition: logger.warning("Missing condition field") return False # Check for at least one selection selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter'] for key in detection.keys() if key != 'condition') if not selection_found: logger.warning("Detection must have at least one selection") return False # Validate status if present status = parsed.get('status') if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']: logger.warning(f"Invalid status: {status}") return False # Additional validation: Check for correct CVE ID if provided if expected_cve_id: import re cve_pattern = r'CVE-\d{4}-\d{4,7}' found_cves = re.findall(cve_pattern, rule_content, re.IGNORECASE) if found_cves: # Check if all found CVE IDs match the expected one wrong_cves = [cve for cve in found_cves if cve.upper() != expected_cve_id.upper()] if wrong_cves: logger.warning(f"Rule contains wrong CVE IDs: {wrong_cves}, expected {expected_cve_id}") return False else: logger.warning(f"Rule does not contain expected CVE ID: {expected_cve_id}") # Don't fail validation for missing CVE ID, just warn logger.info("SIGMA rule validation passed") return True except yaml.YAMLError as e: error_msg = str(e) if "alias" in error_msg.lower() and "*" in error_msg: logger.warning(f"YAML alias syntax error (likely unquoted wildcard): {e}") elif "expected" in error_msg.lower(): logger.warning(f"YAML structure error: {e}") else: logger.warning(f"YAML parsing error: {e}") return False except Exception as e: logger.warning(f"Rule validation error: {e}") return False @classmethod def get_available_providers(cls) -> List[Dict[str, Any]]: """Get list of available LLM providers and their configuration status.""" providers = [] for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items(): env_key = provider_info.get('env_key', '') api_key_configured = bool(os.getenv(env_key)) available = api_key_configured or provider_name == 'ollama' providers.append({ 'name': provider_name, 'models': provider_info.get('models', []), 'default_model': provider_info.get('default_model', ''), 'env_key': env_key, 'api_key_configured': api_key_configured, 'available': available }) return providers def switch_provider(self, provider: str, model: str = None): """Switch to a different LLM provider and model.""" if provider not in self.SUPPORTED_PROVIDERS: raise ValueError(f"Unsupported provider: {provider}") self.provider = provider self.model = model or self._get_default_model(provider) self._initialize_llm() logger.info(f"Switched to provider: {provider} with model: {self.model}") def _check_ollama_model_available(self, base_url: str, model: str) -> bool: """Check if an Ollama model is available locally""" try: import requests response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) for m in models: if m.get('name', '').startswith(model + ':') or m.get('name') == model: return True return False except Exception as e: logger.error(f"Error checking Ollama models: {e}") return False def _pull_ollama_model(self, base_url: str, model: str) -> bool: """Pull an Ollama model""" try: import requests import json # Use the pull API endpoint payload = {"name": model} response = requests.post( f"{base_url}/api/pull", json=payload, timeout=300, # 5 minutes timeout for model download stream=True ) if response.status_code == 200: # Stream the response to monitor progress for line in response.iter_lines(): if line: try: data = json.loads(line.decode('utf-8')) if data.get('status'): logger.info(f"Ollama pull progress: {data.get('status')}") if data.get('error'): logger.error(f"Ollama pull error: {data.get('error')}") return False except json.JSONDecodeError: continue return True else: logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") return False except Exception as e: logger.error(f"Error pulling Ollama model {model}: {e}") return False def _get_ollama_available_models(self) -> List[str]: """Get list of available Ollama models""" try: import requests base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) return [m.get('name', '') for m in models if m.get('name')] return [] except Exception as e: logger.error(f"Error getting Ollama models: {e}") return []