""" Enhanced SIGMA Rule Generator Generates improved SIGMA rules using nomi-sec PoC data and traditional indicators """ import json import logging from datetime import datetime from typing import Dict, List, Optional, Tuple from sqlalchemy.orm import Session import re from llm_client import LLMClient # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EnhancedSigmaGenerator: """Enhanced SIGMA rule generator using nomi-sec PoC data""" def __init__(self, db_session: Session, llm_provider: str = None, llm_model: str = None): self.db_session = db_session self.llm_client = LLMClient(provider=llm_provider, model=llm_model) async def generate_enhanced_rule(self, cve, use_llm: bool = True) -> dict: """Generate enhanced SIGMA rule for a CVE using PoC data""" from main import SigmaRule, RuleTemplate try: # Get PoC data poc_data = cve.poc_data or [] # Find the best quality PoC best_poc = None if poc_data: best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) # Try LLM-enhanced generation first if enabled and available rule_content = None generation_method = "template" if use_llm and self.llm_client.is_available() and best_poc: logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}") rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data) if rule_content: generation_method = f"llm_{self.llm_client.provider}" # Fallback to template-based generation if not rule_content: logger.info(f"Using template-based rule generation for {cve.cve_id}") # Select appropriate template based on PoC analysis template = await self._select_template(cve, best_poc) if not template: logger.warning(f"No suitable template found for {cve.cve_id}") return {'success': False, 'error': 'No suitable template'} # Generate rule content rule_content = await self._generate_rule_content(cve, template, poc_data) # Calculate confidence level confidence_level = self._calculate_confidence_level(cve, poc_data) # Store or update SIGMA rule existing_rule = self.db_session.query(SigmaRule).filter( SigmaRule.cve_id == cve.cve_id ).first() rule_data = { 'cve_id': cve.cve_id, 'rule_name': f"{cve.cve_id} Enhanced Detection", 'rule_content': rule_content, 'detection_type': f"{generation_method}_generated", 'log_source': self._extract_log_source_from_content(rule_content), 'confidence_level': confidence_level, 'auto_generated': True, 'exploit_based': len(poc_data) > 0, 'poc_source': getattr(cve, 'poc_source', 'nomi_sec'), 'poc_quality_score': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0, 'nomi_sec_data': { 'total_pocs': len(poc_data), 'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0, 'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data), 'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0, 'source': getattr(cve, 'poc_source', 'nomi_sec'), 'generation_method': generation_method }, 'github_repos': [p.get('html_url', '') for p in poc_data], 'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)), 'updated_at': datetime.utcnow() } if existing_rule: # Update existing rule for key, value in rule_data.items(): setattr(existing_rule, key, value) logger.info(f"Updated SIGMA rule for {cve.cve_id}") else: # Create new rule new_rule = SigmaRule(**rule_data) self.db_session.add(new_rule) logger.info(f"Created new SIGMA rule for {cve.cve_id}") self.db_session.commit() return { 'success': True, 'cve_id': cve.cve_id, 'template': template.template_name, 'confidence_level': confidence_level, 'poc_count': len(poc_data), 'quality_score': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0 } except Exception as e: logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}") return {'success': False, 'error': str(e)} async def _generate_llm_enhanced_rule(self, cve, best_poc: dict, poc_data: list) -> Optional[str]: """Generate SIGMA rule using LLM API with PoC analysis""" try: # Get PoC content from the best quality PoC poc_content = await self._extract_poc_content(best_poc) if not poc_content: logger.warning(f"No PoC content available for {cve.cve_id}") return None # Generate rule using LLM rule_content = await self.llm_client.generate_sigma_rule( cve_id=cve.cve_id, poc_content=poc_content, cve_description=cve.description or "", existing_rule=None ) if rule_content: # Validate the generated rule if self.llm_client.validate_sigma_rule(rule_content): logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}") return rule_content else: logger.warning(f"Generated rule for {cve.cve_id} failed validation") return None return None except Exception as e: logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") return None async def _extract_poc_content(self, poc: dict) -> Optional[str]: """Extract actual code content from PoC repository""" try: import aiohttp import asyncio # Get repository information repo_url = poc.get('html_url', '') if not repo_url: return None # Convert GitHub URL to API URL for repository content if 'github.com' in repo_url: # Extract owner and repo from URL parts = repo_url.rstrip('/').split('/') if len(parts) >= 2: owner = parts[-2] repo = parts[-1] # Get repository files via GitHub API api_url = f"https://api.github.com/repos/{owner}/{repo}/contents" async with aiohttp.ClientSession() as session: # Add timeout to prevent hanging timeout = aiohttp.ClientTimeout(total=30) async with session.get(api_url, timeout=timeout) as response: if response.status == 200: contents = await response.json() # Look for common exploit files target_files = [ 'exploit.py', 'poc.py', 'exploit.c', 'exploit.cpp', 'exploit.java', 'exploit.rb', 'exploit.php', 'exploit.js', 'exploit.sh', 'exploit.ps1', 'README.md', 'main.py', 'index.js' ] for file_info in contents: if file_info.get('type') == 'file': filename = file_info.get('name', '').lower() # Check if this is a target file if any(target in filename for target in target_files): file_url = file_info.get('download_url') if file_url: async with session.get(file_url, timeout=timeout) as file_response: if file_response.status == 200: content = await file_response.text() # Limit content size if len(content) > 10000: content = content[:10000] + "\n... [content truncated]" return content # If no specific exploit file found, return description/README for file_info in contents: if file_info.get('type') == 'file': filename = file_info.get('name', '').lower() if 'readme' in filename: file_url = file_info.get('download_url') if file_url: async with session.get(file_url, timeout=timeout) as file_response: if file_response.status == 200: content = await file_response.text() return content[:5000] # Smaller limit for README # Fallback to description and metadata description = poc.get('description', '') if description: return f"Repository Description: {description}" return None except Exception as e: logger.error(f"Error extracting PoC content: {e}") return None def _extract_log_source_from_content(self, rule_content: str) -> str: """Extract log source from the generated rule content""" try: import yaml parsed = yaml.safe_load(rule_content) logsource = parsed.get('logsource', {}) category = logsource.get('category', '') product = logsource.get('product', '') if category: return category elif product: return product else: return 'generic' except Exception: return 'generic' async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]: """Select the most appropriate template based on CVE and PoC analysis""" from main import RuleTemplate templates = self.db_session.query(RuleTemplate).all() if not templates: logger.warning("No rule templates found in database - creating default template") # Create a default template if none exist return self._create_default_template(cve, best_poc) # Score templates based on relevance template_scores = {} for template in templates: score = 0 # Score based on PoC indicators (highest priority) if best_poc: indicators = best_poc.get('exploit_indicators', {}) score += self._score_template_poc_match(template, indicators) # Score based on CVE description score += self._score_template_cve_match(template, cve) # Score based on affected products if cve.affected_products: score += self._score_template_product_match(template, cve.affected_products) template_scores[template] = score # Return template with highest score if template_scores: best_template = max(template_scores, key=template_scores.get) logger.info(f"Selected template {best_template.template_name} with score {template_scores[best_template]}") return best_template return self._create_default_template(cve, best_poc) def _score_template_poc_match(self, template: object, indicators: dict) -> int: """Score template based on PoC indicators""" score = 0 template_name = template.template_name.lower() # Process-based templates if 'process' in template_name or 'execution' in template_name: if indicators.get('processes') or indicators.get('commands'): score += 30 # Network-based templates if 'network' in template_name or 'connection' in template_name: if indicators.get('network') or indicators.get('urls'): score += 30 # File-based templates if 'file' in template_name or 'modification' in template_name: if indicators.get('files'): score += 30 # PowerShell templates if 'powershell' in template_name: processes = indicators.get('processes', []) if any('powershell' in p.lower() for p in processes): score += 35 return score def _score_template_cve_match(self, template: object, cve) -> int: """Score template based on CVE description""" score = 0 template_name = template.template_name.lower() description = (cve.description or '').lower() # Keyword matching if 'remote' in description and 'execution' in description: if 'process' in template_name or 'execution' in template_name: score += 20 if 'powershell' in description: if 'powershell' in template_name: score += 25 if 'network' in description or 'http' in description: if 'network' in template_name: score += 20 if 'file' in description or 'upload' in description: if 'file' in template_name: score += 20 return score def _score_template_product_match(self, template: object, affected_products: list) -> int: """Score template based on affected products""" score = 0 if not template.applicable_product_patterns: return 0 for pattern in template.applicable_product_patterns: pattern_lower = pattern.lower() for product in affected_products: product_lower = product.lower() if pattern_lower in product_lower: score += 10 break return score async def _generate_rule_content(self, cve, template: object, poc_data: list) -> str: """Generate the actual SIGMA rule content""" # Combine all exploit indicators combined_indicators = self._combine_exploit_indicators(poc_data) # Get base template content rule_content = template.template_content # Generate a unique rule ID import uuid rule_id = str(uuid.uuid4()) # Replace template placeholders replacements = { '{title}': f"{cve.cve_id} Enhanced Detection", '{description}': self._generate_description(cve, poc_data), '{rule_id}': rule_id, '{date}': datetime.now().strftime('%Y/%m/%d'), '{level}': self._calculate_confidence_level(cve, poc_data).lower(), '{cve_url}': f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}", '{tags}': self._generate_tags(cve, poc_data), '{suspicious_processes}': self._format_indicators(combined_indicators.get('processes', [])), '{suspicious_files}': self._format_indicators(combined_indicators.get('files', [])), '{suspicious_commands}': self._format_indicators(combined_indicators.get('commands', [])), '{suspicious_network}': self._format_indicators(combined_indicators.get('network', [])), '{suspicious_urls}': self._format_indicators(combined_indicators.get('urls', [])), '{suspicious_registry}': self._format_indicators(combined_indicators.get('registry', [])), '{suspicious_ports}': self._format_indicators(combined_indicators.get('ports', [])) } # Apply replacements for placeholder, value in replacements.items(): rule_content = rule_content.replace(placeholder, value) # Clean up empty sections rule_content = self._clean_empty_sections(rule_content) # Add enhanced detection based on PoC quality if poc_data: rule_content = self._enhance_detection_logic(rule_content, combined_indicators, poc_data) return rule_content def _combine_exploit_indicators(self, poc_data: list) -> dict: """Combine exploit indicators from all PoCs""" combined = { 'processes': [], 'files': [], 'commands': [], 'network': [], 'urls': [], 'registry': [] } for poc in poc_data: indicators = poc.get('exploit_indicators', {}) for key in combined.keys(): if key in indicators: combined[key].extend(indicators[key]) # Deduplicate and filter for key in combined.keys(): combined[key] = list(set(combined[key])) # Remove empty and invalid entries combined[key] = [item for item in combined[key] if item and len(item) > 2] return combined def _generate_description(self, cve, poc_data: list) -> str: """Generate enhanced rule description""" base_desc = f"Detection for {cve.cve_id}" if cve.description: # Extract key terms from CVE description desc_words = cve.description.lower().split() key_terms = [word for word in desc_words if word in [ 'remote', 'execution', 'injection', 'bypass', 'privilege', 'escalation', 'overflow', 'disclosure', 'traversal', 'deserialization' ]] if key_terms: base_desc += f" involving {', '.join(set(key_terms[:3]))}" if poc_data: total_pocs = len(poc_data) total_stars = sum(p.get('stargazers_count', 0) for p in poc_data) base_desc += f" [Enhanced with {total_pocs} PoC(s), {total_stars} stars]" return base_desc def _generate_references(self, cve, poc_data: list) -> str: """Generate references section""" refs = [] # Add CVE reference refs.append(f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}") # Add top PoC references (max 3) if poc_data: sorted_pocs = sorted(poc_data, key=lambda x: x.get('stargazers_count', 0), reverse=True) for poc in sorted_pocs[:3]: if poc.get('html_url'): refs.append(poc['html_url']) return '\\n'.join(f" - {ref}" for ref in refs) def _generate_tags(self, cve, poc_data: list) -> str: """Generate MITRE ATT&CK tags and other tags""" tags = [] # CVE tag tags.append(cve.cve_id.lower()) # Add technique tags based on indicators combined_indicators = self._combine_exploit_indicators(poc_data) if combined_indicators.get('processes'): tags.append('attack.t1059') # Command and Scripting Interpreter if combined_indicators.get('network'): tags.append('attack.t1071') # Application Layer Protocol if combined_indicators.get('files'): tags.append('attack.t1105') # Ingress Tool Transfer if any('powershell' in p.lower() for p in combined_indicators.get('processes', [])): tags.append('attack.t1059.001') # PowerShell # Add PoC quality tags if poc_data: tags.append('exploit.poc') best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) quality_tier = best_poc.get('quality_analysis', {}).get('quality_tier', 'poor') tags.append(f'poc.quality.{quality_tier}') # Return tags as a single line for first tag, then additional tags on new lines if not tags: return "unknown" if len(tags) == 1: return tags[0] else: # First tag goes directly after the dash, rest are on new lines first_tag = tags[0] additional_tags = '\\n'.join(f" - {tag}" for tag in tags[1:]) return f"{first_tag}\\n{additional_tags}" def _format_indicators(self, indicators: list) -> str: """Format indicators for SIGMA rule""" if not indicators: return ' - "*" # No specific indicators available' # Limit indicators to avoid overly complex rules limited_indicators = indicators[:10] formatted = [] for indicator in limited_indicators: # Clean and escape special characters for SIGMA cleaned = str(indicator).strip() if cleaned: escaped = cleaned.replace('\\\\', '\\\\\\\\').replace('*', '\\\\*').replace('?', '\\\\?') formatted.append(f' - "{escaped}"') return '\\n'.join(formatted) if formatted else ' - "*" # No valid indicators' def _enhance_detection_logic(self, rule_content: str, indicators: dict, poc_data: list) -> str: """Enhance detection logic based on PoC quality and indicators""" # If we have high-quality PoCs, add additional detection conditions best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) quality_score = best_poc.get('quality_analysis', {}).get('quality_score', 0) if quality_score > 60: # High quality PoC # Add more specific detection conditions if indicators.get('processes') and indicators.get('commands'): additional_condition = """ process_and_command: Image|contains: {{PROCESSES}} CommandLine|contains: {{COMMANDS}}""" # Insert before the condition line rule_content = rule_content.replace( 'condition: selection', additional_condition + '\\n condition: selection or process_and_command' ) return rule_content def _calculate_confidence_level(self, cve, poc_data: list) -> str: """Calculate confidence level based on CVE and PoC data""" score = 0 # CVSS score factor if cve.cvss_score: if cve.cvss_score >= 9.0: score += 40 elif cve.cvss_score >= 7.0: score += 30 elif cve.cvss_score >= 5.0: score += 20 else: score += 10 # PoC quality factor if poc_data: total_stars = sum(p.get('stargazers_count', 0) for p in poc_data) poc_count = len(poc_data) score += min(total_stars, 30) # Max 30 points for stars score += min(poc_count * 5, 20) # Max 20 points for PoC count # Quality tier bonus best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) quality_tier = best_poc.get('quality_analysis', {}).get('quality_tier', 'poor') tier_bonus = { 'excellent': 20, 'good': 15, 'fair': 10, 'poor': 5, 'very_poor': 0 } score += tier_bonus.get(quality_tier, 0) # Determine confidence level if score >= 80: return 'HIGH' elif score >= 60: return 'MEDIUM' elif score >= 40: return 'LOW' else: return 'INFORMATIONAL' def _create_default_template(self, cve, best_poc: Optional[dict]) -> object: """Create a default template based on CVE and PoC analysis""" from main import RuleTemplate import uuid # Analyze the best PoC to determine the most appropriate template type template_type = "process" if best_poc: indicators = best_poc.get('exploit_indicators', {}) if indicators.get('network') or indicators.get('urls'): template_type = "network" elif indicators.get('files'): template_type = "file" elif any('powershell' in p.lower() for p in indicators.get('processes', [])): template_type = "powershell" # Create template content based on type if template_type == "network": template_content = """title: {{TITLE}} id: {{RULE_ID}} status: experimental description: {{DESCRIPTION}} author: CVE-SIGMA Auto Generator date: {{DATE}} references: {{REFERENCES}} tags: {{TAGS}} logsource: category: network_connection product: windows detection: selection: Initiated: true DestinationIp: {{NETWORK}} selection_url: DestinationHostname|contains: {{URLS}} condition: selection or selection_url falsepositives: - Legitimate network connections level: {{LEVEL}}""" elif template_type == "file": template_content = """title: {{TITLE}} id: {{RULE_ID}} status: experimental description: {{DESCRIPTION}} author: CVE-SIGMA Auto Generator date: {{DATE}} references: {{REFERENCES}} tags: {{TAGS}} logsource: category: file_event product: windows detection: selection: TargetFilename|contains: {{FILES}} condition: selection falsepositives: - Legitimate file operations level: {{LEVEL}}""" elif template_type == "powershell": template_content = """title: {{TITLE}} id: {{RULE_ID}} status: experimental description: {{DESCRIPTION}} author: CVE-SIGMA Auto Generator date: {{DATE}} references: {{REFERENCES}} tags: {{TAGS}} logsource: category: process_creation product: windows detection: selection: Image|endswith: - '\\powershell.exe' - '\\pwsh.exe' CommandLine|contains: {{COMMANDS}} condition: selection falsepositives: - Legitimate PowerShell scripts level: {{LEVEL}}""" else: # default to process template_content = """title: {{TITLE}} id: {{RULE_ID}} status: experimental description: {{DESCRIPTION}} author: CVE-SIGMA Auto Generator date: {{DATE}} references: {{REFERENCES}} tags: {{TAGS}} logsource: category: process_creation product: windows detection: selection: Image|endswith: {{PROCESSES}} selection_cmd: CommandLine|contains: {{COMMANDS}} condition: selection or selection_cmd falsepositives: - Legitimate software usage level: {{LEVEL}}""" # Create a temporary template object class DefaultTemplate: def __init__(self, name, content): self.template_name = name self.template_content = content self.applicable_product_patterns = [] return DefaultTemplate(f"Default {template_type.title()} Template", template_content) def _clean_empty_sections(self, rule_content: str) -> str: """Clean up empty sections in the SIGMA rule""" # Remove lines that contain only placeholder indicators lines = rule_content.split('\n') cleaned_lines = [] for line in lines: # Skip lines that are just placeholder indicators if '- "*" # No' in line and 'or selection' in rule_content: continue cleaned_lines.append(line) return '\n'.join(cleaned_lines) def _extract_log_source(self, template_name: str) -> str: """Extract log source from template name""" template_lower = template_name.lower() if 'process' in template_lower or 'execution' in template_lower: return 'process_creation' elif 'network' in template_lower: return 'network_connection' elif 'file' in template_lower: return 'file_event' elif 'powershell' in template_lower: return 'powershell' elif 'registry' in template_lower: return 'registry_event' else: return 'generic'