diff --git a/.env.example b/.env.example index 785ac9c..415794b 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,31 @@ NVD_API_KEY=your_nvd_api_key_here # Only needs "public_repo" scope for searching public repositories GITHUB_TOKEN=your_github_token_here +# LLM API Configuration (Optional - for enhanced SIGMA rule generation) +# Choose your preferred LLM provider and configure the corresponding API key + +# OpenAI Configuration +# Get your API key at: https://platform.openai.com/api-keys +OPENAI_API_KEY=your_openai_api_key_here + +# Anthropic Configuration +# Get your API key at: https://console.anthropic.com/ +ANTHROPIC_API_KEY=your_anthropic_api_key_here + +# Ollama Configuration (for local models) +# Install Ollama locally: https://ollama.ai/ +OLLAMA_BASE_URL=http://localhost:11434 + +# LLM Provider Selection (optional - auto-detects if not specified) +# Options: openai, anthropic, ollama +LLM_PROVIDER=openai + +# LLM Model Selection (optional - uses provider default if not specified) +# OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo +# Anthropic: claude-3-5-sonnet-20241022, claude-3-haiku-20240307, claude-3-opus-20240229 +# Ollama: llama3.2, codellama, mistral, llama2 +LLM_MODEL=gpt-4o-mini + # Database Configuration (Docker Compose will use defaults) # DATABASE_URL=postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db diff --git a/backend/claude_client.py b/backend/claude_client.py new file mode 100644 index 0000000..565372a --- /dev/null +++ b/backend/claude_client.py @@ -0,0 +1,221 @@ +""" +Claude API client for enhanced SIGMA rule generation. +""" +import os +import logging +from typing import Optional, Dict, Any +from anthropic import Anthropic + +logger = logging.getLogger(__name__) + +class ClaudeClient: + """Client for interacting with Claude API for SIGMA rule generation.""" + + def __init__(self, api_key: Optional[str] = None): + """Initialize Claude client with API key.""" + self.api_key = api_key or os.getenv('CLAUDE_API_KEY') + self.client = None + + if self.api_key: + try: + self.client = Anthropic(api_key=self.api_key) + logger.info("Claude API client initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize Claude API client: {e}") + self.client = None + else: + logger.warning("No Claude API key provided. Claude-enhanced rule generation disabled.") + + def is_available(self) -> bool: + """Check if Claude API client is available.""" + return self.client is not None + + async def generate_sigma_rule(self, + cve_id: str, + poc_content: str, + cve_description: str, + existing_rule: Optional[str] = None) -> Optional[str]: + """ + Generate or enhance a SIGMA rule using Claude API. + + Args: + cve_id: CVE identifier + poc_content: Proof-of-concept code content from GitHub + cve_description: CVE description from NVD + existing_rule: Optional existing SIGMA rule to enhance + + Returns: + Generated SIGMA rule YAML content or None if failed + """ + if not self.is_available(): + logger.warning("Claude API client not available") + return None + + try: + # Construct the prompt for Claude + prompt = self._build_sigma_generation_prompt( + cve_id, poc_content, cve_description, existing_rule + ) + + # Make API call to Claude + response = self.client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=2000, + temperature=0.1, + messages=[ + { + "role": "user", + "content": prompt + } + ] + ) + + # Extract the SIGMA rule from response + sigma_rule = self._extract_sigma_rule(response.content[0].text) + + logger.info(f"Successfully generated SIGMA rule for {cve_id} using Claude") + return sigma_rule + + except Exception as e: + logger.error(f"Failed to generate SIGMA rule for {cve_id} using Claude: {e}") + return None + + def _build_sigma_generation_prompt(self, + cve_id: str, + poc_content: str, + cve_description: str, + existing_rule: Optional[str] = None) -> str: + """Build the prompt for Claude to generate SIGMA rules.""" + + base_prompt = f"""You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules. + +**CVE Information:** +- CVE ID: {cve_id} +- Description: {cve_description} + +**Proof-of-Concept Code:** +``` +{poc_content[:4000]} # Truncate if too long +``` + +**Your Task:** +1. Analyze the exploit code to identify: + - Process execution patterns + - File system activities + - Network connections + - Registry modifications + - Command line arguments + - Suspicious behaviors + +2. Create a SIGMA rule that: + - Follows proper SIGMA syntax (YAML format) + - Includes appropriate detection logic + - Has relevant metadata (title, description, author, date, references) + - Uses correct field names for the target log source + - Includes proper condition logic + - Maps to relevant MITRE ATT&CK techniques when applicable + +3. Focus on detection patterns that would catch this specific exploit in action + +**Important Requirements:** +- Output ONLY the SIGMA rule in valid YAML format +- Do not include explanations or comments outside the YAML +- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition +- Make the rule specific enough to detect the exploit but not too narrow to miss variants +- Include relevant tags and MITRE ATT&CK technique mappings""" + + if existing_rule: + base_prompt += f""" + +**Existing SIGMA Rule (to enhance):** +```yaml +{existing_rule} +``` + +Please enhance the existing rule with insights from the PoC code analysis.""" + + return base_prompt + + def _extract_sigma_rule(self, response_text: str) -> str: + """Extract SIGMA rule YAML from Claude's response.""" + # Look for YAML content in the response + lines = response_text.split('\n') + yaml_lines = [] + in_yaml = False + + for line in lines: + if line.strip().startswith('```yaml') or line.strip().startswith('```'): + in_yaml = True + continue + elif line.strip() == '```' and in_yaml: + break + elif in_yaml or line.strip().startswith('title:'): + yaml_lines.append(line) + in_yaml = True + + if not yaml_lines: + # If no YAML block found, return the whole response + return response_text.strip() + + return '\n'.join(yaml_lines).strip() + + async def enhance_existing_rule(self, + existing_rule: str, + poc_content: str, + cve_id: str) -> Optional[str]: + """ + Enhance an existing SIGMA rule with PoC analysis. + + Args: + existing_rule: Existing SIGMA rule YAML + poc_content: PoC code content + cve_id: CVE identifier + + Returns: + Enhanced SIGMA rule or None if failed + """ + if not self.is_available(): + return None + + try: + prompt = f"""You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. + +**CVE ID:** {cve_id} + +**PoC Code:** +``` +{poc_content[:3000]} +``` + +**Existing SIGMA Rule:** +```yaml +{existing_rule} +``` + +**Task:** Enhance the existing rule by: +1. Adding more specific detection patterns found in the PoC +2. Improving the condition logic +3. Adding relevant tags or MITRE ATT&CK mappings +4. Keeping the rule structure intact but making it more effective + +Output ONLY the enhanced SIGMA rule in valid YAML format.""" + + response = self.client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=2000, + temperature=0.1, + messages=[ + { + "role": "user", + "content": prompt + } + ] + ) + + enhanced_rule = self._extract_sigma_rule(response.content[0].text) + logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") + return enhanced_rule + + except Exception as e: + logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") + return None \ No newline at end of file diff --git a/backend/enhanced_sigma_generator.py b/backend/enhanced_sigma_generator.py index 8d73873..b25d637 100644 --- a/backend/enhanced_sigma_generator.py +++ b/backend/enhanced_sigma_generator.py @@ -9,6 +9,7 @@ from datetime import datetime from typing import Dict, List, Optional, Tuple from sqlalchemy.orm import Session import re +from llm_client import LLMClient # Configure logging logging.basicConfig(level=logging.INFO) @@ -17,10 +18,11 @@ logger = logging.getLogger(__name__) class EnhancedSigmaGenerator: """Enhanced SIGMA rule generator using nomi-sec PoC data""" - def __init__(self, db_session: Session): + def __init__(self, db_session: Session, llm_provider: str = None, llm_model: str = None): self.db_session = db_session + self.llm_client = LLMClient(provider=llm_provider, model=llm_model) - async def generate_enhanced_rule(self, cve) -> dict: + async def generate_enhanced_rule(self, cve, use_llm: bool = True) -> dict: """Generate enhanced SIGMA rule for a CVE using PoC data""" from main import SigmaRule, RuleTemplate @@ -33,15 +35,29 @@ class EnhancedSigmaGenerator: if poc_data: best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) - # Select appropriate template based on PoC analysis - template = await self._select_template(cve, best_poc) + # Try LLM-enhanced generation first if enabled and available + rule_content = None + generation_method = "template" - if not template: - logger.warning(f"No suitable template found for {cve.cve_id}") - return {'success': False, 'error': 'No suitable template'} + if use_llm and self.llm_client.is_available() and best_poc: + logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}") + rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data) + if rule_content: + generation_method = f"llm_{self.llm_client.provider}" - # Generate rule content - rule_content = await self._generate_rule_content(cve, template, poc_data) + # Fallback to template-based generation + if not rule_content: + logger.info(f"Using template-based rule generation for {cve.cve_id}") + + # Select appropriate template based on PoC analysis + template = await self._select_template(cve, best_poc) + + if not template: + logger.warning(f"No suitable template found for {cve.cve_id}") + return {'success': False, 'error': 'No suitable template'} + + # Generate rule content + rule_content = await self._generate_rule_content(cve, template, poc_data) # Calculate confidence level confidence_level = self._calculate_confidence_level(cve, poc_data) @@ -55,8 +71,8 @@ class EnhancedSigmaGenerator: 'cve_id': cve.cve_id, 'rule_name': f"{cve.cve_id} Enhanced Detection", 'rule_content': rule_content, - 'detection_type': template.template_name, - 'log_source': self._extract_log_source(template.template_name), + 'detection_type': f"{generation_method}_generated", + 'log_source': self._extract_log_source_from_content(rule_content), 'confidence_level': confidence_level, 'auto_generated': True, 'exploit_based': len(poc_data) > 0, @@ -67,7 +83,8 @@ class EnhancedSigmaGenerator: 'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0, 'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data), 'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0, - 'source': getattr(cve, 'poc_source', 'nomi_sec') + 'source': getattr(cve, 'poc_source', 'nomi_sec'), + 'generation_method': generation_method }, 'github_repos': [p.get('html_url', '') for p in poc_data], 'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)), @@ -100,6 +117,135 @@ class EnhancedSigmaGenerator: logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}") return {'success': False, 'error': str(e)} + async def _generate_llm_enhanced_rule(self, cve, best_poc: dict, poc_data: list) -> Optional[str]: + """Generate SIGMA rule using LLM API with PoC analysis""" + try: + # Get PoC content from the best quality PoC + poc_content = await self._extract_poc_content(best_poc) + if not poc_content: + logger.warning(f"No PoC content available for {cve.cve_id}") + return None + + # Generate rule using LLM + rule_content = await self.llm_client.generate_sigma_rule( + cve_id=cve.cve_id, + poc_content=poc_content, + cve_description=cve.description or "", + existing_rule=None + ) + + if rule_content: + # Validate the generated rule + if self.llm_client.validate_sigma_rule(rule_content): + logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}") + return rule_content + else: + logger.warning(f"Generated rule for {cve.cve_id} failed validation") + return None + + return None + + except Exception as e: + logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") + return None + + async def _extract_poc_content(self, poc: dict) -> Optional[str]: + """Extract actual code content from PoC repository""" + try: + import aiohttp + import asyncio + + # Get repository information + repo_url = poc.get('html_url', '') + if not repo_url: + return None + + # Convert GitHub URL to API URL for repository content + if 'github.com' in repo_url: + # Extract owner and repo from URL + parts = repo_url.rstrip('/').split('/') + if len(parts) >= 2: + owner = parts[-2] + repo = parts[-1] + + # Get repository files via GitHub API + api_url = f"https://api.github.com/repos/{owner}/{repo}/contents" + + async with aiohttp.ClientSession() as session: + # Add timeout to prevent hanging + timeout = aiohttp.ClientTimeout(total=30) + async with session.get(api_url, timeout=timeout) as response: + if response.status == 200: + contents = await response.json() + + # Look for common exploit files + target_files = [ + 'exploit.py', 'poc.py', 'exploit.c', 'exploit.cpp', + 'exploit.java', 'exploit.rb', 'exploit.php', + 'exploit.js', 'exploit.sh', 'exploit.ps1', + 'README.md', 'main.py', 'index.js' + ] + + for file_info in contents: + if file_info.get('type') == 'file': + filename = file_info.get('name', '').lower() + + # Check if this is a target file + if any(target in filename for target in target_files): + file_url = file_info.get('download_url') + if file_url: + async with session.get(file_url, timeout=timeout) as file_response: + if file_response.status == 200: + content = await file_response.text() + # Limit content size + if len(content) > 10000: + content = content[:10000] + "\n... [content truncated]" + return content + + # If no specific exploit file found, return description/README + for file_info in contents: + if file_info.get('type') == 'file': + filename = file_info.get('name', '').lower() + if 'readme' in filename: + file_url = file_info.get('download_url') + if file_url: + async with session.get(file_url, timeout=timeout) as file_response: + if file_response.status == 200: + content = await file_response.text() + return content[:5000] # Smaller limit for README + + # Fallback to description and metadata + description = poc.get('description', '') + if description: + return f"Repository Description: {description}" + + return None + + except Exception as e: + logger.error(f"Error extracting PoC content: {e}") + return None + + + def _extract_log_source_from_content(self, rule_content: str) -> str: + """Extract log source from the generated rule content""" + try: + import yaml + parsed = yaml.safe_load(rule_content) + logsource = parsed.get('logsource', {}) + + category = logsource.get('category', '') + product = logsource.get('product', '') + + if category: + return category + elif product: + return product + else: + return 'generic' + + except Exception: + return 'generic' + async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]: """Select the most appropriate template based on CVE and PoC analysis""" from main import RuleTemplate diff --git a/backend/llm_client.py b/backend/llm_client.py new file mode 100644 index 0000000..c892e26 --- /dev/null +++ b/backend/llm_client.py @@ -0,0 +1,398 @@ +""" +LangChain-based LLM client for enhanced SIGMA rule generation. +Supports multiple LLM providers: OpenAI, Anthropic, and local models. +""" +import os +import logging +from typing import Optional, Dict, Any, List +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_community.llms import Ollama +from langchain_core.output_parsers import StrOutputParser +import yaml + +logger = logging.getLogger(__name__) + +class LLMClient: + """Multi-provider LLM client for SIGMA rule generation using LangChain.""" + + SUPPORTED_PROVIDERS = { + 'openai': { + 'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'], + 'env_key': 'OPENAI_API_KEY', + 'default_model': 'gpt-4o-mini' + }, + 'anthropic': { + 'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'], + 'env_key': 'ANTHROPIC_API_KEY', + 'default_model': 'claude-3-5-sonnet-20241022' + }, + 'ollama': { + 'models': ['llama3.2', 'codellama', 'mistral', 'llama2'], + 'env_key': 'OLLAMA_BASE_URL', + 'default_model': 'llama3.2' + } + } + + def __init__(self, provider: str = None, model: str = None): + """Initialize LLM client with specified provider and model.""" + self.provider = provider or self._detect_provider() + self.model = model or self._get_default_model(self.provider) + self.llm = None + self.output_parser = StrOutputParser() + + self._initialize_llm() + + def _detect_provider(self) -> str: + """Auto-detect available LLM provider based on environment variables.""" + # Check for API keys in order of preference + if os.getenv('ANTHROPIC_API_KEY'): + return 'anthropic' + elif os.getenv('OPENAI_API_KEY'): + return 'openai' + elif os.getenv('OLLAMA_BASE_URL'): + return 'ollama' + else: + # Default to OpenAI if no keys found + return 'openai' + + def _get_default_model(self, provider: str) -> str: + """Get default model for the specified provider.""" + return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini') + + def _initialize_llm(self): + """Initialize the LLM based on provider and model.""" + try: + if self.provider == 'openai': + api_key = os.getenv('OPENAI_API_KEY') + if not api_key: + logger.warning("OpenAI API key not found") + return + + self.llm = ChatOpenAI( + model=self.model, + api_key=api_key, + temperature=0.1, + max_tokens=2000 + ) + + elif self.provider == 'anthropic': + api_key = os.getenv('ANTHROPIC_API_KEY') + if not api_key: + logger.warning("Anthropic API key not found") + return + + self.llm = ChatAnthropic( + model=self.model, + api_key=api_key, + temperature=0.1, + max_tokens=2000 + ) + + elif self.provider == 'ollama': + base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + + self.llm = Ollama( + model=self.model, + base_url=base_url, + temperature=0.1 + ) + + if self.llm: + logger.info(f"LLM client initialized: {self.provider} with model {self.model}") + else: + logger.error(f"Failed to initialize LLM client for provider: {self.provider}") + + except Exception as e: + logger.error(f"Error initializing LLM client: {e}") + self.llm = None + + def is_available(self) -> bool: + """Check if LLM client is available and configured.""" + return self.llm is not None + + def get_provider_info(self) -> Dict[str, Any]: + """Get information about the current provider and configuration.""" + provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) + return { + 'provider': self.provider, + 'model': self.model, + 'available': self.is_available(), + 'supported_models': provider_info.get('models', []), + 'env_key': provider_info.get('env_key', ''), + 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) + } + + async def generate_sigma_rule(self, + cve_id: str, + poc_content: str, + cve_description: str, + existing_rule: Optional[str] = None) -> Optional[str]: + """ + Generate or enhance a SIGMA rule using the configured LLM. + + Args: + cve_id: CVE identifier + poc_content: Proof-of-concept code content from GitHub + cve_description: CVE description from NVD + existing_rule: Optional existing SIGMA rule to enhance + + Returns: + Generated SIGMA rule YAML content or None if failed + """ + if not self.is_available(): + logger.warning("LLM client not available") + return None + + try: + # Create the prompt template + prompt = self._build_sigma_generation_prompt( + cve_id, poc_content, cve_description, existing_rule + ) + + # Create the chain + chain = prompt | self.llm | self.output_parser + + # Generate the response + response = await chain.ainvoke({ + "cve_id": cve_id, + "poc_content": poc_content[:4000], # Truncate if too long + "cve_description": cve_description, + "existing_rule": existing_rule or "None" + }) + + # Extract the SIGMA rule from response + sigma_rule = self._extract_sigma_rule(response) + + logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}") + return sigma_rule + + except Exception as e: + logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}") + return None + + def _build_sigma_generation_prompt(self, + cve_id: str, + poc_content: str, + cve_description: str, + existing_rule: Optional[str] = None) -> ChatPromptTemplate: + """Build the prompt template for SIGMA rule generation.""" + + system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules. + +**Your Task:** +1. Analyze the exploit code to identify: + - Process execution patterns + - File system activities + - Network connections + - Registry modifications + - Command line arguments + - Suspicious behaviors + +2. Create a SIGMA rule that: + - Follows proper SIGMA syntax (YAML format) + - Includes appropriate detection logic + - Has relevant metadata (title, description, author, date, references) + - Uses correct field names for the target log source + - Includes proper condition logic + - Maps to relevant MITRE ATT&CK techniques when applicable + +3. Focus on detection patterns that would catch this specific exploit in action + +**Important Requirements:** +- Output ONLY the SIGMA rule in valid YAML format +- Do not include explanations or comments outside the YAML +- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition +- Make the rule specific enough to detect the exploit but not too narrow to miss variants +- Include relevant tags and MITRE ATT&CK technique mappings""" + + if existing_rule: + user_template = """**CVE Information:** +- CVE ID: {cve_id} +- Description: {cve_description} + +**Proof-of-Concept Code:** +``` +{poc_content} +``` + +**Existing SIGMA Rule (to enhance):** +```yaml +{existing_rule} +``` + +Please enhance the existing rule with insights from the PoC code analysis.""" + else: + user_template = """**CVE Information:** +- CVE ID: {cve_id} +- Description: {cve_description} + +**Proof-of-Concept Code:** +``` +{poc_content} +``` + +Please create a new SIGMA rule based on the PoC code analysis.""" + + return ChatPromptTemplate.from_messages([ + SystemMessage(content=system_message), + HumanMessage(content=user_template) + ]) + + def _extract_sigma_rule(self, response_text: str) -> str: + """Extract SIGMA rule YAML from LLM response.""" + # Look for YAML content in the response + lines = response_text.split('\n') + yaml_lines = [] + in_yaml = False + + for line in lines: + if line.strip().startswith('```yaml') or line.strip().startswith('```'): + in_yaml = True + continue + elif line.strip() == '```' and in_yaml: + break + elif in_yaml or line.strip().startswith('title:'): + yaml_lines.append(line) + in_yaml = True + + if not yaml_lines: + # If no YAML block found, return the whole response + return response_text.strip() + + return '\n'.join(yaml_lines).strip() + + async def enhance_existing_rule(self, + existing_rule: str, + poc_content: str, + cve_id: str) -> Optional[str]: + """ + Enhance an existing SIGMA rule with PoC analysis. + + Args: + existing_rule: Existing SIGMA rule YAML + poc_content: PoC code content + cve_id: CVE identifier + + Returns: + Enhanced SIGMA rule or None if failed + """ + if not self.is_available(): + return None + + try: + system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns. + +**Task:** Enhance the existing rule by: +1. Adding more specific detection patterns found in the PoC +2. Improving the condition logic +3. Adding relevant tags or MITRE ATT&CK mappings +4. Keeping the rule structure intact but making it more effective + +Output ONLY the enhanced SIGMA rule in valid YAML format.""" + + user_template = """**CVE ID:** {cve_id} + +**PoC Code:** +``` +{poc_content} +``` + +**Existing SIGMA Rule:** +```yaml +{existing_rule} +```""" + + prompt = ChatPromptTemplate.from_messages([ + SystemMessage(content=system_message), + HumanMessage(content=user_template) + ]) + + chain = prompt | self.llm | self.output_parser + + response = await chain.ainvoke({ + "cve_id": cve_id, + "poc_content": poc_content[:3000], + "existing_rule": existing_rule + }) + + enhanced_rule = self._extract_sigma_rule(response) + logger.info(f"Successfully enhanced SIGMA rule for {cve_id}") + return enhanced_rule + + except Exception as e: + logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}") + return None + + def validate_sigma_rule(self, rule_content: str) -> bool: + """Validate that the generated rule is syntactically correct SIGMA.""" + try: + # Parse as YAML + parsed = yaml.safe_load(rule_content) + + # Check required fields + required_fields = ['title', 'id', 'description', 'logsource', 'detection'] + for field in required_fields: + if field not in parsed: + logger.warning(f"Missing required field: {field}") + return False + + # Check detection structure + detection = parsed.get('detection', {}) + if not isinstance(detection, dict): + logger.warning("Detection field must be a dictionary") + return False + + # Should have at least one selection and a condition + if 'condition' not in detection: + logger.warning("Detection must have a condition") + return False + + # Check logsource structure + logsource = parsed.get('logsource', {}) + if not isinstance(logsource, dict): + logger.warning("Logsource field must be a dictionary") + return False + + logger.info("SIGMA rule validation passed") + return True + + except yaml.YAMLError as e: + logger.warning(f"YAML parsing error: {e}") + return False + except Exception as e: + logger.warning(f"Rule validation error: {e}") + return False + + @classmethod + def get_available_providers(cls) -> List[Dict[str, Any]]: + """Get list of available LLM providers and their configuration status.""" + providers = [] + + for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items(): + env_key = provider_info.get('env_key', '') + api_key_configured = bool(os.getenv(env_key)) + + providers.append({ + 'name': provider_name, + 'models': provider_info.get('models', []), + 'default_model': provider_info.get('default_model', ''), + 'env_key': env_key, + 'api_key_configured': api_key_configured, + 'available': api_key_configured or provider_name == 'ollama' + }) + + return providers + + def switch_provider(self, provider: str, model: str = None): + """Switch to a different LLM provider and model.""" + if provider not in self.SUPPORTED_PROVIDERS: + raise ValueError(f"Unsupported provider: {provider}") + + self.provider = provider + self.model = model or self._get_default_model(provider) + self._initialize_llm() + + logger.info(f"Switched to provider: {provider} with model: {self.model}") \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index f13e87f..e5a29d1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1414,6 +1414,171 @@ async def regenerate_sigma_rules(background_tasks: BackgroundTasks, "force": request.force } +@app.post("/api/llm-enhanced-rules") +async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): + """Generate SIGMA rules using LLM API for enhanced analysis""" + + # Parse request parameters + cve_id = request.get('cve_id') + force = request.get('force', False) + llm_provider = request.get('provider', os.getenv('LLM_PROVIDER')) + llm_model = request.get('model', os.getenv('LLM_MODEL')) + + # Validation + if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id): + raise HTTPException(status_code=400, detail="Invalid CVE ID format") + + async def llm_generation_task(): + """Background task for LLM-enhanced rule generation""" + try: + from enhanced_sigma_generator import EnhancedSigmaGenerator + + generator = EnhancedSigmaGenerator(db, llm_provider, llm_model) + + # Process specific CVE or all CVEs with PoC data + if cve_id: + cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() + if not cve: + logger.error(f"CVE {cve_id} not found") + return + + cves_to_process = [cve] + else: + # Process CVEs with PoC data that either have no rules or force update + query = db.query(CVE).filter(CVE.poc_count > 0) + + if not force: + # Only process CVEs without existing LLM-generated rules + existing_llm_rules = db.query(SigmaRule).filter( + SigmaRule.detection_type.like('llm_%') + ).all() + existing_cve_ids = {rule.cve_id for rule in existing_llm_rules} + cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids] + else: + cves_to_process = query.all() + + logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}") + + rules_generated = 0 + rules_updated = 0 + failures = 0 + + for cve in cves_to_process: + try: + # Check if CVE has sufficient PoC data + if not cve.poc_data or not cve.poc_count: + logger.debug(f"Skipping {cve.cve_id} - no PoC data") + continue + + # Generate LLM-enhanced rule + result = await generator.generate_enhanced_rule(cve, use_llm=True) + + if result.get('success'): + if result.get('updated'): + rules_updated += 1 + else: + rules_generated += 1 + + logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}") + else: + failures += 1 + logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}") + + except Exception as e: + failures += 1 + logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") + continue + + logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures") + + except Exception as e: + logger.error(f"LLM-enhanced rule generation failed: {e}") + import traceback + traceback.print_exc() + + background_tasks.add_task(llm_generation_task) + + return { + "message": "LLM-enhanced SIGMA rule generation started", + "status": "started", + "cve_id": cve_id, + "force": force, + "provider": llm_provider, + "model": llm_model, + "note": "Requires appropriate LLM API key to be set" + } + +@app.get("/api/llm-status") +async def get_llm_status(): + """Check LLM API availability status""" + try: + from llm_client import LLMClient + + # Get current provider configuration + provider = os.getenv('LLM_PROVIDER') + model = os.getenv('LLM_MODEL') + + client = LLMClient(provider=provider, model=model) + provider_info = client.get_provider_info() + + # Get all available providers + all_providers = LLMClient.get_available_providers() + + return { + "current_provider": provider_info, + "available_providers": all_providers, + "status": "ready" if client.is_available() else "unavailable" + } + except Exception as e: + logger.error(f"Error checking LLM status: {e}") + return { + "current_provider": {"provider": "unknown", "available": False}, + "available_providers": [], + "status": "error", + "error": str(e) + } + +@app.post("/api/llm-switch") +async def switch_llm_provider(request: dict): + """Switch LLM provider and model""" + try: + from llm_client import LLMClient + + provider = request.get('provider') + model = request.get('model') + + if not provider: + raise HTTPException(status_code=400, detail="Provider is required") + + # Validate provider + if provider not in LLMClient.SUPPORTED_PROVIDERS: + raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}") + + # Test the new configuration + client = LLMClient(provider=provider, model=model) + + if not client.is_available(): + raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured") + + # Update environment variables (note: this only affects the current session) + os.environ['LLM_PROVIDER'] = provider + if model: + os.environ['LLM_MODEL'] = model + + provider_info = client.get_provider_info() + + return { + "message": f"Switched to {provider}", + "provider_info": provider_info, + "status": "success" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error switching LLM provider: {e}") + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/cancel-job/{job_id}") async def cancel_job(job_id: str, db: Session = Depends(get_db)): """Cancel a running job""" diff --git a/backend/requirements.txt b/backend/requirements.txt index ead22b2..6a03a5d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -15,3 +15,10 @@ lxml==4.9.3 aiohttp==3.9.1 aiofiles pyyaml==6.0.1 +langchain==0.2.0 +langchain-openai==0.1.17 +langchain-anthropic==0.1.15 +langchain-community==0.2.0 +langchain-core>=0.2.20 +openai>=1.32.0 +anthropic==0.40.0 diff --git a/frontend/src/App.js b/frontend/src/App.js index d8e66b9..9a815d4 100644 --- a/frontend/src/App.js +++ b/frontend/src/App.js @@ -21,6 +21,7 @@ function App() { const [gitHubPocStats, setGitHubPocStats] = useState({}); const [bulkProcessing, setBulkProcessing] = useState(false); const [hasRunningJobs, setHasRunningJobs] = useState(false); + const [llmStatus, setLlmStatus] = useState({}); useEffect(() => { fetchData(); @@ -29,14 +30,15 @@ function App() { const fetchData = async () => { try { setLoading(true); - const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes] = await Promise.all([ + const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes, llmStatusRes] = await Promise.all([ axios.get(`${API_BASE_URL}/api/cves`), axios.get(`${API_BASE_URL}/api/sigma-rules`), axios.get(`${API_BASE_URL}/api/stats`), axios.get(`${API_BASE_URL}/api/bulk-jobs`), axios.get(`${API_BASE_URL}/api/bulk-status`), axios.get(`${API_BASE_URL}/api/poc-stats`), - axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} })) + axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} })), + axios.get(`${API_BASE_URL}/api/llm-status`).catch(err => ({ data: {} })) ]); setCves(cvesRes.data); @@ -46,6 +48,7 @@ function App() { setBulkStatus(bulkStatusRes.data); setPocStats(pocStatsRes.data); setGitHubPocStats(githubPocStatsRes.data); + setLlmStatus(llmStatusRes.data); // Update running jobs state const runningJobs = bulkJobsRes.data.filter(job => job.status === 'running' || job.status === 'pending'); @@ -166,6 +169,32 @@ function App() { } }; + const generateLlmRules = async (force = false) => { + try { + const response = await axios.post(`${API_BASE_URL}/api/llm-enhanced-rules`, { + force: force + }); + console.log('LLM rule generation response:', response.data); + fetchData(); + } catch (error) { + console.error('Error generating LLM-enhanced rules:', error); + } + }; + + const switchLlmProvider = async (provider, model) => { + try { + const response = await axios.post(`${API_BASE_URL}/api/llm-switch`, { + provider: provider, + model: model + }); + console.log('LLM provider switch response:', response.data); + fetchData(); // Refresh to get updated status + } catch (error) { + console.error('Error switching LLM provider:', error); + alert('Failed to switch LLM provider. Please check configuration.'); + } + }; + const getSeverityColor = (severity) => { switch (severity?.toLowerCase()) { case 'critical': return 'bg-red-100 text-red-800'; @@ -197,6 +226,9 @@ function App() {
{stats.total_sigma_rules || 0}
Nomi-sec: {stats.nomi_sec_rules || 0}
GitHub PoCs: {gitHubPocStats.github_poc_rules || 0}
++ LLM: {llmStatus.current_provider?.provider || 'Not Available'} +
+ Provider: {llmStatus.current_provider?.provider || 'Not configured'} +
++ Model: {llmStatus.current_provider?.model || 'Not configured'} +
++ Status: {llmStatus.status || 'Unknown'} +
+