from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.dialects.postgresql import UUID import uuid from datetime import datetime, timedelta import requests import json import re import os from typing import List, Optional from pydantic import BaseModel import asyncio from contextlib import asynccontextmanager import base64 from github import Github from urllib.parse import urlparse import hashlib import logging import threading # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global job tracking running_jobs = {} job_cancellation_flags = {} # Database setup DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db") engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # Database Models class CVE(Base): __tablename__ = "cves" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20), unique=True, nullable=False) description = Column(Text) cvss_score = Column(DECIMAL(3, 1)) severity = Column(String(20)) published_date = Column(TIMESTAMP) modified_date = Column(TIMESTAMP) affected_products = Column(ARRAY(String)) reference_urls = Column(ARRAY(String)) # Bulk processing fields data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual' nvd_json_version = Column(String(10), default='2.0') bulk_processed = Column(Boolean, default=False) # nomi-sec PoC fields poc_count = Column(Integer, default=0) poc_data = Column(JSON) # Store nomi-sec PoC metadata created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class SigmaRule(Base): __tablename__ = "sigma_rules" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20)) rule_name = Column(String(255), nullable=False) rule_content = Column(Text, nullable=False) detection_type = Column(String(50)) log_source = Column(String(100)) confidence_level = Column(String(20)) auto_generated = Column(Boolean, default=True) exploit_based = Column(Boolean, default=False) github_repos = Column(ARRAY(String)) exploit_indicators = Column(Text) # JSON string of extracted indicators # Enhanced fields for new data sources poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual' poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc. nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class RuleTemplate(Base): __tablename__ = "rule_templates" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) template_name = Column(String(255), nullable=False) template_content = Column(Text, nullable=False) applicable_product_patterns = Column(ARRAY(String)) description = Column(Text) created_at = Column(TIMESTAMP, default=datetime.utcnow) class BulkProcessingJob(Base): __tablename__ = "bulk_processing_jobs" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update' status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled' year = Column(Integer) # For year-based processing total_items = Column(Integer, default=0) processed_items = Column(Integer, default=0) failed_items = Column(Integer, default=0) error_message = Column(Text) job_metadata = Column(JSON) # Additional job-specific data started_at = Column(TIMESTAMP) completed_at = Column(TIMESTAMP) cancelled_at = Column(TIMESTAMP) created_at = Column(TIMESTAMP, default=datetime.utcnow) # Pydantic models class CVEResponse(BaseModel): id: str cve_id: str description: Optional[str] = None cvss_score: Optional[float] = None severity: Optional[str] = None published_date: Optional[datetime] = None affected_products: Optional[List[str]] = None reference_urls: Optional[List[str]] = None class Config: from_attributes = True class SigmaRuleResponse(BaseModel): id: str cve_id: str rule_name: str rule_content: str detection_type: Optional[str] = None log_source: Optional[str] = None confidence_level: Optional[str] = None auto_generated: bool = True exploit_based: bool = False github_repos: Optional[List[str]] = None exploit_indicators: Optional[str] = None created_at: datetime class Config: from_attributes = True # GitHub Exploit Analysis Service class GitHubExploitAnalyzer: def __init__(self): self.github_token = os.getenv("GITHUB_TOKEN") self.github = Github(self.github_token) if self.github_token else None async def search_exploits_for_cve(self, cve_id: str) -> List[dict]: """Search GitHub for exploit code related to a CVE""" if not self.github: print(f"No GitHub token configured, skipping exploit search for {cve_id}") return [] try: print(f"Searching GitHub for exploits for {cve_id}") # Search queries to find exploit code search_queries = [ f"{cve_id} exploit", f"{cve_id} poc", f"{cve_id} vulnerability", f'"{cve_id}" exploit code', f"{cve_id.replace('-', '_')} exploit" ] exploits = [] seen_repos = set() for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits try: # Search repositories repos = self.github.search_repositories( query=query, sort="updated", order="desc" ) # Get top 5 results per query for repo in repos[:5]: if repo.full_name in seen_repos: continue seen_repos.add(repo.full_name) # Analyze repository exploit_info = await self._analyze_repository(repo, cve_id) if exploit_info: exploits.append(exploit_info) if len(exploits) >= 10: # Limit total exploits break if len(exploits) >= 10: break except Exception as e: print(f"Error searching GitHub with query '{query}': {str(e)}") continue print(f"Found {len(exploits)} potential exploits for {cve_id}") return exploits except Exception as e: print(f"Error searching GitHub for {cve_id}: {str(e)}") return [] async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]: """Analyze a GitHub repository for exploit code""" try: # Check if repo name or description mentions the CVE repo_text = f"{repo.name} {repo.description or ''}".lower() if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text: return None # Get repository contents exploit_files = [] indicators = { 'processes': set(), 'files': set(), 'registry': set(), 'network': set(), 'commands': set(), 'powershell': set(), 'urls': set() } try: contents = repo.get_contents("") for content in contents[:20]: # Limit files to analyze if content.type == "file" and self._is_exploit_file(content.name): file_analysis = await self._analyze_file_content(repo, content, cve_id) if file_analysis: exploit_files.append(file_analysis) # Merge indicators for key, values in file_analysis.get('indicators', {}).items(): if key in indicators: indicators[key].update(values) except Exception as e: print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}") if not exploit_files: return None return { 'repo_name': repo.full_name, 'repo_url': repo.html_url, 'description': repo.description, 'language': repo.language, 'stars': repo.stargazers_count, 'updated': repo.updated_at.isoformat(), 'files': exploit_files, 'indicators': {k: list(v) for k, v in indicators.items()} } except Exception as e: print(f"Error analyzing repository {repo.full_name}: {str(e)}") return None def _is_exploit_file(self, filename: str) -> bool: """Check if a file is likely to contain exploit code""" exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java'] exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack'] filename_lower = filename.lower() # Check extension if not any(filename_lower.endswith(ext) for ext in exploit_extensions): return False # Check filename for exploit-related terms return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]: """Analyze individual file content for exploit indicators""" try: if file_content.size > 100000: # Skip files larger than 100KB return None # Decode file content content = file_content.decoded_content.decode('utf-8', errors='ignore') # Check if file actually mentions the CVE if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower(): return None indicators = self._extract_indicators_from_code(content, file_content.name) if not any(indicators.values()): return None return { 'filename': file_content.name, 'path': file_content.path, 'size': file_content.size, 'indicators': indicators } except Exception as e: print(f"Error analyzing file {file_content.name}: {str(e)}") return None def _extract_indicators_from_code(self, content: str, filename: str) -> dict: """Extract security indicators from exploit code""" indicators = { 'processes': set(), 'files': set(), 'registry': set(), 'network': set(), 'commands': set(), 'powershell': set(), 'urls': set() } # Process patterns process_patterns = [ r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']', r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']', r'system\s*\(\s*["\']([^"\']+)["\']', r'exec\s*\(\s*["\']([^"\']+)["\']', r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']' ] # File patterns file_patterns = [ r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']', r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?', r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)', r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)' ] # Registry patterns registry_patterns = [ r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']', r'HKEY_[A-Z_]+\\\\([^"\'\\]+)', r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?' ] # Network patterns network_patterns = [ r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)', r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)', r'(?:http|https|ftp)://([^\s"\'<>]+)', r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)' ] # PowerShell patterns powershell_patterns = [ r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?', r'Start-Process\s+["\']?([^"\']+)["\']?', r'Get-Process\s+["\']?([^"\']+)["\']?' ] # Command patterns command_patterns = [ r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)', r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)' ] # Extract indicators using regex patterns patterns = { 'processes': process_patterns, 'files': file_patterns, 'registry': registry_patterns, 'powershell': powershell_patterns, 'commands': command_patterns } for category, pattern_list in patterns.items(): for pattern in pattern_list: matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE) for match in matches: if isinstance(match, tuple): indicators[category].add(match[0]) else: indicators[category].add(match) # Special handling for network indicators for pattern in network_patterns: matches = re.findall(pattern, content, re.IGNORECASE) for match in matches: if isinstance(match, tuple): if len(match) >= 2: indicators['network'].add(f"{match[0]}:{match[1]}") else: indicators['network'].add(match[0]) else: indicators['network'].add(match) # Convert sets to lists and filter out empty/invalid indicators cleaned_indicators = {} for key, values in indicators.items(): cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200] if cleaned_values: cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category return cleaned_indicators class CVESigmaService: def __init__(self, db: Session): self.db = db self.nvd_api_key = os.getenv("NVD_API_KEY") async def fetch_recent_cves(self, days_back: int = 7): """Fetch recent CVEs from NVD API""" end_date = datetime.utcnow() start_date = end_date - timedelta(days=days_back) url = "https://services.nvd.nist.gov/rest/json/cves/2.0" params = { "pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "resultsPerPage": 100 } headers = {} if self.nvd_api_key: headers["apiKey"] = self.nvd_api_key try: response = requests.get(url, params=params, headers=headers, timeout=30) response.raise_for_status() data = response.json() new_cves = [] for vuln in data.get("vulnerabilities", []): cve_data = vuln.get("cve", {}) cve_id = cve_data.get("id") # Check if CVE already exists existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first() if existing: continue # Extract CVE information description = "" if cve_data.get("descriptions"): description = cve_data["descriptions"][0].get("value", "") cvss_score = None severity = None if cve_data.get("metrics", {}).get("cvssMetricV31"): cvss_data = cve_data["metrics"]["cvssMetricV31"][0] cvss_score = cvss_data.get("cvssData", {}).get("baseScore") severity = cvss_data.get("cvssData", {}).get("baseSeverity") affected_products = [] if cve_data.get("configurations"): for config in cve_data["configurations"]: for node in config.get("nodes", []): for cpe_match in node.get("cpeMatch", []): if cpe_match.get("vulnerable"): affected_products.append(cpe_match.get("criteria", "")) reference_urls = [] if cve_data.get("references"): reference_urls = [ref.get("url", "") for ref in cve_data["references"]] cve_obj = CVE( cve_id=cve_id, description=description, cvss_score=cvss_score, severity=severity, published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")), modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")), affected_products=affected_products, reference_urls=reference_urls ) self.db.add(cve_obj) new_cves.append(cve_obj) self.db.commit() return new_cves except Exception as e: print(f"Error fetching CVEs: {str(e)}") return [] def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]: """Generate SIGMA rule based on CVE data""" if not cve.description: return None # Analyze CVE to determine appropriate template description_lower = cve.description.lower() affected_products = [p.lower() for p in (cve.affected_products or [])] template = self._select_template(description_lower, affected_products) if not template: return None # Generate rule content rule_content = self._populate_template(cve, template) if not rule_content: return None # Determine detection type and confidence detection_type = self._determine_detection_type(description_lower) confidence_level = self._calculate_confidence(cve) sigma_rule = SigmaRule( cve_id=cve.cve_id, rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection", rule_content=rule_content, detection_type=detection_type, log_source=template.template_name.lower().replace(" ", "_"), confidence_level=confidence_level, auto_generated=True ) self.db.add(sigma_rule) return sigma_rule def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None): """Select appropriate SIGMA rule template based on CVE and exploit analysis""" templates = self.db.query(RuleTemplate).all() # If we have exploit indicators, use them to determine the best template if exploit_indicators: if exploit_indicators.get('powershell'): powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None) if powershell_template: return powershell_template if exploit_indicators.get('network'): network_template = next((t for t in templates if "Network Connection" in t.template_name), None) if network_template: return network_template if exploit_indicators.get('files'): file_template = next((t for t in templates if "File Modification" in t.template_name), None) if file_template: return file_template if exploit_indicators.get('processes') or exploit_indicators.get('commands'): process_template = next((t for t in templates if "Process Execution" in t.template_name), None) if process_template: return process_template # Fallback to original logic if any("windows" in p or "microsoft" in p for p in affected_products): if "process" in description or "execution" in description: return next((t for t in templates if "Process Execution" in t.template_name), None) elif "network" in description or "remote" in description: return next((t for t in templates if "Network Connection" in t.template_name), None) elif "file" in description or "write" in description: return next((t for t in templates if "File Modification" in t.template_name), None) # Default to process execution template return next((t for t in templates if "Process Execution" in t.template_name), None) def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str: """Populate template with CVE-specific data and exploit indicators""" try: # Use exploit indicators if available, otherwise extract from description if exploit_indicators: suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', []) suspicious_ports = [] file_patterns = exploit_indicators.get('files', []) # Extract ports from network indicators for net_indicator in exploit_indicators.get('network', []): if ':' in str(net_indicator): try: port = int(str(net_indicator).split(':')[-1]) suspicious_ports.append(port) except ValueError: pass else: # Fallback to original extraction suspicious_processes = self._extract_suspicious_indicators(cve.description, "process") suspicious_ports = self._extract_suspicious_indicators(cve.description, "port") file_patterns = self._extract_suspicious_indicators(cve.description, "file") # Determine severity level level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium" # Create enhanced description enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description if exploit_indicators: enhanced_description += " [Enhanced with GitHub exploit analysis]" # Build tags tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()] if exploit_indicators: tags.append("exploit.github") rule_content = template.template_content.format( title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection", description=enhanced_description, rule_id=str(uuid.uuid4()), date=datetime.utcnow().strftime("%Y/%m/%d"), cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}", cve_id=cve.cve_id.lower(), tags="\n - ".join(tags), suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"], suspicious_ports=suspicious_ports or [4444, 8080, 9999], file_patterns=file_patterns or ["temp", "malware", "exploit"], level=level ) return rule_content except Exception as e: print(f"Error populating template: {str(e)}") return None def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str: """Map CVE and exploit indicators to MITRE ATT&CK techniques""" desc_lower = description.lower() # Check exploit indicators first if exploit_indicators: if exploit_indicators.get('powershell'): return "t1059.001" # PowerShell elif exploit_indicators.get('commands'): return "t1059.003" # Windows Command Shell elif exploit_indicators.get('network'): return "t1071.001" # Web Protocols elif exploit_indicators.get('files'): return "t1105" # Ingress Tool Transfer elif exploit_indicators.get('processes'): return "t1106" # Native API # Fallback to description analysis if "powershell" in desc_lower: return "t1059.001" elif "command" in desc_lower or "cmd" in desc_lower: return "t1059.003" elif "network" in desc_lower or "remote" in desc_lower: return "t1071.001" elif "file" in desc_lower or "upload" in desc_lower: return "t1105" elif "process" in desc_lower or "execution" in desc_lower: return "t1106" else: return "execution" # Generic def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List: """Extract suspicious indicators from CVE description""" if indicator_type == "process": # Look for executable names or process patterns exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE) return exe_pattern[:5] if exe_pattern else None elif indicator_type == "port": # Look for port numbers port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE) return [int(p) for p in port_pattern[:3]] if port_pattern else None elif indicator_type == "file": # Look for file extensions or paths file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE) return file_pattern[:5] if file_pattern else None return None def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str: """Determine detection type based on CVE description and exploit indicators""" if exploit_indicators: if exploit_indicators.get('powershell'): return "powershell" elif exploit_indicators.get('network'): return "network" elif exploit_indicators.get('files'): return "file" elif exploit_indicators.get('processes') or exploit_indicators.get('commands'): return "process" # Fallback to original logic if "remote" in description or "network" in description: return "network" elif "process" in description or "execution" in description: return "process" elif "file" in description or "filesystem" in description: return "file" else: return "general" def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str: """Calculate confidence level for the generated rule""" base_confidence = 0 # CVSS score contributes to confidence if cve.cvss_score: if cve.cvss_score >= 9.0: base_confidence += 3 elif cve.cvss_score >= 7.0: base_confidence += 2 else: base_confidence += 1 # Exploit-based rules get higher confidence if exploit_based: base_confidence += 2 # Map to confidence levels if base_confidence >= 4: return "high" elif base_confidence >= 2: return "medium" else: return "low" # Dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() # Background task to fetch CVEs and generate rules async def background_cve_fetch(): retry_count = 0 max_retries = 3 while True: try: db = SessionLocal() service = CVESigmaService(db) current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') print(f"[{current_time}] Starting CVE fetch cycle...") # Use a longer initial period (30 days) to find CVEs new_cves = await service.fetch_recent_cves(days_back=30) if new_cves: print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...") rules_generated = 0 for cve in new_cves: try: sigma_rule = service.generate_sigma_rule(cve) if sigma_rule: rules_generated += 1 print(f"Generated SIGMA rule for {cve.cve_id}") else: print(f"Could not generate rule for {cve.cve_id} - insufficient data") except Exception as e: print(f"Error generating rule for {cve.cve_id}: {str(e)}") db.commit() print(f"Successfully generated {rules_generated} SIGMA rules") retry_count = 0 # Reset retry count on success else: print("No new CVEs found in this cycle") # After first successful run, reduce to 7 days for regular updates if retry_count == 0: print("Switching to 7-day lookback for future runs...") db.close() except Exception as e: retry_count += 1 print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}") if retry_count >= max_retries: print(f"Max retries reached, waiting longer before next attempt...") await asyncio.sleep(1800) # Wait 30 minutes on repeated failures retry_count = 0 else: await asyncio.sleep(300) # Wait 5 minutes before retry continue # Wait 1 hour before next fetch (or 30 minutes if there were errors) wait_time = 3600 if retry_count == 0 else 1800 print(f"Next CVE fetch in {wait_time//60} minutes...") await asyncio.sleep(wait_time) @asynccontextmanager async def lifespan(app: FastAPI): # Start background task task = asyncio.create_task(background_cve_fetch()) yield # Clean up task.cancel() # FastAPI app app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/api/cves", response_model=List[CVEResponse]) async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all() # Convert UUID to string for each CVE result = [] for cve in cves: cve_dict = { 'id': str(cve.id), 'cve_id': cve.cve_id, 'description': cve.description, 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, 'severity': cve.severity, 'published_date': cve.published_date, 'affected_products': cve.affected_products, 'reference_urls': cve.reference_urls } result.append(CVEResponse(**cve_dict)) return result @app.get("/api/cves/{cve_id}", response_model=CVEResponse) async def get_cve(cve_id: str, db: Session = Depends(get_db)): cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() if not cve: raise HTTPException(status_code=404, detail="CVE not found") cve_dict = { 'id': str(cve.id), 'cve_id': cve.cve_id, 'description': cve.description, 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, 'severity': cve.severity, 'published_date': cve.published_date, 'affected_products': cve.affected_products, 'reference_urls': cve.reference_urls } return CVEResponse(**cve_dict) @app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse]) async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all() # Convert UUID to string for each rule result = [] for rule in rules: rule_dict = { 'id': str(rule.id), 'cve_id': rule.cve_id, 'rule_name': rule.rule_name, 'rule_content': rule.rule_content, 'detection_type': rule.detection_type, 'log_source': rule.log_source, 'confidence_level': rule.confidence_level, 'auto_generated': rule.auto_generated, 'exploit_based': rule.exploit_based or False, 'github_repos': rule.github_repos or [], 'exploit_indicators': rule.exploit_indicators, 'created_at': rule.created_at } result.append(SigmaRuleResponse(**rule_dict)) return result @app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse]) async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)): rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all() # Convert UUID to string for each rule result = [] for rule in rules: rule_dict = { 'id': str(rule.id), 'cve_id': rule.cve_id, 'rule_name': rule.rule_name, 'rule_content': rule.rule_content, 'detection_type': rule.detection_type, 'log_source': rule.log_source, 'confidence_level': rule.confidence_level, 'auto_generated': rule.auto_generated, 'exploit_based': rule.exploit_based or False, 'github_repos': rule.github_repos or [], 'exploit_indicators': rule.exploit_indicators, 'created_at': rule.created_at } result.append(SigmaRuleResponse(**rule_dict)) return result @app.post("/api/fetch-cves") async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): async def fetch_task(): try: service = CVESigmaService(db) print("Manual CVE fetch initiated...") # Use 30 days for manual fetch to get more results new_cves = await service.fetch_recent_cves(days_back=30) rules_generated = 0 for cve in new_cves: sigma_rule = service.generate_sigma_rule(cve) if sigma_rule: rules_generated += 1 db.commit() print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated") except Exception as e: print(f"Manual fetch error: {str(e)}") import traceback traceback.print_exc() background_tasks.add_task(fetch_task) return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"} @app.get("/api/test-nvd") async def test_nvd_connection(): """Test endpoint to check NVD API connectivity""" try: # Test with a simple request using current date end_date = datetime.utcnow() start_date = end_date - timedelta(days=30) url = "https://services.nvd.nist.gov/rest/json/cves/2.0/" params = { "lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), "lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), "resultsPerPage": 5, "startIndex": 0 } headers = { "User-Agent": "CVE-SIGMA-Generator/1.0", "Accept": "application/json" } nvd_api_key = os.getenv("NVD_API_KEY") if nvd_api_key: headers["apiKey"] = nvd_api_key print(f"Testing NVD API with URL: {url}") print(f"Test params: {params}") print(f"Test headers: {headers}") response = requests.get(url, params=params, headers=headers, timeout=15) result = { "status": "success" if response.status_code == 200 else "error", "status_code": response.status_code, "has_api_key": bool(nvd_api_key), "request_url": f"{url}?{requests.compat.urlencode(params)}", "response_headers": dict(response.headers) } if response.status_code == 200: data = response.json() result.update({ "total_results": data.get("totalResults", 0), "results_per_page": data.get("resultsPerPage", 0), "vulnerabilities_returned": len(data.get("vulnerabilities", [])), "message": "NVD API is accessible and returning data" }) else: result.update({ "error_message": response.text[:200], "message": f"NVD API returned {response.status_code}" }) # Try fallback without date filters if we get 404 if response.status_code == 404: print("Trying fallback without date filters...") fallback_params = { "resultsPerPage": 5, "startIndex": 0 } fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15) result["fallback_status_code"] = fallback_response.status_code if fallback_response.status_code == 200: fallback_data = fallback_response.json() result.update({ "fallback_success": True, "fallback_total_results": fallback_data.get("totalResults", 0), "message": "NVD API works without date filters" }) return result except Exception as e: print(f"NVD API test error: {str(e)}") return { "status": "error", "message": f"Failed to connect to NVD API: {str(e)}" } @app.get("/api/stats") async def get_stats(db: Session = Depends(get_db)): total_cves = db.query(CVE).count() total_rules = db.query(SigmaRule).count() recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count() # Enhanced stats with bulk processing info bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count() cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count() nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count() return { "total_cves": total_cves, "total_sigma_rules": total_rules, "recent_cves_7_days": recent_cves, "bulk_processed_cves": bulk_processed_cves, "cves_with_pocs": cves_with_pocs, "nomi_sec_rules": nomi_sec_rules, "poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0, "nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0 } # New bulk processing endpoints @app.post("/api/bulk-seed") async def start_bulk_seed(background_tasks: BackgroundTasks, start_year: int = 2002, end_year: Optional[int] = None, skip_nvd: bool = False, skip_nomi_sec: bool = False, db: Session = Depends(get_db)): """Start bulk seeding process""" async def bulk_seed_task(): try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) result = await seeder.full_bulk_seed( start_year=start_year, end_year=end_year, skip_nvd=skip_nvd, skip_nomi_sec=skip_nomi_sec ) logger.info(f"Bulk seed completed: {result}") except Exception as e: logger.error(f"Bulk seed failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(bulk_seed_task) return { "message": "Bulk seeding process started", "status": "started", "start_year": start_year, "end_year": end_year or datetime.now().year, "skip_nvd": skip_nvd, "skip_nomi_sec": skip_nomi_sec } @app.post("/api/incremental-update") async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start incremental update process""" async def incremental_update_task(): try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) result = await seeder.incremental_update() logger.info(f"Incremental update completed: {result}") except Exception as e: logger.error(f"Incremental update failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(incremental_update_task) return { "message": "Incremental update process started", "status": "started" } @app.post("/api/sync-nomi-sec") async def sync_nomi_sec(background_tasks: BackgroundTasks, cve_id: Optional[str] = None, batch_size: int = 50, db: Session = Depends(get_db)): """Synchronize nomi-sec PoC data""" # Create job record job = BulkProcessingJob( job_type='nomi_sec_sync', status='pending', job_metadata={ 'cve_id': cve_id, 'batch_size': batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) running_jobs[job_id] = job job_cancellation_flags[job_id] = False async def sync_task(): try: job.status = 'running' job.started_at = datetime.utcnow() db.commit() from nomi_sec_client import NomiSecClient client = NomiSecClient(db) if cve_id: # Sync specific CVE if job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_pocs(cve_id) logger.info(f"Nomi-sec sync for {cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_all_cves( batch_size=batch_size, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) logger.info(f"Nomi-sec bulk sync completed: {result}") # Update job status if not cancelled if not job_cancellation_flags.get(job_id, False): job.status = 'completed' job.completed_at = datetime.utcnow() db.commit() except Exception as e: if not job_cancellation_flags.get(job_id, False): job.status = 'failed' job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() logger.error(f"Nomi-sec sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking running_jobs.pop(job_id, None) job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"Nomi-sec sync started" + (f" for {cve_id}" if cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": cve_id, "batch_size": batch_size } @app.get("/api/bulk-jobs") async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)): """Get bulk processing job status""" jobs = db.query(BulkProcessingJob).order_by( BulkProcessingJob.created_at.desc() ).limit(limit).all() result = [] for job in jobs: job_dict = { 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'metadata': job.job_metadata, 'started_at': job.started_at, 'completed_at': job.completed_at, 'created_at': job.created_at } result.append(job_dict) return result @app.get("/api/bulk-status") async def get_bulk_status(db: Session = Depends(get_db)): """Get comprehensive bulk processing status""" try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) status = await seeder.get_seeding_status() return status except Exception as e: logger.error(f"Error getting bulk status: {e}") return {"error": str(e)} @app.get("/api/poc-stats") async def get_poc_stats(db: Session = Depends(get_db)): """Get PoC-related statistics""" try: from nomi_sec_client import NomiSecClient client = NomiSecClient(db) stats = await client.get_sync_status() # Additional PoC statistics high_quality_cves = db.query(CVE).filter( CVE.poc_count > 0, func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60 ).count() stats.update({ 'high_quality_cves': high_quality_cves, 'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0 }) return stats except Exception as e: logger.error(f"Error getting PoC stats: {e}") return {"error": str(e)} @app.post("/api/regenerate-rules") async def regenerate_sigma_rules(background_tasks: BackgroundTasks, force: bool = False, db: Session = Depends(get_db)): """Regenerate SIGMA rules using enhanced nomi-sec data""" async def regenerate_task(): try: from enhanced_sigma_generator import EnhancedSigmaGenerator generator = EnhancedSigmaGenerator(db) # Get CVEs with PoC data cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all() rules_generated = 0 rules_updated = 0 for cve in cves_with_pocs: # Check if we should regenerate existing_rule = db.query(SigmaRule).filter( SigmaRule.cve_id == cve.cve_id ).first() if existing_rule and existing_rule.poc_source == 'nomi_sec' and not force: continue # Generate enhanced rule result = await generator.generate_enhanced_rule(cve) if result['success']: if existing_rule: rules_updated += 1 else: rules_generated += 1 logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated") except Exception as e: logger.error(f"Rule regeneration failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(regenerate_task) return { "message": "SIGMA rule regeneration started", "status": "started", "force": force } @app.post("/api/cancel-job/{job_id}") async def cancel_job(job_id: str, db: Session = Depends(get_db)): """Cancel a running job""" try: # Find the job in the database job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Job not found") if job.status not in ['pending', 'running']: raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") # Set cancellation flag job_cancellation_flags[job_id] = True # Update job status job.status = 'cancelled' job.cancelled_at = datetime.utcnow() job.error_message = "Job cancelled by user" db.commit() logger.info(f"Job {job_id} cancellation requested") return { "message": f"Job {job_id} cancellation requested", "status": "cancelled", "job_id": job_id } except HTTPException: raise except Exception as e: logger.error(f"Error cancelling job {job_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/running-jobs") async def get_running_jobs(db: Session = Depends(get_db)): """Get all currently running jobs""" try: jobs = db.query(BulkProcessingJob).filter( BulkProcessingJob.status.in_(['pending', 'running']) ).order_by(BulkProcessingJob.created_at.desc()).all() result = [] for job in jobs: result.append({ 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'started_at': job.started_at, 'created_at': job.created_at, 'can_cancel': job.status in ['pending', 'running'] }) return result except Exception as e: logger.error(f"Error getting running jobs: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)