from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.dialects.postgresql import UUID import uuid from datetime import datetime, timedelta import requests import json import re import os from typing import List, Optional from pydantic import BaseModel import asyncio from contextlib import asynccontextmanager # Database setup DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db") engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # Database Models class CVE(Base): __tablename__ = "cves" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20), unique=True, nullable=False) description = Column(Text) cvss_score = Column(DECIMAL(3, 1)) severity = Column(String(20)) published_date = Column(TIMESTAMP) modified_date = Column(TIMESTAMP) affected_products = Column(ARRAY(String)) reference_urls = Column(ARRAY(String)) created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class SigmaRule(Base): __tablename__ = "sigma_rules" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) cve_id = Column(String(20)) rule_name = Column(String(255), nullable=False) rule_content = Column(Text, nullable=False) detection_type = Column(String(50)) log_source = Column(String(100)) confidence_level = Column(String(20)) auto_generated = Column(Boolean, default=True) created_at = Column(TIMESTAMP, default=datetime.utcnow) updated_at = Column(TIMESTAMP, default=datetime.utcnow) class RuleTemplate(Base): __tablename__ = "rule_templates" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) template_name = Column(String(255), nullable=False) template_content = Column(Text, nullable=False) applicable_product_patterns = Column(ARRAY(String)) description = Column(Text) created_at = Column(TIMESTAMP, default=datetime.utcnow) # Pydantic models class CVEResponse(BaseModel): id: str cve_id: str description: Optional[str] cvss_score: Optional[float] severity: Optional[str] published_date: Optional[datetime] affected_products: Optional[List[str]] reference_urls: Optional[List[str]] class Config: from_attributes = True class SigmaRuleResponse(BaseModel): id: str cve_id: str rule_name: str rule_content: str detection_type: Optional[str] log_source: Optional[str] confidence_level: Optional[str] auto_generated: bool created_at: datetime class Config: from_attributes = True # CVE and SIGMA Rule Generator Service class CVESigmaService: def __init__(self, db: Session): self.db = db self.nvd_api_key = os.getenv("NVD_API_KEY") async def fetch_recent_cves(self, days_back: int = 7): """Fetch recent CVEs from NVD API""" end_date = datetime.utcnow() start_date = end_date - timedelta(days=days_back) url = "https://services.nvd.nist.gov/rest/json/cves/2.0" params = { "pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", "resultsPerPage": 100 } headers = {} if self.nvd_api_key: headers["apiKey"] = self.nvd_api_key try: response = requests.get(url, params=params, headers=headers, timeout=30) response.raise_for_status() data = response.json() new_cves = [] for vuln in data.get("vulnerabilities", []): cve_data = vuln.get("cve", {}) cve_id = cve_data.get("id") # Check if CVE already exists existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first() if existing: continue # Extract CVE information description = "" if cve_data.get("descriptions"): description = cve_data["descriptions"][0].get("value", "") cvss_score = None severity = None if cve_data.get("metrics", {}).get("cvssMetricV31"): cvss_data = cve_data["metrics"]["cvssMetricV31"][0] cvss_score = cvss_data.get("cvssData", {}).get("baseScore") severity = cvss_data.get("cvssData", {}).get("baseSeverity") affected_products = [] if cve_data.get("configurations"): for config in cve_data["configurations"]: for node in config.get("nodes", []): for cpe_match in node.get("cpeMatch", []): if cpe_match.get("vulnerable"): affected_products.append(cpe_match.get("criteria", "")) reference_urls = [] if cve_data.get("references"): reference_urls = [ref.get("url", "") for ref in cve_data["references"]] cve_obj = CVE( cve_id=cve_id, description=description, cvss_score=cvss_score, severity=severity, published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")), modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")), affected_products=affected_products, reference_urls=reference_urls ) self.db.add(cve_obj) new_cves.append(cve_obj) self.db.commit() return new_cves except Exception as e: print(f"Error fetching CVEs: {str(e)}") return [] def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]: """Generate SIGMA rule based on CVE data""" if not cve.description: return None # Analyze CVE to determine appropriate template description_lower = cve.description.lower() affected_products = [p.lower() for p in (cve.affected_products or [])] template = self._select_template(description_lower, affected_products) if not template: return None # Generate rule content rule_content = self._populate_template(cve, template) if not rule_content: return None # Determine detection type and confidence detection_type = self._determine_detection_type(description_lower) confidence_level = self._calculate_confidence(cve) sigma_rule = SigmaRule( cve_id=cve.cve_id, rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection", rule_content=rule_content, detection_type=detection_type, log_source=template.template_name.lower().replace(" ", "_"), confidence_level=confidence_level, auto_generated=True ) self.db.add(sigma_rule) return sigma_rule def _select_template(self, description: str, affected_products: List[str]): """Select appropriate SIGMA rule template""" templates = self.db.query(RuleTemplate).all() # Simple template selection logic if any("windows" in p or "microsoft" in p for p in affected_products): if "process" in description or "execution" in description: return next((t for t in templates if "Process Execution" in t.template_name), None) elif "network" in description or "remote" in description: return next((t for t in templates if "Network Connection" in t.template_name), None) elif "file" in description or "write" in description: return next((t for t in templates if "File Modification" in t.template_name), None) # Default to process execution template return next((t for t in templates if "Process Execution" in t.template_name), None) def _populate_template(self, cve: CVE, template: RuleTemplate) -> str: """Populate template with CVE-specific data""" try: # Extract suspicious indicators from description suspicious_processes = self._extract_suspicious_indicators(cve.description, "process") suspicious_ports = self._extract_suspicious_indicators(cve.description, "port") file_patterns = self._extract_suspicious_indicators(cve.description, "file") # Determine severity level level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium" rule_content = template.template_content.format( title=f"CVE-{cve.cve_id} Exploitation Attempt", description=cve.description[:200] + "..." if len(cve.description) > 200 else cve.description, rule_id=str(uuid.uuid4()), date=datetime.utcnow().strftime("%Y/%m/%d"), cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}", cve_id=cve.cve_id.lower(), suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"], suspicious_ports=suspicious_ports or [4444, 8080, 9999], file_patterns=file_patterns or ["temp", "malware", "exploit"], level=level ) return rule_content except Exception as e: print(f"Error populating template: {str(e)}") return None def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List: """Extract suspicious indicators from CVE description""" if indicator_type == "process": # Look for executable names or process patterns exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE) return exe_pattern[:5] if exe_pattern else None elif indicator_type == "port": # Look for port numbers port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE) return [int(p) for p in port_pattern[:3]] if port_pattern else None elif indicator_type == "file": # Look for file extensions or paths file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE) return file_pattern[:5] if file_pattern else None return None def _determine_detection_type(self, description: str) -> str: """Determine detection type based on CVE description""" if "remote" in description or "network" in description: return "network" elif "process" in description or "execution" in description: return "process" elif "file" in description or "filesystem" in description: return "file" else: return "general" def _calculate_confidence(self, cve: CVE) -> str: """Calculate confidence level for the generated rule""" if cve.cvss_score and cve.cvss_score >= 9.0: return "high" elif cve.cvss_score and cve.cvss_score >= 7.0: return "medium" else: return "low" # Dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() # Background task to fetch CVEs and generate rules async def background_cve_fetch(): while True: try: db = SessionLocal() service = CVESigmaService(db) print("Fetching recent CVEs...") new_cves = await service.fetch_recent_cves() print(f"Found {len(new_cves)} new CVEs") for cve in new_cves: print(f"Generating SIGMA rule for {cve.cve_id}") sigma_rule = service.generate_sigma_rule(cve) if sigma_rule: print(f"Generated rule: {sigma_rule.rule_name}") db.commit() db.close() except Exception as e: print(f"Background task error: {str(e)}") # Wait 1 hour before next fetch await asyncio.sleep(3600) @asynccontextmanager async def lifespan(app: FastAPI): # Start background task task = asyncio.create_task(background_cve_fetch()) yield # Clean up task.cancel() # FastAPI app app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/api/cves", response_model=List[CVEResponse]) async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all() return cves @app.get("/api/cves/{cve_id}", response_model=CVEResponse) async def get_cve(cve_id: str, db: Session = Depends(get_db)): cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() if not cve: raise HTTPException(status_code=404, detail="CVE not found") return cve @app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse]) async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all() return rules @app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse]) async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)): rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all() return rules @app.post("/api/fetch-cves") async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): async def fetch_task(): service = CVESigmaService(db) new_cves = await service.fetch_recent_cves() for cve in new_cves: service.generate_sigma_rule(cve) db.commit() background_tasks.add_task(fetch_task) return {"message": "CVE fetch initiated"} @app.get("/api/stats") async def get_stats(db: Session = Depends(get_db)): total_cves = db.query(CVE).count() total_rules = db.query(SigmaRule).count() recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count() return { "total_cves": total_cves, "total_sigma_rules": total_rules, "recent_cves_7_days": recent_cves } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)