from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.dialects.postgresql import UUID import uuid from datetime import datetime, timedelta import requests import json import re import os from typing import List, Optional from pydantic import BaseModel import asyncio from contextlib import asynccontextmanager import base64 from github import Github from urllib.parse import urlparse import hashlib import logging import threading from mcdevitt_poc_client import GitHubPoCClient from cve2capec_client import CVE2CAPECClient # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global job tracking running_jobs = {} job_cancellation_flags = {} # Database setup DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db") engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # Database Models class CVE(Base): __tablename__ = "cves" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20), unique=True, nullable=False) description = Column(Text) cvss_score = Column(DECIMAL(3, 1)) severity = Column(String(20)) published_date = Column(TIMESTAMP) modified_date = Column(TIMESTAMP) affected_products = Column(ARRAY(String)) reference_urls = Column(ARRAY(String)) # Bulk processing fields data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual' nvd_json_version = Column(String(10), default='2.0') bulk_processed = Column(Boolean, default=False) # nomi-sec PoC fields poc_count = Column(Integer, default=0) poc_data = Column(JSON) # Store nomi-sec PoC metadata # Reference data fields reference_data = Column(JSON) # Store extracted reference content and analysis reference_sync_status = Column(String(20), default='pending') # 'pending', 'processing', 'completed', 'failed' reference_last_synced = Column(TIMESTAMP) created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class SigmaRule(Base): __tablename__ = "sigma_rules" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20)) rule_name = Column(String(255), nullable=False) rule_content = Column(Text, nullable=False) detection_type = Column(String(50)) log_source = Column(String(100)) confidence_level = Column(String(20)) auto_generated = Column(Boolean, default=True) exploit_based = Column(Boolean, default=False) github_repos = Column(ARRAY(String)) exploit_indicators = Column(Text) # JSON string of extracted indicators # Enhanced fields for new data sources poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual' poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc. nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class RuleTemplate(Base): __tablename__ = "rule_templates" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) template_name = Column(String(255), nullable=False) template_content = Column(Text, nullable=False) applicable_product_patterns = Column(ARRAY(String)) description = Column(Text) created_at = Column(TIMESTAMP, default=datetime.utcnow) class BulkProcessingJob(Base): __tablename__ = "bulk_processing_jobs" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update' status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled' year = Column(Integer) # For year-based processing total_items = Column(Integer, default=0) processed_items = Column(Integer, default=0) failed_items = Column(Integer, default=0) error_message = Column(Text) job_metadata = Column(JSON) # Additional job-specific data started_at = Column(TIMESTAMP) completed_at = Column(TIMESTAMP) cancelled_at = Column(TIMESTAMP) created_at = Column(TIMESTAMP, default=datetime.utcnow) # Pydantic models class CVEResponse(BaseModel): id: str cve_id: str description: Optional[str] = None cvss_score: Optional[float] = None severity: Optional[str] = None published_date: Optional[datetime] = None affected_products: Optional[List[str]] = None reference_urls: Optional[List[str]] = None class Config: from_attributes = True class SigmaRuleResponse(BaseModel): id: str cve_id: str rule_name: str rule_content: str detection_type: Optional[str] = None log_source: Optional[str] = None confidence_level: Optional[str] = None auto_generated: bool = True exploit_based: bool = False github_repos: Optional[List[str]] = None exploit_indicators: Optional[str] = None created_at: datetime class Config: from_attributes = True # Request models class BulkSeedRequest(BaseModel): start_year: int = 2002 end_year: Optional[int] = None skip_nvd: bool = False skip_nomi_sec: bool = True class NomiSecSyncRequest(BaseModel): cve_id: Optional[str] = None batch_size: int = 50 class GitHubPoCSyncRequest(BaseModel): cve_id: Optional[str] = None batch_size: int = 50 class ExploitDBSyncRequest(BaseModel): cve_id: Optional[str] = None batch_size: int = 30 class CISAKEVSyncRequest(BaseModel): cve_id: Optional[str] = None batch_size: int = 100 class ReferenceSyncRequest(BaseModel): cve_id: Optional[str] = None batch_size: int = 30 max_cves: Optional[int] = None force_resync: bool = False class RuleRegenRequest(BaseModel): force: bool = False # GitHub Exploit Analysis Service class GitHubExploitAnalyzer: def __init__(self): self.github_token = os.getenv("GITHUB_TOKEN") self.github = Github(self.github_token) if self.github_token else None async def search_exploits_for_cve(self, cve_id: str) -> List[dict]: """Search GitHub for exploit code related to a CVE""" if not self.github: print(f"No GitHub token configured, skipping exploit search for {cve_id}") return [] try: print(f"Searching GitHub for exploits for {cve_id}") # Search queries to find exploit code search_queries = [ f"{cve_id} exploit", f"{cve_id} poc", f"{cve_id} vulnerability", f'"{cve_id}" exploit code', f"{cve_id.replace('-', '_')} exploit" ] exploits = [] seen_repos = set() for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits try: # Search repositories repos = self.github.search_repositories( query=query, sort="updated", order="desc" ) # Get top 5 results per query for repo in repos[:5]: if repo.full_name in seen_repos: continue seen_repos.add(repo.full_name) # Analyze repository exploit_info = await self._analyze_repository(repo, cve_id) if exploit_info: exploits.append(exploit_info) if len(exploits) >= 10: # Limit total exploits break if len(exploits) >= 10: break except Exception as e: print(f"Error searching GitHub with query '{query}': {str(e)}") continue print(f"Found {len(exploits)} potential exploits for {cve_id}") return exploits except Exception as e: print(f"Error searching GitHub for {cve_id}: {str(e)}") return [] async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]: """Analyze a GitHub repository for exploit code""" try: # Check if repo name or description mentions the CVE repo_text = f"{repo.name} {repo.description or ''}".lower() if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text: return None # Get repository contents exploit_files = [] indicators = { 'processes': set(), 'files': set(), 'registry': set(), 'network': set(), 'commands': set(), 'powershell': set(), 'urls': set() } try: contents = repo.get_contents("") for content in contents[:20]: # Limit files to analyze if content.type == "file" and self._is_exploit_file(content.name): file_analysis = await self._analyze_file_content(repo, content, cve_id) if file_analysis: exploit_files.append(file_analysis) # Merge indicators for key, values in file_analysis.get('indicators', {}).items(): if key in indicators: indicators[key].update(values) except Exception as e: print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}") if not exploit_files: return None return { 'repo_name': repo.full_name, 'repo_url': repo.html_url, 'description': repo.description, 'language': repo.language, 'stars': repo.stargazers_count, 'updated': repo.updated_at.isoformat(), 'files': exploit_files, 'indicators': {k: list(v) for k, v in indicators.items()} } except Exception as e: print(f"Error analyzing repository {repo.full_name}: {str(e)}") return None def _is_exploit_file(self, filename: str) -> bool: """Check if a file is likely to contain exploit code""" exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java'] exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack'] filename_lower = filename.lower() # Check extension if not any(filename_lower.endswith(ext) for ext in exploit_extensions): return False # Check filename for exploit-related terms return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]: """Analyze individual file content for exploit indicators""" try: if file_content.size > 100000: # Skip files larger than 100KB return None # Decode file content content = file_content.decoded_content.decode('utf-8', errors='ignore') # Check if file actually mentions the CVE if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower(): return None indicators = self._extract_indicators_from_code(content, file_content.name) if not any(indicators.values()): return None return { 'filename': file_content.name, 'path': file_content.path, 'size': file_content.size, 'indicators': indicators } except Exception as e: print(f"Error analyzing file {file_content.name}: {str(e)}") return None def _extract_indicators_from_code(self, content: str, filename: str) -> dict: """Extract security indicators from exploit code""" indicators = { 'processes': set(), 'files': set(), 'registry': set(), 'network': set(), 'commands': set(), 'powershell': set(), 'urls': set() } # Process patterns process_patterns = [ r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']', r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']', r'system\s*\(\s*["\']([^"\']+)["\']', r'exec\s*\(\s*["\']([^"\']+)["\']', r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']' ] # File patterns file_patterns = [ r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']', r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?', r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)', r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)' ] # Registry patterns registry_patterns = [ r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']', r'HKEY_[A-Z_]+\\\\([^"\'\\]+)', r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?' ] # Network patterns network_patterns = [ r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)', r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)', r'(?:http|https|ftp)://([^\s"\'<>]+)', r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)' ] # PowerShell patterns powershell_patterns = [ r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?', r'Start-Process\s+["\']?([^"\']+)["\']?', r'Get-Process\s+["\']?([^"\']+)["\']?' ] # Command patterns command_patterns = [ r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)', r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)' ] # Extract indicators using regex patterns patterns = { 'processes': process_patterns, 'files': file_patterns, 'registry': registry_patterns, 'powershell': powershell_patterns, 'commands': command_patterns } for category, pattern_list in patterns.items(): for pattern in pattern_list: matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE) for match in matches: if isinstance(match, tuple): indicators[category].add(match[0]) else: indicators[category].add(match) # Special handling for network indicators for pattern in network_patterns: matches = re.findall(pattern, content, re.IGNORECASE) for match in matches: if isinstance(match, tuple): if len(match) >= 2: indicators['network'].add(f"{match[0]}:{match[1]}") else: indicators['network'].add(match[0]) else: indicators['network'].add(match) # Convert sets to lists and filter out empty/invalid indicators cleaned_indicators = {} for key, values in indicators.items(): cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200] if cleaned_values: cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category return cleaned_indicators class CVESigmaService: def __init__(self, db: Session): self.db = db self.nvd_api_key = os.getenv("NVD_API_KEY") async def fetch_recent_cves(self, days_back: int = 7): """Fetch recent CVEs from NVD API""" end_date = datetime.utcnow() start_date = end_date - timedelta(days=days_back) url = "https://services.nvd.nist.gov/rest/json/cves/2.0" params = { "pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "resultsPerPage": 100 } headers = {} if self.nvd_api_key: headers["apiKey"] = self.nvd_api_key try: response = requests.get(url, params=params, headers=headers, timeout=30) response.raise_for_status() data = response.json() new_cves = [] for vuln in data.get("vulnerabilities", []): cve_data = vuln.get("cve", {}) cve_id = cve_data.get("id") # Check if CVE already exists existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first() if existing: continue # Extract CVE information description = "" if cve_data.get("descriptions"): description = cve_data["descriptions"][0].get("value", "") cvss_score = None severity = None if cve_data.get("metrics", {}).get("cvssMetricV31"): cvss_data = cve_data["metrics"]["cvssMetricV31"][0] cvss_score = cvss_data.get("cvssData", {}).get("baseScore") severity = cvss_data.get("cvssData", {}).get("baseSeverity") affected_products = [] if cve_data.get("configurations"): for config in cve_data["configurations"]: for node in config.get("nodes", []): for cpe_match in node.get("cpeMatch", []): if cpe_match.get("vulnerable"): affected_products.append(cpe_match.get("criteria", "")) reference_urls = [] if cve_data.get("references"): reference_urls = [ref.get("url", "") for ref in cve_data["references"]] cve_obj = CVE( cve_id=cve_id, description=description, cvss_score=cvss_score, severity=severity, published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")), modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")), affected_products=affected_products, reference_urls=reference_urls ) self.db.add(cve_obj) new_cves.append(cve_obj) self.db.commit() return new_cves except Exception as e: print(f"Error fetching CVEs: {str(e)}") return [] def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]: """Generate SIGMA rule based on CVE data""" if not cve.description: return None # Analyze CVE to determine appropriate template description_lower = cve.description.lower() affected_products = [p.lower() for p in (cve.affected_products or [])] template = self._select_template(description_lower, affected_products) if not template: return None # Generate rule content rule_content = self._populate_template(cve, template) if not rule_content: return None # Determine detection type and confidence detection_type = self._determine_detection_type(description_lower) confidence_level = self._calculate_confidence(cve) sigma_rule = SigmaRule( cve_id=cve.cve_id, rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection", rule_content=rule_content, detection_type=detection_type, log_source=template.template_name.lower().replace(" ", "_"), confidence_level=confidence_level, auto_generated=True ) self.db.add(sigma_rule) return sigma_rule def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None): """Select appropriate SIGMA rule template based on CVE and exploit analysis""" templates = self.db.query(RuleTemplate).all() # If we have exploit indicators, use them to determine the best template if exploit_indicators: if exploit_indicators.get('powershell'): powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None) if powershell_template: return powershell_template if exploit_indicators.get('network'): network_template = next((t for t in templates if "Network Connection" in t.template_name), None) if network_template: return network_template if exploit_indicators.get('files'): file_template = next((t for t in templates if "File Modification" in t.template_name), None) if file_template: return file_template if exploit_indicators.get('processes') or exploit_indicators.get('commands'): process_template = next((t for t in templates if "Process Execution" in t.template_name), None) if process_template: return process_template # Fallback to original logic if any("windows" in p or "microsoft" in p for p in affected_products): if "process" in description or "execution" in description: return next((t for t in templates if "Process Execution" in t.template_name), None) elif "network" in description or "remote" in description: return next((t for t in templates if "Network Connection" in t.template_name), None) elif "file" in description or "write" in description: return next((t for t in templates if "File Modification" in t.template_name), None) # Default to process execution template return next((t for t in templates if "Process Execution" in t.template_name), None) def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str: """Populate template with CVE-specific data and exploit indicators""" try: # Use exploit indicators if available, otherwise extract from description if exploit_indicators: suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', []) suspicious_ports = [] file_patterns = exploit_indicators.get('files', []) # Extract ports from network indicators for net_indicator in exploit_indicators.get('network', []): if ':' in str(net_indicator): try: port = int(str(net_indicator).split(':')[-1]) suspicious_ports.append(port) except ValueError: pass else: # Fallback to original extraction suspicious_processes = self._extract_suspicious_indicators(cve.description, "process") suspicious_ports = self._extract_suspicious_indicators(cve.description, "port") file_patterns = self._extract_suspicious_indicators(cve.description, "file") # Determine severity level level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium" # Create enhanced description enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description if exploit_indicators: enhanced_description += " [Enhanced with GitHub exploit analysis]" # Build tags tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()] if exploit_indicators: tags.append("exploit.github") rule_content = template.template_content.format( title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection", description=enhanced_description, rule_id=str(uuid.uuid4()), date=datetime.utcnow().strftime("%Y/%m/%d"), cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}", cve_id=cve.cve_id.lower(), tags="\n - ".join(tags), suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"], suspicious_ports=suspicious_ports or [4444, 8080, 9999], file_patterns=file_patterns or ["temp", "malware", "exploit"], level=level ) return rule_content except Exception as e: print(f"Error populating template: {str(e)}") return None def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str: """Map CVE and exploit indicators to MITRE ATT&CK techniques""" desc_lower = description.lower() # Check exploit indicators first if exploit_indicators: if exploit_indicators.get('powershell'): return "t1059.001" # PowerShell elif exploit_indicators.get('commands'): return "t1059.003" # Windows Command Shell elif exploit_indicators.get('network'): return "t1071.001" # Web Protocols elif exploit_indicators.get('files'): return "t1105" # Ingress Tool Transfer elif exploit_indicators.get('processes'): return "t1106" # Native API # Fallback to description analysis if "powershell" in desc_lower: return "t1059.001" elif "command" in desc_lower or "cmd" in desc_lower: return "t1059.003" elif "network" in desc_lower or "remote" in desc_lower: return "t1071.001" elif "file" in desc_lower or "upload" in desc_lower: return "t1105" elif "process" in desc_lower or "execution" in desc_lower: return "t1106" else: return "execution" # Generic def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List: """Extract suspicious indicators from CVE description""" if indicator_type == "process": # Look for executable names or process patterns exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE) return exe_pattern[:5] if exe_pattern else None elif indicator_type == "port": # Look for port numbers port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE) return [int(p) for p in port_pattern[:3]] if port_pattern else None elif indicator_type == "file": # Look for file extensions or paths file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE) return file_pattern[:5] if file_pattern else None return None def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str: """Determine detection type based on CVE description and exploit indicators""" if exploit_indicators: if exploit_indicators.get('powershell'): return "powershell" elif exploit_indicators.get('network'): return "network" elif exploit_indicators.get('files'): return "file" elif exploit_indicators.get('processes') or exploit_indicators.get('commands'): return "process" # Fallback to original logic if "remote" in description or "network" in description: return "network" elif "process" in description or "execution" in description: return "process" elif "file" in description or "filesystem" in description: return "file" else: return "general" def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str: """Calculate confidence level for the generated rule""" base_confidence = 0 # CVSS score contributes to confidence if cve.cvss_score: if cve.cvss_score >= 9.0: base_confidence += 3 elif cve.cvss_score >= 7.0: base_confidence += 2 else: base_confidence += 1 # Exploit-based rules get higher confidence if exploit_based: base_confidence += 2 # Map to confidence levels if base_confidence >= 4: return "high" elif base_confidence >= 2: return "medium" else: return "low" # Dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() # Background task to fetch CVEs and generate rules async def background_cve_fetch(): retry_count = 0 max_retries = 3 while True: try: db = SessionLocal() service = CVESigmaService(db) current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') print(f"[{current_time}] Starting CVE fetch cycle...") # Use a longer initial period (30 days) to find CVEs new_cves = await service.fetch_recent_cves(days_back=30) if new_cves: print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...") rules_generated = 0 for cve in new_cves: try: sigma_rule = service.generate_sigma_rule(cve) if sigma_rule: rules_generated += 1 print(f"Generated SIGMA rule for {cve.cve_id}") else: print(f"Could not generate rule for {cve.cve_id} - insufficient data") except Exception as e: print(f"Error generating rule for {cve.cve_id}: {str(e)}") db.commit() print(f"Successfully generated {rules_generated} SIGMA rules") retry_count = 0 # Reset retry count on success else: print("No new CVEs found in this cycle") # After first successful run, reduce to 7 days for regular updates if retry_count == 0: print("Switching to 7-day lookback for future runs...") db.close() except Exception as e: retry_count += 1 print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}") if retry_count >= max_retries: print(f"Max retries reached, waiting longer before next attempt...") await asyncio.sleep(1800) # Wait 30 minutes on repeated failures retry_count = 0 else: await asyncio.sleep(300) # Wait 5 minutes before retry continue # Wait 1 hour before next fetch (or 30 minutes if there were errors) wait_time = 3600 if retry_count == 0 else 1800 print(f"Next CVE fetch in {wait_time//60} minutes...") await asyncio.sleep(wait_time) @asynccontextmanager async def lifespan(app: FastAPI): # Initialize database Base.metadata.create_all(bind=engine) # Initialize rule templates db = SessionLocal() try: existing_templates = db.query(RuleTemplate).count() if existing_templates == 0: logger.info("No rule templates found. Database initialization will handle template creation.") except Exception as e: logger.error(f"Error checking rule templates: {e}") finally: db.close() # Initialize and start the job scheduler try: from job_scheduler import initialize_scheduler from job_executors import register_all_executors # Initialize scheduler scheduler = initialize_scheduler() scheduler.set_db_session_factory(SessionLocal) # Register all job executors register_all_executors(scheduler) # Start the scheduler scheduler.start() logger.info("Job scheduler initialized and started") except Exception as e: logger.error(f"Error initializing job scheduler: {e}") yield # Shutdown try: from job_scheduler import get_scheduler scheduler = get_scheduler() scheduler.stop() logger.info("Job scheduler stopped") except Exception as e: logger.error(f"Error stopping job scheduler: {e}") # FastAPI app app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/api/cves", response_model=List[CVEResponse]) async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all() # Convert UUID to string for each CVE result = [] for cve in cves: cve_dict = { 'id': str(cve.id), 'cve_id': cve.cve_id, 'description': cve.description, 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, 'severity': cve.severity, 'published_date': cve.published_date, 'affected_products': cve.affected_products, 'reference_urls': cve.reference_urls } result.append(CVEResponse(**cve_dict)) return result @app.get("/api/cves/{cve_id}", response_model=CVEResponse) async def get_cve(cve_id: str, db: Session = Depends(get_db)): cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() if not cve: raise HTTPException(status_code=404, detail="CVE not found") cve_dict = { 'id': str(cve.id), 'cve_id': cve.cve_id, 'description': cve.description, 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, 'severity': cve.severity, 'published_date': cve.published_date, 'affected_products': cve.affected_products, 'reference_urls': cve.reference_urls } return CVEResponse(**cve_dict) @app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse]) async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all() # Convert UUID to string for each rule result = [] for rule in rules: rule_dict = { 'id': str(rule.id), 'cve_id': rule.cve_id, 'rule_name': rule.rule_name, 'rule_content': rule.rule_content, 'detection_type': rule.detection_type, 'log_source': rule.log_source, 'confidence_level': rule.confidence_level, 'auto_generated': rule.auto_generated, 'exploit_based': rule.exploit_based or False, 'github_repos': rule.github_repos or [], 'exploit_indicators': rule.exploit_indicators, 'created_at': rule.created_at } result.append(SigmaRuleResponse(**rule_dict)) return result @app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse]) async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)): rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all() # Convert UUID to string for each rule result = [] for rule in rules: rule_dict = { 'id': str(rule.id), 'cve_id': rule.cve_id, 'rule_name': rule.rule_name, 'rule_content': rule.rule_content, 'detection_type': rule.detection_type, 'log_source': rule.log_source, 'confidence_level': rule.confidence_level, 'auto_generated': rule.auto_generated, 'exploit_based': rule.exploit_based or False, 'github_repos': rule.github_repos or [], 'exploit_indicators': rule.exploit_indicators, 'created_at': rule.created_at } result.append(SigmaRuleResponse(**rule_dict)) return result @app.post("/api/fetch-cves") async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): async def fetch_task(): try: service = CVESigmaService(db) print("Manual CVE fetch initiated...") # Use 30 days for manual fetch to get more results new_cves = await service.fetch_recent_cves(days_back=30) rules_generated = 0 for cve in new_cves: sigma_rule = service.generate_sigma_rule(cve) if sigma_rule: rules_generated += 1 db.commit() print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated") except Exception as e: print(f"Manual fetch error: {str(e)}") import traceback traceback.print_exc() background_tasks.add_task(fetch_task) return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"} @app.get("/api/test-nvd") async def test_nvd_connection(): """Test endpoint to check NVD API connectivity""" try: # Test with a simple request using current date end_date = datetime.utcnow() start_date = end_date - timedelta(days=30) url = "https://services.nvd.nist.gov/rest/json/cves/2.0/" params = { "lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), "lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), "resultsPerPage": 5, "startIndex": 0 } headers = { "User-Agent": "CVE-SIGMA-Generator/1.0", "Accept": "application/json" } nvd_api_key = os.getenv("NVD_API_KEY") if nvd_api_key: headers["apiKey"] = nvd_api_key print(f"Testing NVD API with URL: {url}") print(f"Test params: {params}") print(f"Test headers: {headers}") response = requests.get(url, params=params, headers=headers, timeout=15) result = { "status": "success" if response.status_code == 200 else "error", "status_code": response.status_code, "has_api_key": bool(nvd_api_key), "request_url": f"{url}?{requests.compat.urlencode(params)}", "response_headers": dict(response.headers) } if response.status_code == 200: data = response.json() result.update({ "total_results": data.get("totalResults", 0), "results_per_page": data.get("resultsPerPage", 0), "vulnerabilities_returned": len(data.get("vulnerabilities", [])), "message": "NVD API is accessible and returning data" }) else: result.update({ "error_message": response.text[:200], "message": f"NVD API returned {response.status_code}" }) # Try fallback without date filters if we get 404 if response.status_code == 404: print("Trying fallback without date filters...") fallback_params = { "resultsPerPage": 5, "startIndex": 0 } fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15) result["fallback_status_code"] = fallback_response.status_code if fallback_response.status_code == 200: fallback_data = fallback_response.json() result.update({ "fallback_success": True, "fallback_total_results": fallback_data.get("totalResults", 0), "message": "NVD API works without date filters" }) return result except Exception as e: print(f"NVD API test error: {str(e)}") return { "status": "error", "message": f"Failed to connect to NVD API: {str(e)}" } @app.get("/api/stats") async def get_stats(db: Session = Depends(get_db)): total_cves = db.query(CVE).count() total_rules = db.query(SigmaRule).count() recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count() # Enhanced stats with bulk processing info bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count() cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count() nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count() return { "total_cves": total_cves, "total_sigma_rules": total_rules, "recent_cves_7_days": recent_cves, "bulk_processed_cves": bulk_processed_cves, "cves_with_pocs": cves_with_pocs, "nomi_sec_rules": nomi_sec_rules, "poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0, "nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0 } # New bulk processing endpoints @app.post("/api/bulk-seed") async def start_bulk_seed(background_tasks: BackgroundTasks, request: BulkSeedRequest, db: Session = Depends(get_db)): """Start bulk seeding process""" async def bulk_seed_task(): try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) result = await seeder.full_bulk_seed( start_year=request.start_year, end_year=request.end_year, skip_nvd=request.skip_nvd, skip_nomi_sec=request.skip_nomi_sec ) logger.info(f"Bulk seed completed: {result}") except Exception as e: logger.error(f"Bulk seed failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(bulk_seed_task) return { "message": "Bulk seeding process started", "status": "started", "start_year": request.start_year, "end_year": request.end_year or datetime.now().year, "skip_nvd": request.skip_nvd, "skip_nomi_sec": request.skip_nomi_sec } @app.post("/api/incremental-update") async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start incremental update process""" async def incremental_update_task(): try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) result = await seeder.incremental_update() logger.info(f"Incremental update completed: {result}") except Exception as e: logger.error(f"Incremental update failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(incremental_update_task) return { "message": "Incremental update process started", "status": "started" } @app.post("/api/sync-nomi-sec") async def sync_nomi_sec(background_tasks: BackgroundTasks, request: NomiSecSyncRequest, db: Session = Depends(get_db)): """Synchronize nomi-sec PoC data""" # Create job record job = BulkProcessingJob( job_type='nomi_sec_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) running_jobs[job_id] = job job_cancellation_flags[job_id] = False async def sync_task(): try: job.status = 'running' job.started_at = datetime.utcnow() db.commit() from nomi_sec_client import NomiSecClient client = NomiSecClient(db) if request.cve_id: # Sync specific CVE if job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_pocs(request.cve_id) logger.info(f"Nomi-sec sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_all_cves( batch_size=request.batch_size, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) logger.info(f"Nomi-sec bulk sync completed: {result}") # Update job status if not cancelled if not job_cancellation_flags.get(job_id, False): job.status = 'completed' job.completed_at = datetime.utcnow() db.commit() except Exception as e: if not job_cancellation_flags.get(job_id, False): job.status = 'failed' job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() logger.error(f"Nomi-sec sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking running_jobs.pop(job_id, None) job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @app.post("/api/sync-github-pocs") async def sync_github_pocs(background_tasks: BackgroundTasks, request: GitHubPoCSyncRequest, db: Session = Depends(get_db)): """Synchronize GitHub PoC data""" # Create job record job = BulkProcessingJob( job_type='github_poc_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) running_jobs[job_id] = job job_cancellation_flags[job_id] = False async def sync_task(): try: job.status = 'running' job.started_at = datetime.utcnow() db.commit() client = GitHubPoCClient(db) if request.cve_id: # Sync specific CVE if job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_pocs(request.cve_id) logger.info(f"GitHub PoC sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_all_cves(batch_size=request.batch_size) logger.info(f"GitHub PoC bulk sync completed: {result}") # Update job status if not cancelled if not job_cancellation_flags.get(job_id, False): job.status = 'completed' job.completed_at = datetime.utcnow() db.commit() except Exception as e: if not job_cancellation_flags.get(job_id, False): job.status = 'failed' job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() logger.error(f"GitHub PoC sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking running_jobs.pop(job_id, None) job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @app.post("/api/sync-exploitdb") async def sync_exploitdb(background_tasks: BackgroundTasks, request: ExploitDBSyncRequest, db: Session = Depends(get_db)): """Synchronize ExploitDB data from git mirror""" # Create job record job = BulkProcessingJob( job_type='exploitdb_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) running_jobs[job_id] = job job_cancellation_flags[job_id] = False async def sync_task(): # Create a new database session for the background task task_db = SessionLocal() try: # Get the job in the new session task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if not task_job: logger.error(f"Job {job_id} not found in task session") return task_job.status = 'running' task_job.started_at = datetime.utcnow() task_db.commit() from exploitdb_client_local import ExploitDBLocalClient client = ExploitDBLocalClient(task_db) if request.cve_id: # Sync specific CVE if job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_exploits(request.cve_id) logger.info(f"ExploitDB sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_exploitdb( batch_size=request.batch_size, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) logger.info(f"ExploitDB bulk sync completed: {result}") # Update job status if not cancelled if not job_cancellation_flags.get(job_id, False): task_job.status = 'completed' task_job.completed_at = datetime.utcnow() task_db.commit() except Exception as e: if not job_cancellation_flags.get(job_id, False): # Get the job again in case it was modified task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if task_job: task_job.status = 'failed' task_job.error_message = str(e) task_job.completed_at = datetime.utcnow() task_db.commit() logger.error(f"ExploitDB sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking and close the task session running_jobs.pop(job_id, None) job_cancellation_flags.pop(job_id, None) task_db.close() background_tasks.add_task(sync_task) return { "message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @app.post("/api/sync-cisa-kev") async def sync_cisa_kev(background_tasks: BackgroundTasks, request: CISAKEVSyncRequest, db: Session = Depends(get_db)): """Synchronize CISA Known Exploited Vulnerabilities data""" # Create job record job = BulkProcessingJob( job_type='cisa_kev_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) running_jobs[job_id] = job job_cancellation_flags[job_id] = False async def sync_task(): # Create a new database session for the background task task_db = SessionLocal() try: # Get the job in the new session task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if not task_job: logger.error(f"Job {job_id} not found in task session") return task_job.status = 'running' task_job.started_at = datetime.utcnow() task_db.commit() from cisa_kev_client import CISAKEVClient client = CISAKEVClient(task_db) if request.cve_id: # Sync specific CVE if job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_kev_data(request.cve_id) logger.info(f"CISA KEV sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_kev_data( batch_size=request.batch_size, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) logger.info(f"CISA KEV bulk sync completed: {result}") # Update job status if not cancelled if not job_cancellation_flags.get(job_id, False): task_job.status = 'completed' task_job.completed_at = datetime.utcnow() task_db.commit() except Exception as e: if not job_cancellation_flags.get(job_id, False): # Get the job again in case it was modified task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if task_job: task_job.status = 'failed' task_job.error_message = str(e) task_job.completed_at = datetime.utcnow() task_db.commit() logger.error(f"CISA KEV sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking and close the task session running_jobs.pop(job_id, None) job_cancellation_flags.pop(job_id, None) task_db.close() background_tasks.add_task(sync_task) return { "message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @app.post("/api/sync-references") async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start reference data synchronization""" try: from reference_client import ReferenceClient client = ReferenceClient(db) # Create job ID job_id = str(uuid.uuid4()) # Add job to tracking running_jobs[job_id] = { 'type': 'reference_sync', 'status': 'running', 'cve_id': request.cve_id, 'batch_size': request.batch_size, 'max_cves': request.max_cves, 'force_resync': request.force_resync, 'started_at': datetime.utcnow() } # Create cancellation flag job_cancellation_flags[job_id] = False async def sync_task(): try: if request.cve_id: # Single CVE sync result = await client.sync_cve_references(request.cve_id) running_jobs[job_id]['result'] = result running_jobs[job_id]['status'] = 'completed' else: # Bulk sync result = await client.bulk_sync_references( batch_size=request.batch_size, max_cves=request.max_cves, force_resync=request.force_resync, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) running_jobs[job_id]['result'] = result running_jobs[job_id]['status'] = 'completed' running_jobs[job_id]['completed_at'] = datetime.utcnow() except Exception as e: logger.error(f"Reference sync task failed: {e}") running_jobs[job_id]['status'] = 'failed' running_jobs[job_id]['error'] = str(e) running_jobs[job_id]['completed_at'] = datetime.utcnow() finally: # Clean up cancellation flag job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size, "max_cves": request.max_cves, "force_resync": request.force_resync } except Exception as e: logger.error(f"Failed to start reference sync: {e}") raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}") @app.get("/api/reference-stats") async def get_reference_stats(db: Session = Depends(get_db)): """Get reference synchronization statistics""" try: from reference_client import ReferenceClient client = ReferenceClient(db) # Get sync status status = await client.get_reference_sync_status() # Get quality distribution from reference data quality_distribution = {} from sqlalchemy import text cves_with_references = db.query(CVE).filter( text("reference_data::text LIKE '%\"reference_analysis\"%'") ).all() for cve in cves_with_references: if cve.reference_data and 'reference_analysis' in cve.reference_data: ref_analysis = cve.reference_data['reference_analysis'] high_conf_refs = ref_analysis.get('high_confidence_references', 0) total_refs = ref_analysis.get('reference_count', 0) if total_refs > 0: quality_ratio = high_conf_refs / total_refs if quality_ratio >= 0.8: quality_tier = 'excellent' elif quality_ratio >= 0.6: quality_tier = 'good' elif quality_ratio >= 0.4: quality_tier = 'fair' else: quality_tier = 'poor' quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 # Get reference type distribution reference_type_distribution = {} for cve in cves_with_references: if cve.reference_data and 'reference_analysis' in cve.reference_data: ref_analysis = cve.reference_data['reference_analysis'] ref_types = ref_analysis.get('reference_types', []) for ref_type in ref_types: reference_type_distribution[ref_type] = reference_type_distribution.get(ref_type, 0) + 1 return { 'reference_sync_status': status, 'quality_distribution': quality_distribution, 'reference_type_distribution': reference_type_distribution, 'total_with_reference_analysis': len(cves_with_references), 'source': 'reference_extraction' } except Exception as e: logger.error(f"Failed to get reference stats: {e}") raise HTTPException(status_code=500, detail=f"Failed to get reference stats: {str(e)}") @app.get("/api/exploitdb-stats") async def get_exploitdb_stats(db: Session = Depends(get_db)): """Get ExploitDB-related statistics""" try: from exploitdb_client_local import ExploitDBLocalClient client = ExploitDBLocalClient(db) # Get sync status status = await client.get_exploitdb_sync_status() # Get quality distribution from ExploitDB data quality_distribution = {} from sqlalchemy import text cves_with_exploitdb = db.query(CVE).filter( text("poc_data::text LIKE '%\"exploitdb\"%'") ).all() for cve in cves_with_exploitdb: if cve.poc_data and 'exploitdb' in cve.poc_data: exploits = cve.poc_data['exploitdb'].get('exploits', []) for exploit in exploits: quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown') quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 # Get category distribution category_distribution = {} for cve in cves_with_exploitdb: if cve.poc_data and 'exploitdb' in cve.poc_data: exploits = cve.poc_data['exploitdb'].get('exploits', []) for exploit in exploits: category = exploit.get('category', 'unknown') category_distribution[category] = category_distribution.get(category, 0) + 1 return { "exploitdb_sync_status": status, "quality_distribution": quality_distribution, "category_distribution": category_distribution, "total_exploitdb_cves": len(cves_with_exploitdb), "total_exploits": sum( len(cve.poc_data.get('exploitdb', {}).get('exploits', [])) for cve in cves_with_exploitdb if cve.poc_data and 'exploitdb' in cve.poc_data ) } except Exception as e: logger.error(f"Error getting ExploitDB stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/github-poc-stats") async def get_github_poc_stats(db: Session = Depends(get_db)): """Get GitHub PoC-related statistics""" try: # Get basic statistics github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count() cves_with_github_pocs = db.query(CVE).filter( CVE.poc_data.isnot(None), # Check if poc_data exists func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).count() # Get quality distribution quality_distribution = {} try: quality_results = db.query( func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'), func.count().label('count') ).filter( CVE.poc_data.isnot(None), func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).group_by('tier').all() for tier, count in quality_results: if tier: quality_distribution[tier] = count except Exception as e: logger.warning(f"Error getting quality distribution: {e}") quality_distribution = {} # Calculate average quality score try: avg_quality = db.query( func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer)) ).filter( CVE.poc_data.isnot(None), func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).scalar() or 0 except Exception as e: logger.warning(f"Error calculating average quality: {e}") avg_quality = 0 return { 'github_poc_rules': github_poc_rules, 'cves_with_github_pocs': cves_with_github_pocs, 'quality_distribution': quality_distribution, 'average_quality_score': float(avg_quality) if avg_quality else 0, 'source': 'github_poc' } except Exception as e: logger.error(f"Error getting GitHub PoC stats: {e}") return {"error": str(e)} @app.get("/api/github-poc-status") async def get_github_poc_status(db: Session = Depends(get_db)): """Get GitHub PoC data availability status""" try: client = GitHubPoCClient(db) # Check if GitHub PoC data is available github_poc_data = client.load_github_poc_data() return { 'github_poc_data_available': len(github_poc_data) > 0, 'total_cves_with_pocs': len(github_poc_data), 'sample_cve_ids': list(github_poc_data.keys())[:10], # First 10 CVE IDs 'data_path': str(client.github_poc_path), 'path_exists': client.github_poc_path.exists() } except Exception as e: logger.error(f"Error checking GitHub PoC status: {e}") return {"error": str(e)} @app.get("/api/cisa-kev-stats") async def get_cisa_kev_stats(db: Session = Depends(get_db)): """Get CISA KEV-related statistics""" try: from cisa_kev_client import CISAKEVClient client = CISAKEVClient(db) # Get sync status status = await client.get_kev_sync_status() # Get threat level distribution from CISA KEV data threat_level_distribution = {} from sqlalchemy import text cves_with_kev = db.query(CVE).filter( text("poc_data::text LIKE '%\"cisa_kev\"%'") ).all() for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) threat_level = vuln_data.get('threat_level', 'unknown') threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1 # Get vulnerability category distribution category_distribution = {} for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) category = vuln_data.get('vulnerability_category', 'unknown') category_distribution[category] = category_distribution.get(category, 0) + 1 # Get ransomware usage statistics ransomware_stats = {'known': 0, 'unknown': 0} for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower() if ransomware_use == 'known': ransomware_stats['known'] += 1 else: ransomware_stats['unknown'] += 1 # Calculate average threat score threat_scores = [] for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) threat_score = vuln_data.get('threat_score', 0) if threat_score: threat_scores.append(threat_score) avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0 return { "cisa_kev_sync_status": status, "threat_level_distribution": threat_level_distribution, "category_distribution": category_distribution, "ransomware_stats": ransomware_stats, "average_threat_score": round(avg_threat_score, 2), "total_kev_cves": len(cves_with_kev), "total_with_threat_scores": len(threat_scores) } except Exception as e: logger.error(f"Error getting CISA KEV stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/bulk-jobs") async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)): """Get bulk processing job status""" jobs = db.query(BulkProcessingJob).order_by( BulkProcessingJob.created_at.desc() ).limit(limit).all() result = [] for job in jobs: job_dict = { 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'metadata': job.job_metadata, 'started_at': job.started_at, 'completed_at': job.completed_at, 'created_at': job.created_at } result.append(job_dict) return result @app.get("/api/bulk-status") async def get_bulk_status(db: Session = Depends(get_db)): """Get comprehensive bulk processing status""" try: from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) status = await seeder.get_seeding_status() return status except Exception as e: logger.error(f"Error getting bulk status: {e}") return {"error": str(e)} @app.get("/api/poc-stats") async def get_poc_stats(db: Session = Depends(get_db)): """Get PoC-related statistics""" try: from nomi_sec_client import NomiSecClient client = NomiSecClient(db) stats = await client.get_sync_status() # Additional PoC statistics high_quality_cves = db.query(CVE).filter( CVE.poc_count > 0, func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60 ).count() stats.update({ 'high_quality_cves': high_quality_cves, 'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0 }) return stats except Exception as e: logger.error(f"Error getting PoC stats: {e}") return {"error": str(e)} @app.get("/api/cve2capec-stats") async def get_cve2capec_stats(): """Get CVE2CAPEC MITRE ATT&CK mapping statistics""" try: client = CVE2CAPECClient() stats = client.get_stats() return { "status": "success", "data": stats, "description": "CVE to MITRE ATT&CK technique mappings from CVE2CAPEC repository" } except Exception as e: logger.error(f"Error getting CVE2CAPEC stats: {e}") return {"error": str(e)} @app.post("/api/regenerate-rules") async def regenerate_sigma_rules(background_tasks: BackgroundTasks, request: RuleRegenRequest, db: Session = Depends(get_db)): """Regenerate SIGMA rules using enhanced nomi-sec data""" async def regenerate_task(): try: from enhanced_sigma_generator import EnhancedSigmaGenerator generator = EnhancedSigmaGenerator(db) # Get CVEs with PoC data cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all() rules_generated = 0 rules_updated = 0 for cve in cves_with_pocs: # Check if we should regenerate existing_rule = db.query(SigmaRule).filter( SigmaRule.cve_id == cve.cve_id ).first() if existing_rule and existing_rule.poc_source == 'nomi_sec' and not request.force: continue # Generate enhanced rule result = await generator.generate_enhanced_rule(cve) if result['success']: if existing_rule: rules_updated += 1 else: rules_generated += 1 logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated") except Exception as e: logger.error(f"Rule regeneration failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(regenerate_task) return { "message": "SIGMA rule regeneration started", "status": "started", "force": request.force } @app.post("/api/llm-enhanced-rules") async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Generate SIGMA rules using LLM API for enhanced analysis""" # Parse request parameters cve_id = request.get('cve_id') force = request.get('force', False) llm_provider = request.get('provider', os.getenv('LLM_PROVIDER')) llm_model = request.get('model', os.getenv('LLM_MODEL')) # Validation if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id): raise HTTPException(status_code=400, detail="Invalid CVE ID format") async def llm_generation_task(): """Background task for LLM-enhanced rule generation""" try: from enhanced_sigma_generator import EnhancedSigmaGenerator generator = EnhancedSigmaGenerator(db, llm_provider, llm_model) # Process specific CVE or all CVEs with PoC data if cve_id: cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() if not cve: logger.error(f"CVE {cve_id} not found") return cves_to_process = [cve] else: # Process CVEs with PoC data that either have no rules or force update query = db.query(CVE).filter(CVE.poc_count > 0) if not force: # Only process CVEs without existing LLM-generated rules existing_llm_rules = db.query(SigmaRule).filter( SigmaRule.detection_type.like('llm_%') ).all() existing_cve_ids = {rule.cve_id for rule in existing_llm_rules} cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids] else: cves_to_process = query.all() logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}") rules_generated = 0 rules_updated = 0 failures = 0 for cve in cves_to_process: try: # Check if CVE has sufficient PoC data if not cve.poc_data or not cve.poc_count: logger.debug(f"Skipping {cve.cve_id} - no PoC data") continue # Generate LLM-enhanced rule result = await generator.generate_enhanced_rule(cve, use_llm=True) if result.get('success'): if result.get('updated'): rules_updated += 1 else: rules_generated += 1 logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}") else: failures += 1 logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}") except Exception as e: failures += 1 logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") continue logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures") except Exception as e: logger.error(f"LLM-enhanced rule generation failed: {e}") import traceback traceback.print_exc() background_tasks.add_task(llm_generation_task) return { "message": "LLM-enhanced SIGMA rule generation started", "status": "started", "cve_id": cve_id, "force": force, "provider": llm_provider, "model": llm_model, "note": "Requires appropriate LLM API key to be set" } @app.get("/api/llm-status") async def get_llm_status(): """Check LLM API availability status""" try: from llm_client import LLMClient # Get current provider configuration provider = os.getenv('LLM_PROVIDER') model = os.getenv('LLM_MODEL') client = LLMClient(provider=provider, model=model) provider_info = client.get_provider_info() # Get all available providers all_providers = LLMClient.get_available_providers() return { "current_provider": provider_info, "available_providers": all_providers, "status": "ready" if client.is_available() else "unavailable" } except Exception as e: logger.error(f"Error checking LLM status: {e}") return { "current_provider": {"provider": "unknown", "available": False}, "available_providers": [], "status": "error", "error": str(e) } @app.post("/api/llm-switch") async def switch_llm_provider(request: dict): """Switch LLM provider and model""" try: from llm_client import LLMClient provider = request.get('provider') model = request.get('model') if not provider: raise HTTPException(status_code=400, detail="Provider is required") # Validate provider if provider not in LLMClient.SUPPORTED_PROVIDERS: raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}") # Test the new configuration client = LLMClient(provider=provider, model=model) if not client.is_available(): raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured") # Update environment variables (note: this only affects the current session) os.environ['LLM_PROVIDER'] = provider if model: os.environ['LLM_MODEL'] = model provider_info = client.get_provider_info() return { "message": f"Switched to {provider}", "provider_info": provider_info, "status": "success" } except HTTPException: raise except Exception as e: logger.error(f"Error switching LLM provider: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/cancel-job/{job_id}") async def cancel_job(job_id: str, db: Session = Depends(get_db)): """Cancel a running job""" try: # Find the job in the database job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Job not found") if job.status not in ['pending', 'running']: raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") # Set cancellation flag job_cancellation_flags[job_id] = True # Update job status job.status = 'cancelled' job.cancelled_at = datetime.utcnow() job.error_message = "Job cancelled by user" db.commit() logger.info(f"Job {job_id} cancellation requested") return { "message": f"Job {job_id} cancellation requested", "status": "cancelled", "job_id": job_id } except HTTPException: raise except Exception as e: logger.error(f"Error cancelling job {job_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/running-jobs") async def get_running_jobs(db: Session = Depends(get_db)): """Get all currently running jobs""" try: jobs = db.query(BulkProcessingJob).filter( BulkProcessingJob.status.in_(['pending', 'running']) ).order_by(BulkProcessingJob.created_at.desc()).all() result = [] for job in jobs: result.append({ 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'started_at': job.started_at, 'created_at': job.created_at, 'can_cancel': job.status in ['pending', 'running'] }) return result except Exception as e: logger.error(f"Error getting running jobs: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/ollama-pull-model") async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks): """Pull an Ollama model""" try: from llm_client import LLMClient model = request.get('model') if not model: raise HTTPException(status_code=400, detail="Model name is required") # Create a background task to pull the model def pull_model_task(): try: client = LLMClient(provider='ollama', model=model) base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') if client._pull_ollama_model(base_url, model): logger.info(f"Successfully pulled Ollama model: {model}") else: logger.error(f"Failed to pull Ollama model: {model}") except Exception as e: logger.error(f"Error in model pull task: {e}") background_tasks.add_task(pull_model_task) return { "message": f"Started pulling model {model}", "status": "started", "model": model } except Exception as e: logger.error(f"Error starting model pull: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/ollama-models") async def get_ollama_models(): """Get available Ollama models""" try: from llm_client import LLMClient client = LLMClient(provider='ollama') available_models = client._get_ollama_available_models() return { "available_models": available_models, "total_models": len(available_models), "status": "success" } except Exception as e: logger.error(f"Error getting Ollama models: {e}") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # SCHEDULER ENDPOINTS # ============================================================================ class SchedulerControlRequest(BaseModel): action: str # 'start', 'stop', 'restart' class JobControlRequest(BaseModel): job_name: str action: str # 'enable', 'disable', 'trigger' class UpdateScheduleRequest(BaseModel): job_name: str schedule: str # Cron expression @app.get("/api/scheduler/status") async def get_scheduler_status(): """Get scheduler status and job information""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() status = scheduler.get_job_status() return { "scheduler_status": status, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Error getting scheduler status: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/scheduler/control") async def control_scheduler(request: SchedulerControlRequest): """Control scheduler (start/stop/restart)""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() if request.action == 'start': scheduler.start() message = "Scheduler started" elif request.action == 'stop': scheduler.stop() message = "Scheduler stopped" elif request.action == 'restart': scheduler.stop() scheduler.start() message = "Scheduler restarted" else: raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") return { "message": message, "action": request.action, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Error controlling scheduler: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/scheduler/job/control") async def control_job(request: JobControlRequest): """Control individual jobs (enable/disable/trigger)""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() if request.action == 'enable': success = scheduler.enable_job(request.job_name) message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found" elif request.action == 'disable': success = scheduler.disable_job(request.job_name) message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found" elif request.action == 'trigger': success = scheduler.trigger_job(request.job_name) message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}" else: raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") return { "message": message, "job_name": request.job_name, "action": request.action, "success": success, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Error controlling job: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/scheduler/job/schedule") async def update_job_schedule(request: UpdateScheduleRequest): """Update job schedule""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() success = scheduler.update_job_schedule(request.job_name, request.schedule) if success: # Get updated job info job_status = scheduler.get_job_status(request.job_name) return { "message": f"Schedule updated for job {request.job_name}", "job_name": request.job_name, "new_schedule": request.schedule, "next_run": job_status.get("next_run"), "success": True, "timestamp": datetime.utcnow().isoformat() } else: raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}") except Exception as e: logger.error(f"Error updating job schedule: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/scheduler/job/{job_name}") async def get_job_status(job_name: str): """Get status of a specific job""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() status = scheduler.get_job_status(job_name) if "error" in status: raise HTTPException(status_code=404, detail=status["error"]) return { "job_status": status, "timestamp": datetime.utcnow().isoformat() } except HTTPException: raise except Exception as e: logger.error(f"Error getting job status: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/scheduler/reload") async def reload_scheduler_config(): """Reload scheduler configuration from file""" try: from job_scheduler import get_scheduler scheduler = get_scheduler() success = scheduler.reload_config() if success: return { "message": "Scheduler configuration reloaded successfully", "success": True, "timestamp": datetime.utcnow().isoformat() } else: raise HTTPException(status_code=500, detail="Failed to reload configuration") except Exception as e: logger.error(f"Error reloading scheduler config: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)