""" LangChain-based LLM client for enhanced SIGMA rule generation. Supports multiple LLM providers: OpenAI, Anthropic, and local models. """ import os import logging from typing import Optional, Dict, Any, List from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser import yaml logger = logging.getLogger(__name__) class LLMClient: """Multi-provider LLM client for SIGMA rule generation using LangChain.""" SUPPORTED_PROVIDERS = { 'openai': { 'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'], 'env_key': 'OPENAI_API_KEY', 'default_model': 'gpt-4o-mini' }, 'anthropic': { 'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'], 'env_key': 'ANTHROPIC_API_KEY', 'default_model': 'claude-3-5-sonnet-20241022' }, 'ollama': { 'models': ['llama3.2', 'codellama', 'mistral', 'llama2'], 'env_key': 'OLLAMA_BASE_URL', 'default_model': 'llama3.2' } } def __init__(self, provider: str = None, model: str = None): """Initialize LLM client with specified provider and model.""" self.provider = provider or self._detect_provider() self.model = model or self._get_default_model(self.provider) self.llm = None self.output_parser = StrOutputParser() self._initialize_llm() def _detect_provider(self) -> str: """Auto-detect available LLM provider based on environment variables.""" # Check for API keys in order of preference if os.getenv('ANTHROPIC_API_KEY'): return 'anthropic' elif os.getenv('OPENAI_API_KEY'): return 'openai' elif os.getenv('OLLAMA_BASE_URL'): return 'ollama' else: # Default to OpenAI if no keys found return 'openai' def _get_default_model(self, provider: str) -> str: """Get default model for the specified provider.""" return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini') def _initialize_llm(self): """Initialize the LLM based on provider and model.""" try: if self.provider == 'openai': api_key = os.getenv('OPENAI_API_KEY') if not api_key: logger.warning("OpenAI API key not found") return self.llm = ChatOpenAI( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: logger.warning("Anthropic API key not found") return self.llm = ChatAnthropic( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') # Check if model is available, if not try to pull it if not self._check_ollama_model_available(base_url, self.model): logger.info(f"Model {self.model} not found, attempting to pull...") if self._pull_ollama_model(base_url, self.model): logger.info(f"Successfully pulled model {self.model}") else: logger.error(f"Failed to pull model {self.model}") return self.llm = Ollama( model=self.model, base_url=base_url, temperature=0.1 ) if self.llm: logger.info(f"LLM client initialized: {self.provider} with model {self.model}") else: logger.error(f"Failed to initialize LLM client for provider: {self.provider}") except Exception as e: logger.error(f"Error initializing LLM client: {e}") self.llm = None def is_available(self) -> bool: """Check if LLM client is available and configured.""" return self.llm is not None def get_provider_info(self) -> Dict[str, Any]: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) # For Ollama, get actually available models available_models = provider_info.get('models', []) if self.provider == 'ollama': ollama_models = self._get_ollama_available_models() if ollama_models: available_models = ollama_models return { 'provider': self.provider, 'model': self.model, 'available': self.is_available(), 'supported_models': provider_info.get('models', []), 'available_models': available_models, 'env_key': provider_info.get('env_key', ''), 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) } async def generate_sigma_rule(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> Optional[str]: """ Generate or enhance a SIGMA rule using the configured LLM. Args: cve_id: CVE identifier poc_content: Proof-of-concept code content from GitHub cve_description: CVE description from NVD existing_rule: Optional existing SIGMA rule to enhance Returns: Generated SIGMA rule YAML content or None if failed """ if not self.is_available(): logger.warning("LLM client not available") return None try: # Create the prompt template prompt = self._build_sigma_generation_prompt( cve_id, poc_content, cve_description, existing_rule ) # Create the chain chain = prompt | self.llm | self.output_parser # Debug: Log what we're sending to the LLM input_data = { "cve_id": cve_id, "poc_content": poc_content[:4000], # Truncate if too long "cve_description": cve_description, "existing_rule": existing_rule or "None" } logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}") # Generate the response response = await chain.ainvoke(input_data) # Debug: Log raw LLM response logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...") # Extract the SIGMA rule from response sigma_rule = self._extract_sigma_rule(response) # Post-process to ensure clean YAML sigma_rule = self._post_process_sigma_rule(sigma_rule) # Debug: Log final processed rule logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...") logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}") return sigma_rule except Exception as e: logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}") return None def _build_sigma_generation_prompt(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> ChatPromptTemplate: """Build the prompt template for SIGMA rule generation.""" system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification. **CRITICAL: You must follow the exact SIGMA specification format:** 1. **YAML Structure Requirements:** - Use UTF-8 encoding with LF line breaks - Indent with 4 spaces (no tabs) - Use lowercase keys only - Use single quotes for string values - No quotes for numeric values - Follow proper YAML syntax 2. **MANDATORY Fields (must include):** - title: Brief description (max 256 chars) - logsource: Log data source specification - detection: Search identifiers and conditions - condition: How detection elements combine 3. **RECOMMENDED Fields:** - id: Unique UUID - status: 'experimental' (for new rules) - description: Detailed explanation - author: 'AI Generated' - date: Current date (YYYY/MM/DD) - references: Array with CVE link - tags: MITRE ATT&CK techniques 4. **Detection Structure:** - Use selection blocks (selection, selection1, etc.) - Condition references these selections - Use proper field names (Image, CommandLine, ProcessName, etc.) - Support wildcards (*) and value lists **ABSOLUTE REQUIREMENTS:** - Output ONLY valid YAML - NO explanatory text before or after - NO comments or instructions - NO markdown formatting or code blocks - NEVER repeat the input prompt or template - NEVER include variables like {cve_id} or {poc_content} - NO "Human:", "CVE ID:", "Description:" headers - NO "Analyze this" or "Output only" text - Start IMMEDIATELY with 'title:' - End with the last YAML line only - Ensure perfect YAML syntax **STRUCTURE REQUIREMENTS:** - title: Descriptive title that MUST include the exact CVE ID provided by the user - id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012') - status: experimental - description: Specific description based on CVE and PoC analysis - author: 'AI Generated' - date: Current date (2025/01/11) - references: Include the EXACT CVE URL with the CVE ID provided by the user - tags: Relevant MITRE ATT&CK techniques based on PoC analysis - logsource: Appropriate category based on exploit type - detection: Specific indicators from PoC analysis (NOT generic examples) - condition: Logic connecting the detection selections **CRITICAL RULES:** 1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID 2. Analyze the provided CVE and PoC content to create SPECIFIC detection patterns 3. DO NOT hallucinate or invent CVE IDs from your training data 4. Use the CVE ID exactly as provided in the title and references""" if existing_rule: user_template = """CVE ID: {cve_id} CVE Description: {cve_description} PoC Code: {poc_content} Existing SIGMA Rule: {existing_rule} Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'.""" else: user_template = """CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE: **MANDATORY CVE ID TO USE: {cve_id}** **CVE Description: {cve_description}** **Proof-of-Concept Code Analysis:** {poc_content} **CRITICAL REQUIREMENTS:** 1. Use EXACTLY this CVE ID in the title: {cve_id} 2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{cve_id} 3. Analyze the CVE description to understand the vulnerability type 4. Extract specific indicators from the PoC code (files, processes, commands, network patterns) 5. Create detection logic based on the actual exploit behavior 6. Use relevant logsource category (process_creation, file_event, network_connection, etc.) 7. Include appropriate MITRE ATT&CK tags based on the exploit techniques **IMPORTANT: You MUST use the exact CVE ID "{cve_id}" - do NOT generate a different CVE ID!** Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {cve_id}.""" return ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) def _extract_sigma_rule(self, response_text: str) -> str: """Extract and clean SIGMA rule YAML from LLM response.""" lines = response_text.split('\n') yaml_lines = [] in_yaml_block = False found_title = False for line in lines: stripped = line.strip() # Skip code block markers if stripped.startswith('```'): if stripped.startswith('```yaml'): in_yaml_block = True elif stripped == '```' and in_yaml_block: break continue # Skip obvious non-YAML content if not in_yaml_block and not found_title: if not stripped.startswith('title:'): # Skip explanatory text and prompt artifacts skip_phrases = [ 'please note', 'this rule', 'you should', 'analysis:', 'explanation:', 'based on', 'the following', 'here is', 'note that', 'important:', 'remember', 'this is a', 'make sure to', 'you can modify', 'adjust the', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only' ] if any(phrase in stripped.lower() for phrase in skip_phrases): continue # Skip template variables and prompt artifacts if '{' in stripped and '}' in stripped: continue # Skip lines that are clearly not YAML structure if stripped and not ':' in stripped and len(stripped) > 20: continue # Start collecting when we find title or are in YAML block if stripped.startswith('title:') or in_yaml_block: found_title = True in_yaml_block = True # Skip explanatory comments if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()): continue yaml_lines.append(line) # Stop if we encounter obvious non-YAML after starting elif found_title: # Stop at explanatory text after the rule stop_phrases = [ 'please note', 'this rule should', 'make sure to', 'you can modify', 'adjust the', 'also, this is', 'based on the analysis', 'the rule above' ] if any(phrase in stripped.lower() for phrase in stop_phrases): break # Stop at lines without colons that aren't indented (likely explanations) if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped: break if not yaml_lines: # Fallback: look for any line with YAML-like structure for line in lines: if ':' in line and not line.strip().startswith('#'): yaml_lines.append(line) return '\n'.join(yaml_lines).strip() def _post_process_sigma_rule(self, rule_content: str) -> str: """Post-process SIGMA rule to ensure clean YAML format.""" lines = rule_content.split('\n') cleaned_lines = [] for line in lines: stripped = line.strip() # Skip obvious non-YAML content and prompt artifacts if any(phrase in stripped.lower() for phrase in [ 'please note', 'you should replace', 'this is a proof-of-concept', 'please make sure', 'note that', 'important:', 'remember to', 'analysis shows', 'based on the', 'the rule above', 'this rule', 'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:', 'analyze this', 'create a', 'output only', 'generate a' ]): continue # Skip template variables if '{' in stripped and '}' in stripped: continue # Skip lines that look like explanations if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '): # This might be explanatory text, skip it if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']): continue # Skip empty explanatory sections if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']: continue cleaned_lines.append(line) return '\n'.join(cleaned_lines).strip() async def enhance_existing_rule(self, existing_rule: str, poc_content: str, cve_id: str) -> Optional[str]: """ Enhance an existing SIGMA rule with PoC analysis. Args: existing_rule: Existing SIGMA rule YAML poc_content: PoC code content cve_id: CVE identifier Returns: Enhanced SIGMA rule or None if failed """ if not self.is_available(): return None try: system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. **Task:** Enhance the existing rule by: 1. Adding more specific detection patterns found in the PoC 2. Improving the condition logic 3. Adding relevant tags or MITRE ATT&CK mappings 4. Keeping the rule structure intact but making it more effective Output ONLY the enhanced SIGMA rule in valid YAML format.""" user_template = """**CVE ID:** {cve_id} **PoC Code:** ``` {poc_content} ``` **Existing SIGMA Rule:** ```yaml {existing_rule} ```""" prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) chain = prompt | self.llm | self.output_parser response = await chain.ainvoke({ "cve_id": cve_id, "poc_content": poc_content[:3000], "existing_rule": existing_rule }) enhanced_rule = self._extract_sigma_rule(response) enhanced_rule = self._post_process_sigma_rule(enhanced_rule) logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") return enhanced_rule except Exception as e: logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") return None def validate_sigma_rule(self, rule_content: str) -> bool: """Validate that the generated rule follows SIGMA specification.""" try: # Parse as YAML parsed = yaml.safe_load(rule_content) if not isinstance(parsed, dict): logger.warning("Rule must be a YAML dictionary") return False # Check MANDATORY fields per SIGMA spec mandatory_fields = ['title', 'logsource', 'detection'] for field in mandatory_fields: if field not in parsed: logger.warning(f"Missing mandatory field: {field}") return False # Validate title title = parsed.get('title', '') if not isinstance(title, str) or len(title) > 256: logger.warning("Title must be string ≤256 characters") return False # Validate logsource structure logsource = parsed.get('logsource', {}) if not isinstance(logsource, dict): logger.warning("Logsource must be a dictionary") return False # Validate detection structure detection = parsed.get('detection', {}) if not isinstance(detection, dict): logger.warning("Detection must be a dictionary") return False # Check for condition (can be in detection or at root) has_condition = 'condition' in detection or 'condition' in parsed if not has_condition: logger.warning("Missing condition field") return False # Check for at least one selection selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter'] for key in detection.keys() if key != 'condition') if not selection_found: logger.warning("Detection must have at least one selection") return False # Validate status if present status = parsed.get('status') if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']: logger.warning(f"Invalid status: {status}") return False logger.info("SIGMA rule validation passed") return True except yaml.YAMLError as e: logger.warning(f"YAML parsing error: {e}") return False except Exception as e: logger.warning(f"Rule validation error: {e}") return False @classmethod def get_available_providers(cls) -> List[Dict[str, Any]]: """Get list of available LLM providers and their configuration status.""" providers = [] for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items(): env_key = provider_info.get('env_key', '') api_key_configured = bool(os.getenv(env_key)) providers.append({ 'name': provider_name, 'models': provider_info.get('models', []), 'default_model': provider_info.get('default_model', ''), 'env_key': env_key, 'api_key_configured': api_key_configured, 'available': api_key_configured or provider_name == 'ollama' }) return providers def switch_provider(self, provider: str, model: str = None): """Switch to a different LLM provider and model.""" if provider not in self.SUPPORTED_PROVIDERS: raise ValueError(f"Unsupported provider: {provider}") self.provider = provider self.model = model or self._get_default_model(provider) self._initialize_llm() logger.info(f"Switched to provider: {provider} with model: {self.model}") def _check_ollama_model_available(self, base_url: str, model: str) -> bool: """Check if an Ollama model is available locally""" try: import requests response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) for m in models: if m.get('name', '').startswith(model + ':') or m.get('name') == model: return True return False except Exception as e: logger.error(f"Error checking Ollama models: {e}") return False def _pull_ollama_model(self, base_url: str, model: str) -> bool: """Pull an Ollama model""" try: import requests import json # Use the pull API endpoint payload = {"name": model} response = requests.post( f"{base_url}/api/pull", json=payload, timeout=300, # 5 minutes timeout for model download stream=True ) if response.status_code == 200: # Stream the response to monitor progress for line in response.iter_lines(): if line: try: data = json.loads(line.decode('utf-8')) if data.get('status'): logger.info(f"Ollama pull progress: {data.get('status')}") if data.get('error'): logger.error(f"Ollama pull error: {data.get('error')}") return False except json.JSONDecodeError: continue return True else: logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") return False except Exception as e: logger.error(f"Error pulling Ollama model {model}: {e}") return False def _get_ollama_available_models(self) -> List[str]: """Get list of available Ollama models""" try: import requests base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) return [m.get('name', '') for m in models if m.get('name')] return [] except Exception as e: logger.error(f"Error getting Ollama models: {e}") return []