auto_sigma_rule_generator/backend/enhanced_sigma_generator.py
bpmcdevitt a6fb367ed4 refactor: modularize backend architecture for improved maintainability
- Extract database models from monolithic main.py (2,373 lines) into organized modules
- Implement service layer pattern with dedicated business logic classes
- Split API endpoints into modular FastAPI routers by functionality
- Add centralized configuration management with environment variable handling
- Create proper separation of concerns across data, service, and presentation layers

**Architecture Changes:**
- models/: SQLAlchemy database models (CVE, SigmaRule, RuleTemplate, BulkProcessingJob)
- config/: Centralized settings and database configuration
- services/: Business logic (CVEService, SigmaRuleService, GitHubExploitAnalyzer)
- routers/: Modular API endpoints (cves, sigma_rules, bulk_operations, llm_operations)
- schemas/: Pydantic request/response models

**Key Improvements:**
- 95% reduction in main.py size (2,373 → 120 lines)
- Updated 15+ backend files with proper import structure
- Eliminated circular dependencies and tight coupling
- Enhanced testability with isolated service components
- Better code organization for team collaboration

**Backward Compatibility:**
- All API endpoints maintain same URLs and behavior
- Zero breaking changes to existing functionality
- Database schema unchanged
- Environment variables preserved

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-14 17:51:23 -05:00

774 lines
No EOL
31 KiB
Python

