""" LangChain-based LLM client for enhanced SIGMA rule generation. Supports multiple LLM providers: OpenAI, Anthropic, and local models. """ import os import logging from typing import Optional, Dict, Any, List from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser import yaml logger = logging.getLogger(__name__) class LLMClient: """Multi-provider LLM client for SIGMA rule generation using LangChain.""" SUPPORTED_PROVIDERS = { 'openai': { 'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'], 'env_key': 'OPENAI_API_KEY', 'default_model': 'gpt-4o-mini' }, 'anthropic': { 'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'], 'env_key': 'ANTHROPIC_API_KEY', 'default_model': 'claude-3-5-sonnet-20241022' }, 'ollama': { 'models': ['llama3.2', 'codellama', 'mistral', 'llama2'], 'env_key': 'OLLAMA_BASE_URL', 'default_model': 'llama3.2' } } def __init__(self, provider: str = None, model: str = None): """Initialize LLM client with specified provider and model.""" self.provider = provider or self._detect_provider() self.model = model or self._get_default_model(self.provider) self.llm = None self.output_parser = StrOutputParser() self._initialize_llm() def _detect_provider(self) -> str: """Auto-detect available LLM provider based on environment variables.""" # Check for API keys in order of preference if os.getenv('ANTHROPIC_API_KEY'): return 'anthropic' elif os.getenv('OPENAI_API_KEY'): return 'openai' elif os.getenv('OLLAMA_BASE_URL'): return 'ollama' else: # Default to OpenAI if no keys found return 'openai' def _get_default_model(self, provider: str) -> str: """Get default model for the specified provider.""" return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini') def _initialize_llm(self): """Initialize the LLM based on provider and model.""" try: if self.provider == 'openai': api_key = os.getenv('OPENAI_API_KEY') if not api_key: logger.warning("OpenAI API key not found") return self.llm = ChatOpenAI( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: logger.warning("Anthropic API key not found") return self.llm = ChatAnthropic( model=self.model, api_key=api_key, temperature=0.1, max_tokens=2000 ) elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') # Check if model is available, if not try to pull it if not self._check_ollama_model_available(base_url, self.model): logger.info(f"Model {self.model} not found, attempting to pull...") if self._pull_ollama_model(base_url, self.model): logger.info(f"Successfully pulled model {self.model}") else: logger.error(f"Failed to pull model {self.model}") return self.llm = Ollama( model=self.model, base_url=base_url, temperature=0.1 ) if self.llm: logger.info(f"LLM client initialized: {self.provider} with model {self.model}") else: logger.error(f"Failed to initialize LLM client for provider: {self.provider}") except Exception as e: logger.error(f"Error initializing LLM client: {e}") self.llm = None def is_available(self) -> bool: """Check if LLM client is available and configured.""" return self.llm is not None def get_provider_info(self) -> Dict[str, Any]: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) # For Ollama, get actually available models available_models = provider_info.get('models', []) if self.provider == 'ollama': ollama_models = self._get_ollama_available_models() if ollama_models: available_models = ollama_models return { 'provider': self.provider, 'model': self.model, 'available': self.is_available(), 'supported_models': provider_info.get('models', []), 'available_models': available_models, 'env_key': provider_info.get('env_key', ''), 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) } async def generate_sigma_rule(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> Optional[str]: """ Generate or enhance a SIGMA rule using the configured LLM. Args: cve_id: CVE identifier poc_content: Proof-of-concept code content from GitHub cve_description: CVE description from NVD existing_rule: Optional existing SIGMA rule to enhance Returns: Generated SIGMA rule YAML content or None if failed """ if not self.is_available(): logger.warning("LLM client not available") return None try: # Create the prompt template prompt = self._build_sigma_generation_prompt( cve_id, poc_content, cve_description, existing_rule ) # Create the chain chain = prompt | self.llm | self.output_parser # Generate the response response = await chain.ainvoke({ "cve_id": cve_id, "poc_content": poc_content[:4000], # Truncate if too long "cve_description": cve_description, "existing_rule": existing_rule or "None" }) # Extract the SIGMA rule from response sigma_rule = self._extract_sigma_rule(response) logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}") return sigma_rule except Exception as e: logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}") return None def _build_sigma_generation_prompt(self, cve_id: str, poc_content: str, cve_description: str, existing_rule: Optional[str] = None) -> ChatPromptTemplate: """Build the prompt template for SIGMA rule generation.""" system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules. **Your Task:** 1. Analyze the exploit code to identify: - Process execution patterns - File system activities - Network connections - Registry modifications - Command line arguments - Suspicious behaviors 2. Create a SIGMA rule that: - Follows proper SIGMA syntax (YAML format) - Includes appropriate detection logic - Has relevant metadata (title, description, author, date, references) - Uses correct field names for the target log source - Includes proper condition logic - Maps to relevant MITRE ATT&CK techniques when applicable 3. Focus on detection patterns that would catch this specific exploit in action **Important Requirements:** - Output ONLY the SIGMA rule in valid YAML format - Do not include explanations or comments outside the YAML - Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition - Make the rule specific enough to detect the exploit but not too narrow to miss variants - Include relevant tags and MITRE ATT&CK technique mappings""" if existing_rule: user_template = """**CVE Information:** - CVE ID: {cve_id} - Description: {cve_description} **Proof-of-Concept Code:** ``` {poc_content} ``` **Existing SIGMA Rule (to enhance):** ```yaml {existing_rule} ``` Please enhance the existing rule with insights from the PoC code analysis.""" else: user_template = """**CVE Information:** - CVE ID: {cve_id} - Description: {cve_description} **Proof-of-Concept Code:** ``` {poc_content} ``` Please create a new SIGMA rule based on the PoC code analysis.""" return ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) def _extract_sigma_rule(self, response_text: str) -> str: """Extract SIGMA rule YAML from LLM response.""" # Look for YAML content in the response lines = response_text.split('\n') yaml_lines = [] in_yaml = False for line in lines: if line.strip().startswith('```yaml') or line.strip().startswith('```'): in_yaml = True continue elif line.strip() == '```' and in_yaml: break elif in_yaml or line.strip().startswith('title:'): yaml_lines.append(line) in_yaml = True if not yaml_lines: # If no YAML block found, return the whole response return response_text.strip() return '\n'.join(yaml_lines).strip() async def enhance_existing_rule(self, existing_rule: str, poc_content: str, cve_id: str) -> Optional[str]: """ Enhance an existing SIGMA rule with PoC analysis. Args: existing_rule: Existing SIGMA rule YAML poc_content: PoC code content cve_id: CVE identifier Returns: Enhanced SIGMA rule or None if failed """ if not self.is_available(): return None try: system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. **Task:** Enhance the existing rule by: 1. Adding more specific detection patterns found in the PoC 2. Improving the condition logic 3. Adding relevant tags or MITRE ATT&CK mappings 4. Keeping the rule structure intact but making it more effective Output ONLY the enhanced SIGMA rule in valid YAML format.""" user_template = """**CVE ID:** {cve_id} **PoC Code:** ``` {poc_content} ``` **Existing SIGMA Rule:** ```yaml {existing_rule} ```""" prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=system_message), HumanMessage(content=user_template) ]) chain = prompt | self.llm | self.output_parser response = await chain.ainvoke({ "cve_id": cve_id, "poc_content": poc_content[:3000], "existing_rule": existing_rule }) enhanced_rule = self._extract_sigma_rule(response) logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") return enhanced_rule except Exception as e: logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") return None def validate_sigma_rule(self, rule_content: str) -> bool: """Validate that the generated rule is syntactically correct SIGMA.""" try: # Parse as YAML parsed = yaml.safe_load(rule_content) # Check required fields required_fields = ['title', 'id', 'description', 'logsource', 'detection'] for field in required_fields: if field not in parsed: logger.warning(f"Missing required field: {field}") return False # Check detection structure detection = parsed.get('detection', {}) if not isinstance(detection, dict): logger.warning("Detection field must be a dictionary") return False # Should have at least one selection and a condition if 'condition' not in detection: logger.warning("Detection must have a condition") return False # Check logsource structure logsource = parsed.get('logsource', {}) if not isinstance(logsource, dict): logger.warning("Logsource field must be a dictionary") return False logger.info("SIGMA rule validation passed") return True except yaml.YAMLError as e: logger.warning(f"YAML parsing error: {e}") return False except Exception as e: logger.warning(f"Rule validation error: {e}") return False @classmethod def get_available_providers(cls) -> List[Dict[str, Any]]: """Get list of available LLM providers and their configuration status.""" providers = [] for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items(): env_key = provider_info.get('env_key', '') api_key_configured = bool(os.getenv(env_key)) providers.append({ 'name': provider_name, 'models': provider_info.get('models', []), 'default_model': provider_info.get('default_model', ''), 'env_key': env_key, 'api_key_configured': api_key_configured, 'available': api_key_configured or provider_name == 'ollama' }) return providers def switch_provider(self, provider: str, model: str = None): """Switch to a different LLM provider and model.""" if provider not in self.SUPPORTED_PROVIDERS: raise ValueError(f"Unsupported provider: {provider}") self.provider = provider self.model = model or self._get_default_model(provider) self._initialize_llm() logger.info(f"Switched to provider: {provider} with model: {self.model}") def _check_ollama_model_available(self, base_url: str, model: str) -> bool: """Check if an Ollama model is available locally""" try: import requests response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) for m in models: if m.get('name', '').startswith(model + ':') or m.get('name') == model: return True return False except Exception as e: logger.error(f"Error checking Ollama models: {e}") return False def _pull_ollama_model(self, base_url: str, model: str) -> bool: """Pull an Ollama model""" try: import requests import json # Use the pull API endpoint payload = {"name": model} response = requests.post( f"{base_url}/api/pull", json=payload, timeout=300, # 5 minutes timeout for model download stream=True ) if response.status_code == 200: # Stream the response to monitor progress for line in response.iter_lines(): if line: try: data = json.loads(line.decode('utf-8')) if data.get('status'): logger.info(f"Ollama pull progress: {data.get('status')}") if data.get('error'): logger.error(f"Ollama pull error: {data.get('error')}") return False except json.JSONDecodeError: continue return True else: logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") return False except Exception as e: logger.error(f"Error pulling Ollama model {model}: {e}") return False def _get_ollama_available_models(self) -> List[str]: """Get list of available Ollama models""" try: import requests base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) return [m.get('name', '') for m in models if m.get('name')] return [] except Exception as e: logger.error(f"Error getting Ollama models: {e}") return []