auto_sigma_rule_generator/backend/main.py

2373 lines
92 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.dialects.postgresql import UUID
import uuid
from datetime import datetime, timedelta
import requests
import json
import re
import os
from typing import List, Optional
from pydantic import BaseModel
import asyncio
from contextlib import asynccontextmanager
import base64
from github import Github
from urllib.parse import urlparse
import hashlib
import logging
import threading
from mcdevitt_poc_client import GitHubPoCClient
from cve2capec_client import CVE2CAPECClient
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global job tracking
running_jobs = {}
job_cancellation_flags = {}
# Database setup
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database Models
class CVE(Base):
__tablename__ = "cves"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20), unique=True, nullable=False)
description = Column(Text)
cvss_score = Column(DECIMAL(3, 1))
severity = Column(String(20))
published_date = Column(TIMESTAMP)
modified_date = Column(TIMESTAMP)
affected_products = Column(ARRAY(String))
reference_urls = Column(ARRAY(String))
# Bulk processing fields
data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual'
nvd_json_version = Column(String(10), default='2.0')
bulk_processed = Column(Boolean, default=False)
# nomi-sec PoC fields
poc_count = Column(Integer, default=0)
poc_data = Column(JSON) # Store nomi-sec PoC metadata
# Reference data fields
reference_data = Column(JSON) # Store extracted reference content and analysis
reference_sync_status = Column(String(20), default='pending') # 'pending', 'processing', 'completed', 'failed'
reference_last_synced = Column(TIMESTAMP)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class SigmaRule(Base):
__tablename__ = "sigma_rules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20))
rule_name = Column(String(255), nullable=False)
rule_content = Column(Text, nullable=False)
detection_type = Column(String(50))
log_source = Column(String(100))
confidence_level = Column(String(20))
auto_generated = Column(Boolean, default=True)
exploit_based = Column(Boolean, default=False)
github_repos = Column(ARRAY(String))
exploit_indicators = Column(Text) # JSON string of extracted indicators
# Enhanced fields for new data sources
poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual'
poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc.
nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class RuleTemplate(Base):
__tablename__ = "rule_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
template_name = Column(String(255), nullable=False)
template_content = Column(Text, nullable=False)
applicable_product_patterns = Column(ARRAY(String))
description = Column(Text)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
class BulkProcessingJob(Base):
__tablename__ = "bulk_processing_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update'
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
year = Column(Integer) # For year-based processing
total_items = Column(Integer, default=0)
processed_items = Column(Integer, default=0)
failed_items = Column(Integer, default=0)
error_message = Column(Text)
job_metadata = Column(JSON) # Additional job-specific data
started_at = Column(TIMESTAMP)
completed_at = Column(TIMESTAMP)
cancelled_at = Column(TIMESTAMP)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
# Pydantic models
class CVEResponse(BaseModel):
id: str
cve_id: str
description: Optional[str] = None
cvss_score: Optional[float] = None
severity: Optional[str] = None
published_date: Optional[datetime] = None
affected_products: Optional[List[str]] = None
reference_urls: Optional[List[str]] = None
class Config:
from_attributes = True
class SigmaRuleResponse(BaseModel):
id: str
cve_id: str
rule_name: str
rule_content: str
detection_type: Optional[str] = None
log_source: Optional[str] = None
confidence_level: Optional[str] = None
auto_generated: bool = True
exploit_based: bool = False
github_repos: Optional[List[str]] = None
exploit_indicators: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# Request models
class BulkSeedRequest(BaseModel):
start_year: int = 2002
end_year: Optional[int] = None
skip_nvd: bool = False
skip_nomi_sec: bool = True
class NomiSecSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 50
class GitHubPoCSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 50
class ExploitDBSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 30
class CISAKEVSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 100
class ReferenceSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 30
max_cves: Optional[int] = None
force_resync: bool = False
class RuleRegenRequest(BaseModel):
force: bool = False
# GitHub Exploit Analysis Service
class GitHubExploitAnalyzer:
def __init__(self):
self.github_token = os.getenv("GITHUB_TOKEN")
self.github = Github(self.github_token) if self.github_token else None
async def search_exploits_for_cve(self, cve_id: str) -> List[dict]:
"""Search GitHub for exploit code related to a CVE"""
if not self.github:
print(f"No GitHub token configured, skipping exploit search for {cve_id}")
return []
try:
print(f"Searching GitHub for exploits for {cve_id}")
# Search queries to find exploit code
search_queries = [
f"{cve_id} exploit",
f"{cve_id} poc",
f"{cve_id} vulnerability",
f'"{cve_id}" exploit code',
f"{cve_id.replace('-', '_')} exploit"
]
exploits = []
seen_repos = set()
for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits
try:
# Search repositories
repos = self.github.search_repositories(
query=query,
sort="updated",
order="desc"
)
# Get top 5 results per query
for repo in repos[:5]:
if repo.full_name in seen_repos:
continue
seen_repos.add(repo.full_name)
# Analyze repository
exploit_info = await self._analyze_repository(repo, cve_id)
if exploit_info:
exploits.append(exploit_info)
if len(exploits) >= 10: # Limit total exploits
break
if len(exploits) >= 10:
break
except Exception as e:
print(f"Error searching GitHub with query '{query}': {str(e)}")
continue
print(f"Found {len(exploits)} potential exploits for {cve_id}")
return exploits
except Exception as e:
print(f"Error searching GitHub for {cve_id}: {str(e)}")
return []
async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]:
"""Analyze a GitHub repository for exploit code"""
try:
# Check if repo name or description mentions the CVE
repo_text = f"{repo.name} {repo.description or ''}".lower()
if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text:
return None
# Get repository contents
exploit_files = []
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
try:
contents = repo.get_contents("")
for content in contents[:20]: # Limit files to analyze
if content.type == "file" and self._is_exploit_file(content.name):
file_analysis = await self._analyze_file_content(repo, content, cve_id)
if file_analysis:
exploit_files.append(file_analysis)
# Merge indicators
for key, values in file_analysis.get('indicators', {}).items():
if key in indicators:
indicators[key].update(values)
except Exception as e:
print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}")
if not exploit_files:
return None
return {
'repo_name': repo.full_name,
'repo_url': repo.html_url,
'description': repo.description,
'language': repo.language,
'stars': repo.stargazers_count,
'updated': repo.updated_at.isoformat(),
'files': exploit_files,
'indicators': {k: list(v) for k, v in indicators.items()}
}
except Exception as e:
print(f"Error analyzing repository {repo.full_name}: {str(e)}")
return None
def _is_exploit_file(self, filename: str) -> bool:
"""Check if a file is likely to contain exploit code"""
exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java']
exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack']
filename_lower = filename.lower()
# Check extension
if not any(filename_lower.endswith(ext) for ext in exploit_extensions):
return False
# Check filename for exploit-related terms
return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower
async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]:
"""Analyze individual file content for exploit indicators"""
try:
if file_content.size > 100000: # Skip files larger than 100KB
return None
# Decode file content
content = file_content.decoded_content.decode('utf-8', errors='ignore')
# Check if file actually mentions the CVE
if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower():
return None
indicators = self._extract_indicators_from_code(content, file_content.name)
if not any(indicators.values()):
return None
return {
'filename': file_content.name,
'path': file_content.path,
'size': file_content.size,
'indicators': indicators
}
except Exception as e:
print(f"Error analyzing file {file_content.name}: {str(e)}")
return None
def _extract_indicators_from_code(self, content: str, filename: str) -> dict:
"""Extract security indicators from exploit code"""
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
# Process patterns
process_patterns = [
r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']',
r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'system\s*\(\s*["\']([^"\']+)["\']',
r'exec\s*\(\s*["\']([^"\']+)["\']',
r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']'
]
# File patterns
file_patterns = [
r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']',
r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?',
r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)',
r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)'
]
# Registry patterns
registry_patterns = [
r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'HKEY_[A-Z_]+\\\\([^"\'\\]+)',
r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?'
]
# Network patterns
network_patterns = [
r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)',
r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)',
r'(?:http|https|ftp)://([^\s"\'<>]+)',
r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)'
]
# PowerShell patterns
powershell_patterns = [
r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?',
r'Start-Process\s+["\']?([^"\']+)["\']?',
r'Get-Process\s+["\']?([^"\']+)["\']?'
]
# Command patterns
command_patterns = [
r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)',
r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)'
]
# Extract indicators using regex patterns
patterns = {
'processes': process_patterns,
'files': file_patterns,
'registry': registry_patterns,
'powershell': powershell_patterns,
'commands': command_patterns
}
for category, pattern_list in patterns.items():
for pattern in pattern_list:
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
for match in matches:
if isinstance(match, tuple):
indicators[category].add(match[0])
else:
indicators[category].add(match)
# Special handling for network indicators
for pattern in network_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
if len(match) >= 2:
indicators['network'].add(f"{match[0]}:{match[1]}")
else:
indicators['network'].add(match[0])
else:
indicators['network'].add(match)
# Convert sets to lists and filter out empty/invalid indicators
cleaned_indicators = {}
for key, values in indicators.items():
cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200]
if cleaned_values:
cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category
return cleaned_indicators
class CVESigmaService:
def __init__(self, db: Session):
self.db = db
self.nvd_api_key = os.getenv("NVD_API_KEY")
async def fetch_recent_cves(self, days_back: int = 7):
"""Fetch recent CVEs from NVD API"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days_back)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
params = {
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"resultsPerPage": 100
}
headers = {}
if self.nvd_api_key:
headers["apiKey"] = self.nvd_api_key
try:
response = requests.get(url, params=params, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
new_cves = []
for vuln in data.get("vulnerabilities", []):
cve_data = vuln.get("cve", {})
cve_id = cve_data.get("id")
# Check if CVE already exists
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
if existing:
continue
# Extract CVE information
description = ""
if cve_data.get("descriptions"):
description = cve_data["descriptions"][0].get("value", "")
cvss_score = None
severity = None
if cve_data.get("metrics", {}).get("cvssMetricV31"):
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
affected_products = []
if cve_data.get("configurations"):
for config in cve_data["configurations"]:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable"):
affected_products.append(cpe_match.get("criteria", ""))
reference_urls = []
if cve_data.get("references"):
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
cve_obj = CVE(
cve_id=cve_id,
description=description,
cvss_score=cvss_score,
severity=severity,
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
affected_products=affected_products,
reference_urls=reference_urls
)
self.db.add(cve_obj)
new_cves.append(cve_obj)
self.db.commit()
return new_cves
except Exception as e:
print(f"Error fetching CVEs: {str(e)}")
return []
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
"""Generate SIGMA rule based on CVE data"""
if not cve.description:
return None
# Analyze CVE to determine appropriate template
description_lower = cve.description.lower()
affected_products = [p.lower() for p in (cve.affected_products or [])]
template = self._select_template(description_lower, affected_products)
if not template:
return None
# Generate rule content
rule_content = self._populate_template(cve, template)
if not rule_content:
return None
# Determine detection type and confidence
detection_type = self._determine_detection_type(description_lower)
confidence_level = self._calculate_confidence(cve)
sigma_rule = SigmaRule(
cve_id=cve.cve_id,
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
rule_content=rule_content,
detection_type=detection_type,
log_source=template.template_name.lower().replace(" ", "_"),
confidence_level=confidence_level,
auto_generated=True
)
self.db.add(sigma_rule)
return sigma_rule
def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None):
"""Select appropriate SIGMA rule template based on CVE and exploit analysis"""
templates = self.db.query(RuleTemplate).all()
# If we have exploit indicators, use them to determine the best template
if exploit_indicators:
if exploit_indicators.get('powershell'):
powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None)
if powershell_template:
return powershell_template
if exploit_indicators.get('network'):
network_template = next((t for t in templates if "Network Connection" in t.template_name), None)
if network_template:
return network_template
if exploit_indicators.get('files'):
file_template = next((t for t in templates if "File Modification" in t.template_name), None)
if file_template:
return file_template
if exploit_indicators.get('processes') or exploit_indicators.get('commands'):
process_template = next((t for t in templates if "Process Execution" in t.template_name), None)
if process_template:
return process_template
# Fallback to original logic
if any("windows" in p or "microsoft" in p for p in affected_products):
if "process" in description or "execution" in description:
return next((t for t in templates if "Process Execution" in t.template_name), None)
elif "network" in description or "remote" in description:
return next((t for t in templates if "Network Connection" in t.template_name), None)
elif "file" in description or "write" in description:
return next((t for t in templates if "File Modification" in t.template_name), None)
# Default to process execution template
return next((t for t in templates if "Process Execution" in t.template_name), None)
def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str:
"""Populate template with CVE-specific data and exploit indicators"""
try:
# Use exploit indicators if available, otherwise extract from description
if exploit_indicators:
suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', [])
suspicious_ports = []
file_patterns = exploit_indicators.get('files', [])
# Extract ports from network indicators
for net_indicator in exploit_indicators.get('network', []):
if ':' in str(net_indicator):
try:
port = int(str(net_indicator).split(':')[-1])
suspicious_ports.append(port)
except ValueError:
pass
else:
# Fallback to original extraction
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
# Determine severity level
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
# Create enhanced description
enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description
if exploit_indicators:
enhanced_description += " [Enhanced with GitHub exploit analysis]"
# Build tags
tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()]
if exploit_indicators:
tags.append("exploit.github")
rule_content = template.template_content.format(
title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection",
description=enhanced_description,
rule_id=str(uuid.uuid4()),
date=datetime.utcnow().strftime("%Y/%m/%d"),
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
cve_id=cve.cve_id.lower(),
tags="\n - ".join(tags),
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
file_patterns=file_patterns or ["temp", "malware", "exploit"],
level=level
)
return rule_content
except Exception as e:
print(f"Error populating template: {str(e)}")
return None
def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str:
"""Map CVE and exploit indicators to MITRE ATT&CK techniques"""
desc_lower = description.lower()
# Check exploit indicators first
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "t1059.001" # PowerShell
elif exploit_indicators.get('commands'):
return "t1059.003" # Windows Command Shell
elif exploit_indicators.get('network'):
return "t1071.001" # Web Protocols
elif exploit_indicators.get('files'):
return "t1105" # Ingress Tool Transfer
elif exploit_indicators.get('processes'):
return "t1106" # Native API
# Fallback to description analysis
if "powershell" in desc_lower:
return "t1059.001"
elif "command" in desc_lower or "cmd" in desc_lower:
return "t1059.003"
elif "network" in desc_lower or "remote" in desc_lower:
return "t1071.001"
elif "file" in desc_lower or "upload" in desc_lower:
return "t1105"
elif "process" in desc_lower or "execution" in desc_lower:
return "t1106"
else:
return "execution" # Generic
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
"""Extract suspicious indicators from CVE description"""
if indicator_type == "process":
# Look for executable names or process patterns
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
return exe_pattern[:5] if exe_pattern else None
elif indicator_type == "port":
# Look for port numbers
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
return [int(p) for p in port_pattern[:3]] if port_pattern else None
elif indicator_type == "file":
# Look for file extensions or paths
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
return file_pattern[:5] if file_pattern else None
return None
def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str:
"""Determine detection type based on CVE description and exploit indicators"""
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "powershell"
elif exploit_indicators.get('network'):
return "network"
elif exploit_indicators.get('files'):
return "file"
elif exploit_indicators.get('processes') or exploit_indicators.get('commands'):
return "process"
# Fallback to original logic
if "remote" in description or "network" in description:
return "network"
elif "process" in description or "execution" in description:
return "process"
elif "file" in description or "filesystem" in description:
return "file"
else:
return "general"
def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str:
"""Calculate confidence level for the generated rule"""
base_confidence = 0
# CVSS score contributes to confidence
if cve.cvss_score:
if cve.cvss_score >= 9.0:
base_confidence += 3
elif cve.cvss_score >= 7.0:
base_confidence += 2
else:
base_confidence += 1
# Exploit-based rules get higher confidence
if exploit_based:
base_confidence += 2
# Map to confidence levels
if base_confidence >= 4:
return "high"
elif base_confidence >= 2:
return "medium"
else:
return "low"
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Background task to fetch CVEs and generate rules
async def background_cve_fetch():
retry_count = 0
max_retries = 3
while True:
try:
db = SessionLocal()
service = CVESigmaService(db)
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
print(f"[{current_time}] Starting CVE fetch cycle...")
# Use a longer initial period (30 days) to find CVEs
new_cves = await service.fetch_recent_cves(days_back=30)
if new_cves:
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
rules_generated = 0
for cve in new_cves:
try:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
print(f"Generated SIGMA rule for {cve.cve_id}")
else:
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
except Exception as e:
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
db.commit()
print(f"Successfully generated {rules_generated} SIGMA rules")
retry_count = 0 # Reset retry count on success
else:
print("No new CVEs found in this cycle")
# After first successful run, reduce to 7 days for regular updates
if retry_count == 0:
print("Switching to 7-day lookback for future runs...")
db.close()
except Exception as e:
retry_count += 1
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
if retry_count >= max_retries:
print(f"Max retries reached, waiting longer before next attempt...")
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
retry_count = 0
else:
await asyncio.sleep(300) # Wait 5 minutes before retry
continue
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
wait_time = 3600 if retry_count == 0 else 1800
print(f"Next CVE fetch in {wait_time//60} minutes...")
await asyncio.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database
Base.metadata.create_all(bind=engine)
# Initialize rule templates
db = SessionLocal()
try:
existing_templates = db.query(RuleTemplate).count()
if existing_templates == 0:
logger.info("No rule templates found. Database initialization will handle template creation.")
except Exception as e:
logger.error(f"Error checking rule templates: {e}")
finally:
db.close()
# Initialize and start the job scheduler
try:
from job_scheduler import initialize_scheduler
from job_executors import register_all_executors
# Initialize scheduler
scheduler = initialize_scheduler()
scheduler.set_db_session_factory(SessionLocal)
# Register all job executors
register_all_executors(scheduler)
# Start the scheduler
scheduler.start()
logger.info("Job scheduler initialized and started")
except Exception as e:
logger.error(f"Error initializing job scheduler: {e}")
yield
# Shutdown
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
scheduler.stop()
logger.info("Job scheduler stopped")
except Exception as e:
logger.error(f"Error stopping job scheduler: {e}")
# FastAPI app
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/cves", response_model=List[CVEResponse])
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each CVE
result = []
for cve in cves:
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
result.append(CVEResponse(**cve_dict))
return result
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
return CVEResponse(**cve_dict)
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.post("/api/fetch-cves")
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
async def fetch_task():
try:
service = CVESigmaService(db)
print("Manual CVE fetch initiated...")
# Use 30 days for manual fetch to get more results
new_cves = await service.fetch_recent_cves(days_back=30)
rules_generated = 0
for cve in new_cves:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
db.commit()
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
except Exception as e:
print(f"Manual fetch error: {str(e)}")
import traceback
traceback.print_exc()
background_tasks.add_task(fetch_task)
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
@app.get("/api/test-nvd")
async def test_nvd_connection():
"""Test endpoint to check NVD API connectivity"""
try:
# Test with a simple request using current date
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=30)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
params = {
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"resultsPerPage": 5,
"startIndex": 0
}
headers = {
"User-Agent": "CVE-SIGMA-Generator/1.0",
"Accept": "application/json"
}
nvd_api_key = os.getenv("NVD_API_KEY")
if nvd_api_key:
headers["apiKey"] = nvd_api_key
print(f"Testing NVD API with URL: {url}")
print(f"Test params: {params}")
print(f"Test headers: {headers}")
response = requests.get(url, params=params, headers=headers, timeout=15)
result = {
"status": "success" if response.status_code == 200 else "error",
"status_code": response.status_code,
"has_api_key": bool(nvd_api_key),
"request_url": f"{url}?{requests.compat.urlencode(params)}",
"response_headers": dict(response.headers)
}
if response.status_code == 200:
data = response.json()
result.update({
"total_results": data.get("totalResults", 0),
"results_per_page": data.get("resultsPerPage", 0),
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
"message": "NVD API is accessible and returning data"
})
else:
result.update({
"error_message": response.text[:200],
"message": f"NVD API returned {response.status_code}"
})
# Try fallback without date filters if we get 404
if response.status_code == 404:
print("Trying fallback without date filters...")
fallback_params = {
"resultsPerPage": 5,
"startIndex": 0
}
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
result["fallback_status_code"] = fallback_response.status_code
if fallback_response.status_code == 200:
fallback_data = fallback_response.json()
result.update({
"fallback_success": True,
"fallback_total_results": fallback_data.get("totalResults", 0),
"message": "NVD API works without date filters"
})
return result
except Exception as e:
print(f"NVD API test error: {str(e)}")
return {
"status": "error",
"message": f"Failed to connect to NVD API: {str(e)}"
}
@app.get("/api/stats")
async def get_stats(db: Session = Depends(get_db)):
total_cves = db.query(CVE).count()
total_rules = db.query(SigmaRule).count()
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
# Enhanced stats with bulk processing info
bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count()
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count()
return {
"total_cves": total_cves,
"total_sigma_rules": total_rules,
"recent_cves_7_days": recent_cves,
"bulk_processed_cves": bulk_processed_cves,
"cves_with_pocs": cves_with_pocs,
"nomi_sec_rules": nomi_sec_rules,
"poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
"nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0
}
# New bulk processing endpoints
@app.post("/api/bulk-seed")
async def start_bulk_seed(background_tasks: BackgroundTasks,
request: BulkSeedRequest,
db: Session = Depends(get_db)):
"""Start bulk seeding process"""
async def bulk_seed_task():
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
result = await seeder.full_bulk_seed(
start_year=request.start_year,
end_year=request.end_year,
skip_nvd=request.skip_nvd,
skip_nomi_sec=request.skip_nomi_sec
)
logger.info(f"Bulk seed completed: {result}")
except Exception as e:
logger.error(f"Bulk seed failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(bulk_seed_task)
return {
"message": "Bulk seeding process started",
"status": "started",
"start_year": request.start_year,
"end_year": request.end_year or datetime.now().year,
"skip_nvd": request.skip_nvd,
"skip_nomi_sec": request.skip_nomi_sec
}
@app.post("/api/incremental-update")
async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start incremental update process"""
async def incremental_update_task():
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
result = await seeder.incremental_update()
logger.info(f"Incremental update completed: {result}")
except Exception as e:
logger.error(f"Incremental update failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(incremental_update_task)
return {
"message": "Incremental update process started",
"status": "started"
}
@app.post("/api/sync-nomi-sec")
async def sync_nomi_sec(background_tasks: BackgroundTasks,
request: NomiSecSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize nomi-sec PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(
batch_size=request.batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"Nomi-sec bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"Nomi-sec sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-github-pocs")
async def sync_github_pocs(background_tasks: BackgroundTasks,
request: GitHubPoCSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize GitHub PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='github_poc_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
client = GitHubPoCClient(db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
logger.info(f"GitHub PoC bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"GitHub PoC sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-exploitdb")
async def sync_exploitdb(background_tasks: BackgroundTasks,
request: ExploitDBSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize ExploitDB data from git mirror"""
# Create job record
job = BulkProcessingJob(
job_type='exploitdb_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(task_db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_exploits(request.cve_id)
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_exploitdb(
batch_size=request.batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"ExploitDB bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"ExploitDB sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-cisa-kev")
async def sync_cisa_kev(background_tasks: BackgroundTasks,
request: CISAKEVSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize CISA Known Exploited Vulnerabilities data"""
# Create job record
job = BulkProcessingJob(
job_type='cisa_kev_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(task_db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_kev_data(request.cve_id)
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_kev_data(
batch_size=request.batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"CISA KEV bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"CISA KEV sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-references")
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start reference data synchronization"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Create job ID
job_id = str(uuid.uuid4())
# Add job to tracking
running_jobs[job_id] = {
'type': 'reference_sync',
'status': 'running',
'cve_id': request.cve_id,
'batch_size': request.batch_size,
'max_cves': request.max_cves,
'force_resync': request.force_resync,
'started_at': datetime.utcnow()
}
# Create cancellation flag
job_cancellation_flags[job_id] = False
async def sync_task():
try:
if request.cve_id:
# Single CVE sync
result = await client.sync_cve_references(request.cve_id)
running_jobs[job_id]['result'] = result
running_jobs[job_id]['status'] = 'completed'
else:
# Bulk sync
result = await client.bulk_sync_references(
batch_size=request.batch_size,
max_cves=request.max_cves,
force_resync=request.force_resync,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
running_jobs[job_id]['result'] = result
running_jobs[job_id]['status'] = 'completed'
running_jobs[job_id]['completed_at'] = datetime.utcnow()
except Exception as e:
logger.error(f"Reference sync task failed: {e}")
running_jobs[job_id]['status'] = 'failed'
running_jobs[job_id]['error'] = str(e)
running_jobs[job_id]['completed_at'] = datetime.utcnow()
finally:
# Clean up cancellation flag
job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"max_cves": request.max_cves,
"force_resync": request.force_resync
}
except Exception as e:
logger.error(f"Failed to start reference sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
@app.get("/api/reference-stats")
async def get_reference_stats(db: Session = Depends(get_db)):
"""Get reference synchronization statistics"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Get sync status
status = await client.get_reference_sync_status()
# Get quality distribution from reference data
quality_distribution = {}
from sqlalchemy import text
cves_with_references = db.query(CVE).filter(
text("reference_data::text LIKE '%\"reference_analysis\"%'")
).all()
for cve in cves_with_references:
if cve.reference_data and 'reference_analysis' in cve.reference_data:
ref_analysis = cve.reference_data['reference_analysis']
high_conf_refs = ref_analysis.get('high_confidence_references', 0)
total_refs = ref_analysis.get('reference_count', 0)
if total_refs > 0:
quality_ratio = high_conf_refs / total_refs
if quality_ratio >= 0.8:
quality_tier = 'excellent'
elif quality_ratio >= 0.6:
quality_tier = 'good'
elif quality_ratio >= 0.4:
quality_tier = 'fair'
else:
quality_tier = 'poor'
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get reference type distribution
reference_type_distribution = {}
for cve in cves_with_references:
if cve.reference_data and 'reference_analysis' in cve.reference_data:
ref_analysis = cve.reference_data['reference_analysis']
ref_types = ref_analysis.get('reference_types', [])
for ref_type in ref_types:
reference_type_distribution[ref_type] = reference_type_distribution.get(ref_type, 0) + 1
return {
'reference_sync_status': status,
'quality_distribution': quality_distribution,
'reference_type_distribution': reference_type_distribution,
'total_with_reference_analysis': len(cves_with_references),
'source': 'reference_extraction'
}
except Exception as e:
logger.error(f"Failed to get reference stats: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get reference stats: {str(e)}")
@app.get("/api/exploitdb-stats")
async def get_exploitdb_stats(db: Session = Depends(get_db)):
"""Get ExploitDB-related statistics"""
try:
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(db)
# Get sync status
status = await client.get_exploitdb_sync_status()
# Get quality distribution from ExploitDB data
quality_distribution = {}
from sqlalchemy import text
cves_with_exploitdb = db.query(CVE).filter(
text("poc_data::text LIKE '%\"exploitdb\"%'")
).all()
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get category distribution
category_distribution = {}
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
category = exploit.get('category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
return {
"exploitdb_sync_status": status,
"quality_distribution": quality_distribution,
"category_distribution": category_distribution,
"total_exploitdb_cves": len(cves_with_exploitdb),
"total_exploits": sum(
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
for cve in cves_with_exploitdb
if cve.poc_data and 'exploitdb' in cve.poc_data
)
}
except Exception as e:
logger.error(f"Error getting ExploitDB stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/github-poc-stats")
async def get_github_poc_stats(db: Session = Depends(get_db)):
"""Get GitHub PoC-related statistics"""
try:
# Get basic statistics
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
cves_with_github_pocs = db.query(CVE).filter(
CVE.poc_data.isnot(None), # Check if poc_data exists
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).count()
# Get quality distribution
quality_distribution = {}
try:
quality_results = db.query(
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
func.count().label('count')
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).group_by('tier').all()
for tier, count in quality_results:
if tier:
quality_distribution[tier] = count
except Exception as e:
logger.warning(f"Error getting quality distribution: {e}")
quality_distribution = {}
# Calculate average quality score
try:
avg_quality = db.query(
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).scalar() or 0
except Exception as e:
logger.warning(f"Error calculating average quality: {e}")
avg_quality = 0
return {
'github_poc_rules': github_poc_rules,
'cves_with_github_pocs': cves_with_github_pocs,
'quality_distribution': quality_distribution,
'average_quality_score': float(avg_quality) if avg_quality else 0,
'source': 'github_poc'
}
except Exception as e:
logger.error(f"Error getting GitHub PoC stats: {e}")
return {"error": str(e)}
@app.get("/api/github-poc-status")
async def get_github_poc_status(db: Session = Depends(get_db)):
"""Get GitHub PoC data availability status"""
try:
client = GitHubPoCClient(db)
# Check if GitHub PoC data is available
github_poc_data = client.load_github_poc_data()
return {
'github_poc_data_available': len(github_poc_data) > 0,
'total_cves_with_pocs': len(github_poc_data),
'sample_cve_ids': list(github_poc_data.keys())[:10], # First 10 CVE IDs
'data_path': str(client.github_poc_path),
'path_exists': client.github_poc_path.exists()
}
except Exception as e:
logger.error(f"Error checking GitHub PoC status: {e}")
return {"error": str(e)}
@app.get("/api/cisa-kev-stats")
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
"""Get CISA KEV-related statistics"""
try:
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(db)
# Get sync status
status = await client.get_kev_sync_status()
# Get threat level distribution from CISA KEV data
threat_level_distribution = {}
from sqlalchemy import text
cves_with_kev = db.query(CVE).filter(
text("poc_data::text LIKE '%\"cisa_kev\"%'")
).all()
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_level = vuln_data.get('threat_level', 'unknown')
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
# Get vulnerability category distribution
category_distribution = {}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
category = vuln_data.get('vulnerability_category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
# Get ransomware usage statistics
ransomware_stats = {'known': 0, 'unknown': 0}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
if ransomware_use == 'known':
ransomware_stats['known'] += 1
else:
ransomware_stats['unknown'] += 1
# Calculate average threat score
threat_scores = []
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_score = vuln_data.get('threat_score', 0)
if threat_score:
threat_scores.append(threat_score)
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
return {
"cisa_kev_sync_status": status,
"threat_level_distribution": threat_level_distribution,
"category_distribution": category_distribution,
"ransomware_stats": ransomware_stats,
"average_threat_score": round(avg_threat_score, 2),
"total_kev_cves": len(cves_with_kev),
"total_with_threat_scores": len(threat_scores)
}
except Exception as e:
logger.error(f"Error getting CISA KEV stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/bulk-jobs")
async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)):
"""Get bulk processing job status"""
jobs = db.query(BulkProcessingJob).order_by(
BulkProcessingJob.created_at.desc()
).limit(limit).all()
result = []
for job in jobs:
job_dict = {
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'metadata': job.job_metadata,
'started_at': job.started_at,
'completed_at': job.completed_at,
'created_at': job.created_at
}
result.append(job_dict)
return result
@app.get("/api/bulk-status")
async def get_bulk_status(db: Session = Depends(get_db)):
"""Get comprehensive bulk processing status"""
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
status = await seeder.get_seeding_status()
return status
except Exception as e:
logger.error(f"Error getting bulk status: {e}")
return {"error": str(e)}
@app.get("/api/poc-stats")
async def get_poc_stats(db: Session = Depends(get_db)):
"""Get PoC-related statistics"""
try:
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
stats = await client.get_sync_status()
# Additional PoC statistics
high_quality_cves = db.query(CVE).filter(
CVE.poc_count > 0,
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60
).count()
stats.update({
'high_quality_cves': high_quality_cves,
'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0
})
return stats
except Exception as e:
logger.error(f"Error getting PoC stats: {e}")
return {"error": str(e)}
@app.get("/api/cve2capec-stats")
async def get_cve2capec_stats():
"""Get CVE2CAPEC MITRE ATT&CK mapping statistics"""
try:
client = CVE2CAPECClient()
stats = client.get_stats()
return {
"status": "success",
"data": stats,
"description": "CVE to MITRE ATT&CK technique mappings from CVE2CAPEC repository"
}
except Exception as e:
logger.error(f"Error getting CVE2CAPEC stats: {e}")
return {"error": str(e)}
@app.post("/api/regenerate-rules")
async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
request: RuleRegenRequest,
db: Session = Depends(get_db)):
"""Regenerate SIGMA rules using enhanced nomi-sec data"""
async def regenerate_task():
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db)
# Get CVEs with PoC data
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all()
rules_generated = 0
rules_updated = 0
for cve in cves_with_pocs:
# Check if we should regenerate
existing_rule = db.query(SigmaRule).filter(
SigmaRule.cve_id == cve.cve_id
).first()
if existing_rule and existing_rule.poc_source == 'nomi_sec' and not request.force:
continue
# Generate enhanced rule
result = await generator.generate_enhanced_rule(cve)
if result['success']:
if existing_rule:
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated")
except Exception as e:
logger.error(f"Rule regeneration failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(regenerate_task)
return {
"message": "SIGMA rule regeneration started",
"status": "started",
"force": request.force
}
@app.post("/api/llm-enhanced-rules")
async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Generate SIGMA rules using LLM API for enhanced analysis"""
# Parse request parameters
cve_id = request.get('cve_id')
force = request.get('force', False)
llm_provider = request.get('provider', os.getenv('LLM_PROVIDER'))
llm_model = request.get('model', os.getenv('LLM_MODEL'))
# Validation
if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id):
raise HTTPException(status_code=400, detail="Invalid CVE ID format")
async def llm_generation_task():
"""Background task for LLM-enhanced rule generation"""
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db, llm_provider, llm_model)
# Process specific CVE or all CVEs with PoC data
if cve_id:
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.error(f"CVE {cve_id} not found")
return
cves_to_process = [cve]
else:
# Process CVEs with PoC data that either have no rules or force update
query = db.query(CVE).filter(CVE.poc_count > 0)
if not force:
# Only process CVEs without existing LLM-generated rules
existing_llm_rules = db.query(SigmaRule).filter(
SigmaRule.detection_type.like('llm_%')
).all()
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids]
else:
cves_to_process = query.all()
logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}")
rules_generated = 0
rules_updated = 0
failures = 0
for cve in cves_to_process:
try:
# Check if CVE has sufficient PoC data
if not cve.poc_data or not cve.poc_count:
logger.debug(f"Skipping {cve.cve_id} - no PoC data")
continue
# Generate LLM-enhanced rule
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if result.get('updated'):
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
else:
failures += 1
logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}")
except Exception as e:
failures += 1
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
continue
logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
except Exception as e:
logger.error(f"LLM-enhanced rule generation failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(llm_generation_task)
return {
"message": "LLM-enhanced SIGMA rule generation started",
"status": "started",
"cve_id": cve_id,
"force": force,
"provider": llm_provider,
"model": llm_model,
"note": "Requires appropriate LLM API key to be set"
}
@app.get("/api/llm-status")
async def get_llm_status():
"""Check LLM API availability status"""
try:
from llm_client import LLMClient
# Get current provider configuration
provider = os.getenv('LLM_PROVIDER')
model = os.getenv('LLM_MODEL')
client = LLMClient(provider=provider, model=model)
provider_info = client.get_provider_info()
# Get all available providers
all_providers = LLMClient.get_available_providers()
return {
"current_provider": provider_info,
"available_providers": all_providers,
"status": "ready" if client.is_available() else "unavailable"
}
except Exception as e:
logger.error(f"Error checking LLM status: {e}")
return {
"current_provider": {"provider": "unknown", "available": False},
"available_providers": [],
"status": "error",
"error": str(e)
}
@app.post("/api/llm-switch")
async def switch_llm_provider(request: dict):
"""Switch LLM provider and model"""
try:
from llm_client import LLMClient
provider = request.get('provider')
model = request.get('model')
if not provider:
raise HTTPException(status_code=400, detail="Provider is required")
# Validate provider
if provider not in LLMClient.SUPPORTED_PROVIDERS:
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
# Test the new configuration
client = LLMClient(provider=provider, model=model)
if not client.is_available():
raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured")
# Update environment variables (note: this only affects the current session)
os.environ['LLM_PROVIDER'] = provider
if model:
os.environ['LLM_MODEL'] = model
provider_info = client.get_provider_info()
return {
"message": f"Switched to {provider}",
"provider_info": provider_info,
"status": "success"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error switching LLM provider: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job"""
try:
# Find the job in the database
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status not in ['pending', 'running']:
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
# Set cancellation flag
job_cancellation_flags[job_id] = True
# Update job status
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
db.commit()
logger.info(f"Job {job_id} cancellation requested")
return {
"message": f"Job {job_id} cancellation requested",
"status": "cancelled",
"job_id": job_id
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/running-jobs")
async def get_running_jobs(db: Session = Depends(get_db)):
"""Get all currently running jobs"""
try:
jobs = db.query(BulkProcessingJob).filter(
BulkProcessingJob.status.in_(['pending', 'running'])
).order_by(BulkProcessingJob.created_at.desc()).all()
result = []
for job in jobs:
result.append({
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'started_at': job.started_at,
'created_at': job.created_at,
'can_cancel': job.status in ['pending', 'running']
})
return result
except Exception as e:
logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ollama-pull-model")
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
"""Pull an Ollama model"""
try:
from llm_client import LLMClient
model = request.get('model')
if not model:
raise HTTPException(status_code=400, detail="Model name is required")
# Create a background task to pull the model
def pull_model_task():
try:
client = LLMClient(provider='ollama', model=model)
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
if client._pull_ollama_model(base_url, model):
logger.info(f"Successfully pulled Ollama model: {model}")
else:
logger.error(f"Failed to pull Ollama model: {model}")
except Exception as e:
logger.error(f"Error in model pull task: {e}")
background_tasks.add_task(pull_model_task)
return {
"message": f"Started pulling model {model}",
"status": "started",
"model": model
}
except Exception as e:
logger.error(f"Error starting model pull: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/ollama-models")
async def get_ollama_models():
"""Get available Ollama models"""
try:
from llm_client import LLMClient
client = LLMClient(provider='ollama')
available_models = client._get_ollama_available_models()
return {
"available_models": available_models,
"total_models": len(available_models),
"status": "success"
}
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# SCHEDULER ENDPOINTS
# ============================================================================
class SchedulerControlRequest(BaseModel):
action: str # 'start', 'stop', 'restart'
class JobControlRequest(BaseModel):
job_name: str
action: str # 'enable', 'disable', 'trigger'
class UpdateScheduleRequest(BaseModel):
job_name: str
schedule: str # Cron expression
@app.get("/api/scheduler/status")
async def get_scheduler_status():
"""Get scheduler status and job information"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
status = scheduler.get_job_status()
return {
"scheduler_status": status,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting scheduler status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/scheduler/control")
async def control_scheduler(request: SchedulerControlRequest):
"""Control scheduler (start/stop/restart)"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
if request.action == 'start':
scheduler.start()
message = "Scheduler started"
elif request.action == 'stop':
scheduler.stop()
message = "Scheduler stopped"
elif request.action == 'restart':
scheduler.stop()
scheduler.start()
message = "Scheduler restarted"
else:
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
return {
"message": message,
"action": request.action,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error controlling scheduler: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/scheduler/job/control")
async def control_job(request: JobControlRequest):
"""Control individual jobs (enable/disable/trigger)"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
if request.action == 'enable':
success = scheduler.enable_job(request.job_name)
message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found"
elif request.action == 'disable':
success = scheduler.disable_job(request.job_name)
message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found"
elif request.action == 'trigger':
success = scheduler.trigger_job(request.job_name)
message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}"
else:
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
return {
"message": message,
"job_name": request.job_name,
"action": request.action,
"success": success,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error controlling job: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/scheduler/job/schedule")
async def update_job_schedule(request: UpdateScheduleRequest):
"""Update job schedule"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
success = scheduler.update_job_schedule(request.job_name, request.schedule)
if success:
# Get updated job info
job_status = scheduler.get_job_status(request.job_name)
return {
"message": f"Schedule updated for job {request.job_name}",
"job_name": request.job_name,
"new_schedule": request.schedule,
"next_run": job_status.get("next_run"),
"success": True,
"timestamp": datetime.utcnow().isoformat()
}
else:
raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}")
except Exception as e:
logger.error(f"Error updating job schedule: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/scheduler/job/{job_name}")
async def get_job_status(job_name: str):
"""Get status of a specific job"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
status = scheduler.get_job_status(job_name)
if "error" in status:
raise HTTPException(status_code=404, detail=status["error"])
return {
"job_status": status,
"timestamp": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting job status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/scheduler/reload")
async def reload_scheduler_config():
"""Reload scheduler configuration from file"""
try:
from job_scheduler import get_scheduler
scheduler = get_scheduler()
success = scheduler.reload_config()
if success:
return {
"message": "Scheduler configuration reloaded successfully",
"success": True,
"timestamp": datetime.utcnow().isoformat()
}
else:
raise HTTPException(status_code=500, detail="Failed to reload configuration")
except Exception as e:
logger.error(f"Error reloading scheduler config: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)