"""
Enhanced SIGMA Rule Generator
Generates improved SIGMA rules using nomi-sec PoC data and traditional indicators
"""
import json
import logging
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
import re
from llm_client import LLMClient
from cve2capec_client import CVE2CAPECClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnhancedSigmaGenerator:
"""Enhanced SIGMA rule generator using nomi-sec PoC data"""
def __init__(self, db_session: Session, llm_provider: str = None, llm_model: str = None):
self.db_session = db_session
self.llm_client = LLMClient(provider=llm_provider, model=llm_model)
self.cve2capec_client = CVE2CAPECClient()
async def generate_enhanced_rule(self, cve, use_llm: bool = True) -> dict:
"""Generate enhanced SIGMA rule for a CVE using PoC data"""
from models import SigmaRule, RuleTemplate
try:
# Get PoC data
poc_data = cve.poc_data or []
# Find the best quality PoC
best_poc = None
if poc_data:
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
# Try LLM-enhanced generation first if enabled and available
rule_content = None
generation_method = "template"
template = None
if use_llm and self.llm_client.is_available() and best_poc:
logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}")
rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data)
if rule_content:
generation_method = f"llm_{self.llm_client.provider}"
# Create a dummy template object for LLM-generated rules
class LLMTemplate:
def __init__(self, provider_name):
self.template_name = f"LLM Generated ({provider_name})"
template = LLMTemplate(self.llm_client.provider)
# Fallback to template-based generation
if not rule_content:
logger.info(f"Using template-based rule generation for {cve.cve_id}")
# Select appropriate template based on PoC analysis
template = await self._select_template(cve, best_poc)
if not template:
logger.warning(f"No suitable template found for {cve.cve_id}")
return {'success': False, 'error': 'No suitable template'}
# Generate rule content
rule_content = await self._generate_rule_content(cve, template, poc_data)
# Calculate confidence level
confidence_level = self._calculate_confidence_level(cve, poc_data)
# Store or update SIGMA rule
existing_rule = self.db_session.query(SigmaRule).filter(
SigmaRule.cve_id == cve.cve_id
).first()
rule_data = {
'cve_id': cve.cve_id,
'rule_name': f"{cve.cve_id} Enhanced Detection",
'rule_content': rule_content,
'detection_type': f"{generation_method}_generated",
'log_source': self._extract_log_source_from_content(rule_content),
'confidence_level': confidence_level,
'auto_generated': True,
'exploit_based': len(poc_data) > 0,
'poc_source': getattr(cve, 'poc_source', 'nomi_sec'),
'poc_quality_score': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0,
'nomi_sec_data': {
'total_pocs': len(poc_data),
'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0,
'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data),
'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0,
'source': getattr(cve, 'poc_source', 'nomi_sec'),
'generation_method': generation_method
},
'github_repos': [p.get('html_url', '') for p in poc_data],
'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)),
'updated_at': datetime.utcnow()
}
if existing_rule:
# Update existing rule
for key, value in rule_data.items():
setattr(existing_rule, key, value)
logger.info(f"Updated SIGMA rule for {cve.cve_id}")
else:
# Create new rule
new_rule = SigmaRule(**rule_data)
self.db_session.add(new_rule)
logger.info(f"Created new SIGMA rule for {cve.cve_id}")
self.db_session.commit()
return {
'success': True,
'cve_id': cve.cve_id,
'template': template.template_name if template else 'Unknown',
'confidence_level': confidence_level,
'poc_count': len(poc_data),
'quality_score': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0
}
except Exception as e:
logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}")
return {'success': False, 'error': str(e)}
async def _generate_llm_enhanced_rule(self, cve, best_poc: dict, poc_data: list) -> Optional[str]:
"""Generate SIGMA rule using LLM API with PoC analysis"""
try:
# Get PoC content from the best quality PoC
poc_content = await self._extract_poc_content(best_poc)
if not poc_content:
logger.warning(f"No PoC content available for {cve.cve_id}")
return None
# Generate rule using LLM
rule_content = await self.llm_client.generate_sigma_rule(
cve_id=cve.cve_id,
poc_content=poc_content,
cve_description=cve.description or "",
existing_rule=None
)
if rule_content:
# Validate the generated rule with CVE ID check
if self.llm_client.validate_sigma_rule(rule_content, cve.cve_id):
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
return rule_content
else:
logger.warning(f"Generated rule for {cve.cve_id} failed validation")
return None
return None
except Exception as e:
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
return None
async def _extract_poc_content(self, poc: dict) -> Optional[str]:
"""Extract actual code content from PoC repository"""
try:
import aiohttp
import asyncio
# Get repository information
repo_url = poc.get('html_url', '')
if not repo_url:
return None
# Convert GitHub URL to API URL for repository content
if 'github.com' in repo_url:
# Extract owner and repo from URL
parts = repo_url.rstrip('/').split('/')
if len(parts) >= 2:
owner = parts[-2]
repo = parts[-1]
# Get repository files via GitHub API
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents"
async with aiohttp.ClientSession() as session:
# Add timeout to prevent hanging
timeout = aiohttp.ClientTimeout(total=30)
async with session.get(api_url, timeout=timeout) as response:
if response.status == 200:
contents = await response.json()
# Look for common exploit files
target_files = [
'exploit.py', 'poc.py', 'exploit.c', 'exploit.cpp',
'exploit.java', 'exploit.rb', 'exploit.php',
'exploit.js', 'exploit.sh', 'exploit.ps1',
'README.md', 'main.py', 'index.js'
]
for file_info in contents:
if file_info.get('type') == 'file':
filename = file_info.get('name', '').lower()
# Check if this is a target file
if any(target in filename for target in target_files):
file_url = file_info.get('download_url')
if file_url:
async with session.get(file_url, timeout=timeout) as file_response:
if file_response.status == 200:
content = await file_response.text()
# Limit content size
if len(content) > 10000:
content = content[:10000] + "\n... [content truncated]"
return content
# If no specific exploit file found, return description/README
for file_info in contents:
if file_info.get('type') == 'file':
filename = file_info.get('name', '').lower()
if 'readme' in filename:
file_url = file_info.get('download_url')
if file_url:
async with session.get(file_url, timeout=timeout) as file_response:
if file_response.status == 200:
content = await file_response.text()
return content[:5000] # Smaller limit for README
# Fallback to description and metadata
description = poc.get('description', '')
if description:
return f"Repository Description: {description}"
return None
except Exception as e:
logger.error(f"Error extracting PoC content: {e}")
return None
def _extract_log_source_from_content(self, rule_content: str) -> str:
"""Extract log source from the generated rule content"""
try:
import yaml
parsed = yaml.safe_load(rule_content)
logsource = parsed.get('logsource', {})
category = logsource.get('category', '')
product = logsource.get('product', '')
if category:
return category
elif product:
return product
else:
return 'generic'
except Exception:
return 'generic'
async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]:
"""Select the most appropriate template based on CVE and PoC analysis"""
from models import RuleTemplate
templates = self.db_session.query(RuleTemplate).all()
if not templates:
logger.warning("No rule templates found in database - creating default template")
# Create a default template if none exist
return self._create_default_template(cve, best_poc)
# Score templates based on relevance
template_scores = {}
for template in templates:
score = 0
# Score based on PoC indicators (highest priority)
if best_poc:
indicators = best_poc.get('exploit_indicators', {})
score += self._score_template_poc_match(template, indicators)
# Score based on CVE description
score += self._score_template_cve_match(template, cve)
# Score based on affected products
if cve.affected_products:
score += self._score_template_product_match(template, cve.affected_products)
template_scores[template] = score
# Return template with highest score
if template_scores:
best_template = max(template_scores, key=template_scores.get)
logger.info(f"Selected template {best_template.template_name} with score {template_scores[best_template]}")
return best_template
return self._create_default_template(cve, best_poc)
def _score_template_poc_match(self, template: object, indicators: dict) -> int:
"""Score template based on PoC indicators"""
score = 0
template_name = template.template_name.lower()
# Process-based templates
if 'process' in template_name or 'execution' in template_name:
if indicators.get('processes') or indicators.get('commands'):
score += 30
# Network-based templates
if 'network' in template_name or 'connection' in template_name:
if indicators.get('network') or indicators.get('urls'):
score += 30
# File-based templates
if 'file' in template_name or 'modification' in template_name:
if indicators.get('files'):
score += 30
# PowerShell templates
if 'powershell' in template_name:
processes = indicators.get('processes', [])
if any('powershell' in p.lower() for p in processes):
score += 35
return score
def _score_template_cve_match(self, template: object, cve) -> int:
"""Score template based on CVE description"""
score = 0
template_name = template.template_name.lower()
description = (cve.description or '').lower()
# Keyword matching
if 'remote' in description and 'execution' in description:
if 'process' in template_name or 'execution' in template_name:
score += 20
if 'powershell' in description:
if 'powershell' in template_name:
score += 25
if 'network' in description or 'http' in description:
if 'network' in template_name:
score += 20
if 'file' in description or 'upload' in description:
if 'file' in template_name:
score += 20
return score
def _score_template_product_match(self, template: object, affected_products: list) -> int:
"""Score template based on affected products"""
score = 0
if not template.applicable_product_patterns:
return 0
for pattern in template.applicable_product_patterns:
pattern_lower = pattern.lower()
for product in affected_products:
product_lower = product.lower()
if pattern_lower in product_lower:
score += 10
break
return score
async def _generate_rule_content(self, cve, template: object, poc_data: list) -> str:
"""Generate the actual SIGMA rule content"""
# Combine all exploit indicators
combined_indicators = self._combine_exploit_indicators(poc_data)
# Get base template content
rule_content = template.template_content
# Generate a unique rule ID
import uuid
rule_id = str(uuid.uuid4())
# Replace template placeholders
replacements = {
'{title}': f"{cve.cve_id} Enhanced Detection",
'{description}': self._generate_description(cve, poc_data),
'{rule_id}': rule_id,
'{date}': datetime.now().strftime('%Y/%m/%d'),
'{level}': self._calculate_confidence_level(cve, poc_data).lower(),
'{cve_url}': f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
'{tags}': self._generate_tags(cve, poc_data),
'{suspicious_processes}': self._format_indicators(combined_indicators.get('processes', [])),
'{suspicious_files}': self._format_indicators(combined_indicators.get('files', [])),
'{suspicious_commands}': self._format_indicators(combined_indicators.get('commands', [])),
'{suspicious_network}': self._format_indicators(combined_indicators.get('network', [])),
'{suspicious_urls}': self._format_indicators(combined_indicators.get('urls', [])),
'{suspicious_registry}': self._format_indicators(combined_indicators.get('registry', [])),
'{suspicious_ports}': self._format_indicators(combined_indicators.get('ports', []))
}
# Apply replacements
for placeholder, value in replacements.items():
rule_content = rule_content.replace(placeholder, value)
# Clean up empty sections
rule_content = self._clean_empty_sections(rule_content)
# Add enhanced detection based on PoC quality
if poc_data:
rule_content = self._enhance_detection_logic(rule_content, combined_indicators, poc_data)
return rule_content
def _combine_exploit_indicators(self, poc_data: list) -> dict:
"""Combine exploit indicators from all PoCs"""
combined = {
'processes': [],
'files': [],
'commands': [],
'network': [],
'urls': [],
'registry': []
}
for poc in poc_data:
indicators = poc.get('exploit_indicators', {})
for key in combined.keys():
if key in indicators:
combined[key].extend(indicators[key])
# Deduplicate and filter
for key in combined.keys():
combined[key] = list(set(combined[key]))
# Remove empty and invalid entries
combined[key] = [item for item in combined[key] if item and len(item) > 2]
return combined
def _generate_description(self, cve, poc_data: list) -> str:
"""Generate enhanced rule description"""
base_desc = f"Detection for {cve.cve_id}"
if cve.description:
# Extract key terms from CVE description
desc_words = cve.description.lower().split()
key_terms = [word for word in desc_words if word in [
'remote', 'execution', 'injection', 'bypass', 'privilege', 'escalation',
'overflow', 'disclosure', 'traversal', 'deserialization'
]]
if key_terms:
base_desc += f" involving {', '.join(set(key_terms[:3]))}"
if poc_data:
total_pocs = len(poc_data)
total_stars = sum(p.get('stargazers_count', 0) for p in poc_data)
base_desc += f" [Enhanced with {total_pocs} PoC(s), {total_stars} stars]"
return base_desc
def _generate_references(self, cve, poc_data: list) -> str:
"""Generate references section"""
refs = []
# Add CVE reference
refs.append(f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}")
# Add top PoC references (max 3)
if poc_data:
sorted_pocs = sorted(poc_data, key=lambda x: x.get('stargazers_count', 0), reverse=True)
for poc in sorted_pocs[:3]:
if poc.get('html_url'):
refs.append(poc['html_url'])
return '\\n'.join(f" - {ref}" for ref in refs)
def _generate_tags(self, cve, poc_data: list) -> str:
"""Generate MITRE ATT&CK tags and other tags using CVE2CAPEC mappings"""
tags = []
# CVE tag
tags.append(cve.cve_id.lower())
# Get MITRE ATT&CK techniques from CVE2CAPEC mapping
mitre_techniques = self.cve2capec_client.get_mitre_techniques_for_cve(cve.cve_id)
if mitre_techniques:
logger.info(f"Found {len(mitre_techniques)} MITRE techniques for {cve.cve_id}: {mitre_techniques}")
# Add all mapped MITRE techniques
for technique in mitre_techniques:
# Convert to attack.t format (lowercase)
attack_tag = f"attack.{technique.lower()}"
if attack_tag not in tags:
tags.append(attack_tag)
else:
# Fallback to indicator-based technique detection
logger.info(f"No CVE2CAPEC mapping found for {cve.cve_id}, using indicator-based detection")
combined_indicators = self._combine_exploit_indicators(poc_data)
if combined_indicators.get('processes'):
tags.append('attack.t1059') # Command and Scripting Interpreter
if combined_indicators.get('network'):
tags.append('attack.t1071') # Application Layer Protocol
if combined_indicators.get('files'):
tags.append('attack.t1105') # Ingress Tool Transfer
if any('powershell' in p.lower() for p in combined_indicators.get('processes', [])):
tags.append('attack.t1059.001') # PowerShell
# Get CWE codes for additional context
cwe_codes = self.cve2capec_client.get_cwe_for_cve(cve.cve_id)
if cwe_codes:
# Add the primary CWE as a tag
primary_cwe = cwe_codes[0].lower().replace('-', '.')
tags.append(primary_cwe)
# Add PoC quality tags
if poc_data:
tags.append('exploit.poc')
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
quality_tier = best_poc.get('quality_analysis', {}).get('quality_tier', 'poor')
tags.append(f'poc.quality.{quality_tier}')
# Return tags as a single line for first tag, then additional tags on new lines
if not tags:
return "unknown"
if len(tags) == 1:
return tags[0]
else:
# First tag goes directly after the dash, rest are on new lines
first_tag = tags[0]
additional_tags = '\\n'.join(f" - {tag}" for tag in tags[1:])
return f"{first_tag}\\n{additional_tags}"
def _format_indicators(self, indicators: list) -> str:
"""Format indicators for SIGMA rule"""
if not indicators:
return ' - "*" # No specific indicators available'
# Limit indicators to avoid overly complex rules
limited_indicators = indicators[:10]
formatted = []
for indicator in limited_indicators:
# Clean and escape special characters for SIGMA
cleaned = str(indicator).strip()
if cleaned:
escaped = cleaned.replace('\\\\', '\\\\\\\\').replace('*', '\\\\*').replace('?', '\\\\?')
formatted.append(f' - "{escaped}"')
return '\\n'.join(formatted) if formatted else ' - "*" # No valid indicators'
def _enhance_detection_logic(self, rule_content: str, indicators: dict, poc_data: list) -> str:
"""Enhance detection logic based on PoC quality and indicators"""
# If we have high-quality PoCs, add additional detection conditions
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
quality_score = best_poc.get('quality_analysis', {}).get('quality_score', 0)
if quality_score > 60: # High quality PoC
# Add more specific detection conditions
if indicators.get('processes') and indicators.get('commands'):
additional_condition = """
process_and_command:
Image|contains: {{PROCESSES}}
CommandLine|contains: {{COMMANDS}}"""
# Insert before the condition line
rule_content = rule_content.replace(
'condition: selection',
additional_condition + '\\n condition: selection or process_and_command'
)
return rule_content
def _calculate_confidence_level(self, cve, poc_data: list) -> str:
"""Calculate confidence level based on CVE and PoC data"""
score = 0
# CVSS score factor
if cve.cvss_score:
if cve.cvss_score >= 9.0:
score += 40
elif cve.cvss_score >= 7.0:
score += 30
elif cve.cvss_score >= 5.0:
score += 20
else:
score += 10
# PoC quality factor
if poc_data:
total_stars = sum(p.get('stargazers_count', 0) for p in poc_data)
poc_count = len(poc_data)
score += min(total_stars, 30) # Max 30 points for stars
score += min(poc_count * 5, 20) # Max 20 points for PoC count
# Quality tier bonus
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
quality_tier = best_poc.get('quality_analysis', {}).get('quality_tier', 'poor')
tier_bonus = {
'excellent': 20,
'good': 15,
'fair': 10,
'poor': 5,
'very_poor': 0
}
score += tier_bonus.get(quality_tier, 0)
# Determine confidence level
if score >= 80:
return 'HIGH'
elif score >= 60:
return 'MEDIUM'
elif score >= 40:
return 'LOW'
else:
return 'INFORMATIONAL'
def _create_default_template(self, cve, best_poc: Optional[dict]) -> object:
"""Create a default template based on CVE and PoC analysis"""
from models import RuleTemplate
import uuid
# Analyze the best PoC to determine the most appropriate template type
template_type = "process"
if best_poc:
indicators = best_poc.get('exploit_indicators', {})
if indicators.get('network') or indicators.get('urls'):
template_type = "network"
elif indicators.get('files'):
template_type = "file"
elif any('powershell' in p.lower() for p in indicators.get('processes', [])):
template_type = "powershell"
# Create template content based on type
if template_type == "network":
template_content = """title: {{TITLE}}
id: {{RULE_ID}}
status: experimental
description: {{DESCRIPTION}}
author: CVE-SIGMA Auto Generator
date: {{DATE}}
references:
{{REFERENCES}}
tags:
{{TAGS}}
logsource:
category: network_connection
product: windows
detection:
selection:
Initiated: true
DestinationIp:
{{NETWORK}}
selection_url:
DestinationHostname|contains:
{{URLS}}
condition: selection or selection_url
falsepositives:
- Legitimate network connections
level: {{LEVEL}}"""
elif template_type == "file":
template_content = """title: {{TITLE}}
id: {{RULE_ID}}
status: experimental
description: {{DESCRIPTION}}
author: CVE-SIGMA Auto Generator
date: {{DATE}}
references:
{{REFERENCES}}
tags:
{{TAGS}}
logsource:
category: file_event
product: windows
detection:
selection:
TargetFilename|contains:
{{FILES}}
condition: selection
falsepositives:
- Legitimate file operations
level: {{LEVEL}}"""
elif template_type == "powershell":
template_content = """title: {{TITLE}}
id: {{RULE_ID}}
status: experimental
description: {{DESCRIPTION}}
author: CVE-SIGMA Auto Generator
date: {{DATE}}
references:
{{REFERENCES}}
tags:
{{TAGS}}
logsource:
category: process_creation
product: windows
detection:
selection:
Image|endswith:
- '\\powershell.exe'
- '\\pwsh.exe'
CommandLine|contains:
{{COMMANDS}}
condition: selection
falsepositives:
- Legitimate PowerShell scripts
level: {{LEVEL}}"""
else: # default to process
template_content = """title: {{TITLE}}
id: {{RULE_ID}}
status: experimental
description: {{DESCRIPTION}}
author: CVE-SIGMA Auto Generator
date: {{DATE}}
references:
{{REFERENCES}}
tags:
{{TAGS}}
logsource:
category: process_creation
product: windows
detection:
selection:
Image|endswith:
{{PROCESSES}}
selection_cmd:
CommandLine|contains:
{{COMMANDS}}
condition: selection or selection_cmd
falsepositives:
- Legitimate software usage
level: {{LEVEL}}"""
# Create a temporary template object
class DefaultTemplate:
def __init__(self, name, content):
self.template_name = name
self.template_content = content
self.applicable_product_patterns = []
return DefaultTemplate(f"Default {template_type.title()} Template", template_content)
def _clean_empty_sections(self, rule_content: str) -> str:
"""Clean up empty sections in the SIGMA rule"""
# Remove lines that contain only placeholder indicators
lines = rule_content.split('\n')
cleaned_lines = []
for line in lines:
# Skip lines that are just placeholder indicators
if '- "*" # No' in line and 'or selection' in rule_content:
continue
cleaned_lines.append(line)
return '\n'.join(cleaned_lines)
def _extract_log_source(self, template_name: str) -> str:
"""Extract log source from template name"""
template_lower = template_name.lower()
if 'process' in template_lower or 'execution' in template_lower:
return 'process_creation'
elif 'network' in template_lower:
return 'network_connection'
elif 'file' in template_lower:
return 'file_event'
elif 'powershell' in template_lower:
return 'powershell'
elif 'registry' in template_lower:
return 'registry_event'
else:
return 'generic'