auto_sigma_rule_generator/backend/main.py
2025-07-08 17:50:01 -05:00

1319 lines
52 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.dialects.postgresql import UUID
import uuid
from datetime import datetime, timedelta
import requests
import json
import re
import os
from typing import List, Optional
from pydantic import BaseModel
import asyncio
from contextlib import asynccontextmanager
import base64
from github import Github
from urllib.parse import urlparse
import hashlib
import logging
import threading
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global job tracking
running_jobs = {}
job_cancellation_flags = {}
# Database setup
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database Models
class CVE(Base):
__tablename__ = "cves"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20), unique=True, nullable=False)
description = Column(Text)
cvss_score = Column(DECIMAL(3, 1))
severity = Column(String(20))
published_date = Column(TIMESTAMP)
modified_date = Column(TIMESTAMP)
affected_products = Column(ARRAY(String))
reference_urls = Column(ARRAY(String))
# Bulk processing fields
data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual'
nvd_json_version = Column(String(10), default='2.0')
bulk_processed = Column(Boolean, default=False)
# nomi-sec PoC fields
poc_count = Column(Integer, default=0)
poc_data = Column(JSON) # Store nomi-sec PoC metadata
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class SigmaRule(Base):
__tablename__ = "sigma_rules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20))
rule_name = Column(String(255), nullable=False)
rule_content = Column(Text, nullable=False)
detection_type = Column(String(50))
log_source = Column(String(100))
confidence_level = Column(String(20))
auto_generated = Column(Boolean, default=True)
exploit_based = Column(Boolean, default=False)
github_repos = Column(ARRAY(String))
exploit_indicators = Column(Text) # JSON string of extracted indicators
# Enhanced fields for new data sources
poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual'
poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc.
nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class RuleTemplate(Base):
__tablename__ = "rule_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
template_name = Column(String(255), nullable=False)
template_content = Column(Text, nullable=False)
applicable_product_patterns = Column(ARRAY(String))
description = Column(Text)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
class BulkProcessingJob(Base):
__tablename__ = "bulk_processing_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update'
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
year = Column(Integer) # For year-based processing
total_items = Column(Integer, default=0)
processed_items = Column(Integer, default=0)
failed_items = Column(Integer, default=0)
error_message = Column(Text)
job_metadata = Column(JSON) # Additional job-specific data
started_at = Column(TIMESTAMP)
completed_at = Column(TIMESTAMP)
cancelled_at = Column(TIMESTAMP)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
# Pydantic models
class CVEResponse(BaseModel):
id: str
cve_id: str
description: Optional[str] = None
cvss_score: Optional[float] = None
severity: Optional[str] = None
published_date: Optional[datetime] = None
affected_products: Optional[List[str]] = None
reference_urls: Optional[List[str]] = None
class Config:
from_attributes = True
class SigmaRuleResponse(BaseModel):
id: str
cve_id: str
rule_name: str
rule_content: str
detection_type: Optional[str] = None
log_source: Optional[str] = None
confidence_level: Optional[str] = None
auto_generated: bool = True
exploit_based: bool = False
github_repos: Optional[List[str]] = None
exploit_indicators: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# GitHub Exploit Analysis Service
class GitHubExploitAnalyzer:
def __init__(self):
self.github_token = os.getenv("GITHUB_TOKEN")
self.github = Github(self.github_token) if self.github_token else None
async def search_exploits_for_cve(self, cve_id: str) -> List[dict]:
"""Search GitHub for exploit code related to a CVE"""
if not self.github:
print(f"No GitHub token configured, skipping exploit search for {cve_id}")
return []
try:
print(f"Searching GitHub for exploits for {cve_id}")
# Search queries to find exploit code
search_queries = [
f"{cve_id} exploit",
f"{cve_id} poc",
f"{cve_id} vulnerability",
f'"{cve_id}" exploit code',
f"{cve_id.replace('-', '_')} exploit"
]
exploits = []
seen_repos = set()
for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits
try:
# Search repositories
repos = self.github.search_repositories(
query=query,
sort="updated",
order="desc"
)
# Get top 5 results per query
for repo in repos[:5]:
if repo.full_name in seen_repos:
continue
seen_repos.add(repo.full_name)
# Analyze repository
exploit_info = await self._analyze_repository(repo, cve_id)
if exploit_info:
exploits.append(exploit_info)
if len(exploits) >= 10: # Limit total exploits
break
if len(exploits) >= 10:
break
except Exception as e:
print(f"Error searching GitHub with query '{query}': {str(e)}")
continue
print(f"Found {len(exploits)} potential exploits for {cve_id}")
return exploits
except Exception as e:
print(f"Error searching GitHub for {cve_id}: {str(e)}")
return []
async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]:
"""Analyze a GitHub repository for exploit code"""
try:
# Check if repo name or description mentions the CVE
repo_text = f"{repo.name} {repo.description or ''}".lower()
if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text:
return None
# Get repository contents
exploit_files = []
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
try:
contents = repo.get_contents("")
for content in contents[:20]: # Limit files to analyze
if content.type == "file" and self._is_exploit_file(content.name):
file_analysis = await self._analyze_file_content(repo, content, cve_id)
if file_analysis:
exploit_files.append(file_analysis)
# Merge indicators
for key, values in file_analysis.get('indicators', {}).items():
if key in indicators:
indicators[key].update(values)
except Exception as e:
print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}")
if not exploit_files:
return None
return {
'repo_name': repo.full_name,
'repo_url': repo.html_url,
'description': repo.description,
'language': repo.language,
'stars': repo.stargazers_count,
'updated': repo.updated_at.isoformat(),
'files': exploit_files,
'indicators': {k: list(v) for k, v in indicators.items()}
}
except Exception as e:
print(f"Error analyzing repository {repo.full_name}: {str(e)}")
return None
def _is_exploit_file(self, filename: str) -> bool:
"""Check if a file is likely to contain exploit code"""
exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java']
exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack']
filename_lower = filename.lower()
# Check extension
if not any(filename_lower.endswith(ext) for ext in exploit_extensions):
return False
# Check filename for exploit-related terms
return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower
async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]:
"""Analyze individual file content for exploit indicators"""
try:
if file_content.size > 100000: # Skip files larger than 100KB
return None
# Decode file content
content = file_content.decoded_content.decode('utf-8', errors='ignore')
# Check if file actually mentions the CVE
if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower():
return None
indicators = self._extract_indicators_from_code(content, file_content.name)
if not any(indicators.values()):
return None
return {
'filename': file_content.name,
'path': file_content.path,
'size': file_content.size,
'indicators': indicators
}
except Exception as e:
print(f"Error analyzing file {file_content.name}: {str(e)}")
return None
def _extract_indicators_from_code(self, content: str, filename: str) -> dict:
"""Extract security indicators from exploit code"""
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
# Process patterns
process_patterns = [
r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']',
r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'system\s*\(\s*["\']([^"\']+)["\']',
r'exec\s*\(\s*["\']([^"\']+)["\']',
r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']'
]
# File patterns
file_patterns = [
r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']',
r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?',
r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)',
r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)'
]
# Registry patterns
registry_patterns = [
r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'HKEY_[A-Z_]+\\\\([^"\'\\]+)',
r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?'
]
# Network patterns
network_patterns = [
r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)',
r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)',
r'(?:http|https|ftp)://([^\s"\'<>]+)',
r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)'
]
# PowerShell patterns
powershell_patterns = [
r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?',
r'Start-Process\s+["\']?([^"\']+)["\']?',
r'Get-Process\s+["\']?([^"\']+)["\']?'
]
# Command patterns
command_patterns = [
r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)',
r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)'
]
# Extract indicators using regex patterns
patterns = {
'processes': process_patterns,
'files': file_patterns,
'registry': registry_patterns,
'powershell': powershell_patterns,
'commands': command_patterns
}
for category, pattern_list in patterns.items():
for pattern in pattern_list:
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
for match in matches:
if isinstance(match, tuple):
indicators[category].add(match[0])
else:
indicators[category].add(match)
# Special handling for network indicators
for pattern in network_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
if len(match) >= 2:
indicators['network'].add(f"{match[0]}:{match[1]}")
else:
indicators['network'].add(match[0])
else:
indicators['network'].add(match)
# Convert sets to lists and filter out empty/invalid indicators
cleaned_indicators = {}
for key, values in indicators.items():
cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200]
if cleaned_values:
cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category
return cleaned_indicators
class CVESigmaService:
def __init__(self, db: Session):
self.db = db
self.nvd_api_key = os.getenv("NVD_API_KEY")
async def fetch_recent_cves(self, days_back: int = 7):
"""Fetch recent CVEs from NVD API"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days_back)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
params = {
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"resultsPerPage": 100
}
headers = {}
if self.nvd_api_key:
headers["apiKey"] = self.nvd_api_key
try:
response = requests.get(url, params=params, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
new_cves = []
for vuln in data.get("vulnerabilities", []):
cve_data = vuln.get("cve", {})
cve_id = cve_data.get("id")
# Check if CVE already exists
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
if existing:
continue
# Extract CVE information
description = ""
if cve_data.get("descriptions"):
description = cve_data["descriptions"][0].get("value", "")
cvss_score = None
severity = None
if cve_data.get("metrics", {}).get("cvssMetricV31"):
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
affected_products = []
if cve_data.get("configurations"):
for config in cve_data["configurations"]:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable"):
affected_products.append(cpe_match.get("criteria", ""))
reference_urls = []
if cve_data.get("references"):
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
cve_obj = CVE(
cve_id=cve_id,
description=description,
cvss_score=cvss_score,
severity=severity,
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
affected_products=affected_products,
reference_urls=reference_urls
)
self.db.add(cve_obj)
new_cves.append(cve_obj)
self.db.commit()
return new_cves
except Exception as e:
print(f"Error fetching CVEs: {str(e)}")
return []
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
"""Generate SIGMA rule based on CVE data"""
if not cve.description:
return None
# Analyze CVE to determine appropriate template
description_lower = cve.description.lower()
affected_products = [p.lower() for p in (cve.affected_products or [])]
template = self._select_template(description_lower, affected_products)
if not template:
return None
# Generate rule content
rule_content = self._populate_template(cve, template)
if not rule_content:
return None
# Determine detection type and confidence
detection_type = self._determine_detection_type(description_lower)
confidence_level = self._calculate_confidence(cve)
sigma_rule = SigmaRule(
cve_id=cve.cve_id,
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
rule_content=rule_content,
detection_type=detection_type,
log_source=template.template_name.lower().replace(" ", "_"),
confidence_level=confidence_level,
auto_generated=True
)
self.db.add(sigma_rule)
return sigma_rule
def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None):
"""Select appropriate SIGMA rule template based on CVE and exploit analysis"""
templates = self.db.query(RuleTemplate).all()
# If we have exploit indicators, use them to determine the best template
if exploit_indicators:
if exploit_indicators.get('powershell'):
powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None)
if powershell_template:
return powershell_template
if exploit_indicators.get('network'):
network_template = next((t for t in templates if "Network Connection" in t.template_name), None)
if network_template:
return network_template
if exploit_indicators.get('files'):
file_template = next((t for t in templates if "File Modification" in t.template_name), None)
if file_template:
return file_template
if exploit_indicators.get('processes') or exploit_indicators.get('commands'):
process_template = next((t for t in templates if "Process Execution" in t.template_name), None)
if process_template:
return process_template
# Fallback to original logic
if any("windows" in p or "microsoft" in p for p in affected_products):
if "process" in description or "execution" in description:
return next((t for t in templates if "Process Execution" in t.template_name), None)
elif "network" in description or "remote" in description:
return next((t for t in templates if "Network Connection" in t.template_name), None)
elif "file" in description or "write" in description:
return next((t for t in templates if "File Modification" in t.template_name), None)
# Default to process execution template
return next((t for t in templates if "Process Execution" in t.template_name), None)
def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str:
"""Populate template with CVE-specific data and exploit indicators"""
try:
# Use exploit indicators if available, otherwise extract from description
if exploit_indicators:
suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', [])
suspicious_ports = []
file_patterns = exploit_indicators.get('files', [])
# Extract ports from network indicators
for net_indicator in exploit_indicators.get('network', []):
if ':' in str(net_indicator):
try:
port = int(str(net_indicator).split(':')[-1])
suspicious_ports.append(port)
except ValueError:
pass
else:
# Fallback to original extraction
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
# Determine severity level
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
# Create enhanced description
enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description
if exploit_indicators:
enhanced_description += " [Enhanced with GitHub exploit analysis]"
# Build tags
tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()]
if exploit_indicators:
tags.append("exploit.github")
rule_content = template.template_content.format(
title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection",
description=enhanced_description,
rule_id=str(uuid.uuid4()),
date=datetime.utcnow().strftime("%Y/%m/%d"),
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
cve_id=cve.cve_id.lower(),
tags="\n - ".join(tags),
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
file_patterns=file_patterns or ["temp", "malware", "exploit"],
level=level
)
return rule_content
except Exception as e:
print(f"Error populating template: {str(e)}")
return None
def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str:
"""Map CVE and exploit indicators to MITRE ATT&CK techniques"""
desc_lower = description.lower()
# Check exploit indicators first
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "t1059.001" # PowerShell
elif exploit_indicators.get('commands'):
return "t1059.003" # Windows Command Shell
elif exploit_indicators.get('network'):
return "t1071.001" # Web Protocols
elif exploit_indicators.get('files'):
return "t1105" # Ingress Tool Transfer
elif exploit_indicators.get('processes'):
return "t1106" # Native API
# Fallback to description analysis
if "powershell" in desc_lower:
return "t1059.001"
elif "command" in desc_lower or "cmd" in desc_lower:
return "t1059.003"
elif "network" in desc_lower or "remote" in desc_lower:
return "t1071.001"
elif "file" in desc_lower or "upload" in desc_lower:
return "t1105"
elif "process" in desc_lower or "execution" in desc_lower:
return "t1106"
else:
return "execution" # Generic
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
"""Extract suspicious indicators from CVE description"""
if indicator_type == "process":
# Look for executable names or process patterns
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
return exe_pattern[:5] if exe_pattern else None
elif indicator_type == "port":
# Look for port numbers
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
return [int(p) for p in port_pattern[:3]] if port_pattern else None
elif indicator_type == "file":
# Look for file extensions or paths
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
return file_pattern[:5] if file_pattern else None
return None
def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str:
"""Determine detection type based on CVE description and exploit indicators"""
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "powershell"
elif exploit_indicators.get('network'):
return "network"
elif exploit_indicators.get('files'):
return "file"
elif exploit_indicators.get('processes') or exploit_indicators.get('commands'):
return "process"
# Fallback to original logic
if "remote" in description or "network" in description:
return "network"
elif "process" in description or "execution" in description:
return "process"
elif "file" in description or "filesystem" in description:
return "file"
else:
return "general"
def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str:
"""Calculate confidence level for the generated rule"""
base_confidence = 0
# CVSS score contributes to confidence
if cve.cvss_score:
if cve.cvss_score >= 9.0:
base_confidence += 3
elif cve.cvss_score >= 7.0:
base_confidence += 2
else:
base_confidence += 1
# Exploit-based rules get higher confidence
if exploit_based:
base_confidence += 2
# Map to confidence levels
if base_confidence >= 4:
return "high"
elif base_confidence >= 2:
return "medium"
else:
return "low"
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Background task to fetch CVEs and generate rules
async def background_cve_fetch():
retry_count = 0
max_retries = 3
while True:
try:
db = SessionLocal()
service = CVESigmaService(db)
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
print(f"[{current_time}] Starting CVE fetch cycle...")
# Use a longer initial period (30 days) to find CVEs
new_cves = await service.fetch_recent_cves(days_back=30)
if new_cves:
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
rules_generated = 0
for cve in new_cves:
try:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
print(f"Generated SIGMA rule for {cve.cve_id}")
else:
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
except Exception as e:
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
db.commit()
print(f"Successfully generated {rules_generated} SIGMA rules")
retry_count = 0 # Reset retry count on success
else:
print("No new CVEs found in this cycle")
# After first successful run, reduce to 7 days for regular updates
if retry_count == 0:
print("Switching to 7-day lookback for future runs...")
db.close()
except Exception as e:
retry_count += 1
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
if retry_count >= max_retries:
print(f"Max retries reached, waiting longer before next attempt...")
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
retry_count = 0
else:
await asyncio.sleep(300) # Wait 5 minutes before retry
continue
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
wait_time = 3600 if retry_count == 0 else 1800
print(f"Next CVE fetch in {wait_time//60} minutes...")
await asyncio.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Start background task
task = asyncio.create_task(background_cve_fetch())
yield
# Clean up
task.cancel()
# FastAPI app
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/cves", response_model=List[CVEResponse])
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each CVE
result = []
for cve in cves:
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
result.append(CVEResponse(**cve_dict))
return result
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
return CVEResponse(**cve_dict)
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.post("/api/fetch-cves")
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
async def fetch_task():
try:
service = CVESigmaService(db)
print("Manual CVE fetch initiated...")
# Use 30 days for manual fetch to get more results
new_cves = await service.fetch_recent_cves(days_back=30)
rules_generated = 0
for cve in new_cves:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
db.commit()
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
except Exception as e:
print(f"Manual fetch error: {str(e)}")
import traceback
traceback.print_exc()
background_tasks.add_task(fetch_task)
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
@app.get("/api/test-nvd")
async def test_nvd_connection():
"""Test endpoint to check NVD API connectivity"""
try:
# Test with a simple request using current date
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=30)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
params = {
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"resultsPerPage": 5,
"startIndex": 0
}
headers = {
"User-Agent": "CVE-SIGMA-Generator/1.0",
"Accept": "application/json"
}
nvd_api_key = os.getenv("NVD_API_KEY")
if nvd_api_key:
headers["apiKey"] = nvd_api_key
print(f"Testing NVD API with URL: {url}")
print(f"Test params: {params}")
print(f"Test headers: {headers}")
response = requests.get(url, params=params, headers=headers, timeout=15)
result = {
"status": "success" if response.status_code == 200 else "error",
"status_code": response.status_code,
"has_api_key": bool(nvd_api_key),
"request_url": f"{url}?{requests.compat.urlencode(params)}",
"response_headers": dict(response.headers)
}
if response.status_code == 200:
data = response.json()
result.update({
"total_results": data.get("totalResults", 0),
"results_per_page": data.get("resultsPerPage", 0),
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
"message": "NVD API is accessible and returning data"
})
else:
result.update({
"error_message": response.text[:200],
"message": f"NVD API returned {response.status_code}"
})
# Try fallback without date filters if we get 404
if response.status_code == 404:
print("Trying fallback without date filters...")
fallback_params = {
"resultsPerPage": 5,
"startIndex": 0
}
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
result["fallback_status_code"] = fallback_response.status_code
if fallback_response.status_code == 200:
fallback_data = fallback_response.json()
result.update({
"fallback_success": True,
"fallback_total_results": fallback_data.get("totalResults", 0),
"message": "NVD API works without date filters"
})
return result
except Exception as e:
print(f"NVD API test error: {str(e)}")
return {
"status": "error",
"message": f"Failed to connect to NVD API: {str(e)}"
}
@app.get("/api/stats")
async def get_stats(db: Session = Depends(get_db)):
total_cves = db.query(CVE).count()
total_rules = db.query(SigmaRule).count()
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
# Enhanced stats with bulk processing info
bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count()
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count()
return {
"total_cves": total_cves,
"total_sigma_rules": total_rules,
"recent_cves_7_days": recent_cves,
"bulk_processed_cves": bulk_processed_cves,
"cves_with_pocs": cves_with_pocs,
"nomi_sec_rules": nomi_sec_rules,
"poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
"nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0
}
# New bulk processing endpoints
@app.post("/api/bulk-seed")
async def start_bulk_seed(background_tasks: BackgroundTasks,
start_year: int = 2002,
end_year: Optional[int] = None,
skip_nvd: bool = False,
skip_nomi_sec: bool = False,
db: Session = Depends(get_db)):
"""Start bulk seeding process"""
async def bulk_seed_task():
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
result = await seeder.full_bulk_seed(
start_year=start_year,
end_year=end_year,
skip_nvd=skip_nvd,
skip_nomi_sec=skip_nomi_sec
)
logger.info(f"Bulk seed completed: {result}")
except Exception as e:
logger.error(f"Bulk seed failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(bulk_seed_task)
return {
"message": "Bulk seeding process started",
"status": "started",
"start_year": start_year,
"end_year": end_year or datetime.now().year,
"skip_nvd": skip_nvd,
"skip_nomi_sec": skip_nomi_sec
}
@app.post("/api/incremental-update")
async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start incremental update process"""
async def incremental_update_task():
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
result = await seeder.incremental_update()
logger.info(f"Incremental update completed: {result}")
except Exception as e:
logger.error(f"Incremental update failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(incremental_update_task)
return {
"message": "Incremental update process started",
"status": "started"
}
@app.post("/api/sync-nomi-sec")
async def sync_nomi_sec(background_tasks: BackgroundTasks,
cve_id: Optional[str] = None,
batch_size: int = 50,
db: Session = Depends(get_db)):
"""Synchronize nomi-sec PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='pending',
job_metadata={
'cve_id': cve_id,
'batch_size': batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
if cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(cve_id)
logger.info(f"Nomi-sec sync for {cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(
batch_size=batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"Nomi-sec bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"Nomi-sec sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Nomi-sec sync started" + (f" for {cve_id}" if cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": cve_id,
"batch_size": batch_size
}
@app.get("/api/bulk-jobs")
async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)):
"""Get bulk processing job status"""
jobs = db.query(BulkProcessingJob).order_by(
BulkProcessingJob.created_at.desc()
).limit(limit).all()
result = []
for job in jobs:
job_dict = {
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'metadata': job.job_metadata,
'started_at': job.started_at,
'completed_at': job.completed_at,
'created_at': job.created_at
}
result.append(job_dict)
return result
@app.get("/api/bulk-status")
async def get_bulk_status(db: Session = Depends(get_db)):
"""Get comprehensive bulk processing status"""
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
status = await seeder.get_seeding_status()
return status
except Exception as e:
logger.error(f"Error getting bulk status: {e}")
return {"error": str(e)}
@app.get("/api/poc-stats")
async def get_poc_stats(db: Session = Depends(get_db)):
"""Get PoC-related statistics"""
try:
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
stats = await client.get_sync_status()
# Additional PoC statistics
high_quality_cves = db.query(CVE).filter(
CVE.poc_count > 0,
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60
).count()
stats.update({
'high_quality_cves': high_quality_cves,
'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0
})
return stats
except Exception as e:
logger.error(f"Error getting PoC stats: {e}")
return {"error": str(e)}
@app.post("/api/regenerate-rules")
async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
force: bool = False,
db: Session = Depends(get_db)):
"""Regenerate SIGMA rules using enhanced nomi-sec data"""
async def regenerate_task():
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db)
# Get CVEs with PoC data
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all()
rules_generated = 0
rules_updated = 0
for cve in cves_with_pocs:
# Check if we should regenerate
existing_rule = db.query(SigmaRule).filter(
SigmaRule.cve_id == cve.cve_id
).first()
if existing_rule and existing_rule.poc_source == 'nomi_sec' and not force:
continue
# Generate enhanced rule
result = await generator.generate_enhanced_rule(cve)
if result['success']:
if existing_rule:
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated")
except Exception as e:
logger.error(f"Rule regeneration failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(regenerate_task)
return {
"message": "SIGMA rule regeneration started",
"status": "started",
"force": force
}
@app.post("/api/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job"""
try:
# Find the job in the database
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status not in ['pending', 'running']:
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
# Set cancellation flag
job_cancellation_flags[job_id] = True
# Update job status
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
db.commit()
logger.info(f"Job {job_id} cancellation requested")
return {
"message": f"Job {job_id} cancellation requested",
"status": "cancelled",
"job_id": job_id
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/running-jobs")
async def get_running_jobs(db: Session = Depends(get_db)):
"""Get all currently running jobs"""
try:
jobs = db.query(BulkProcessingJob).filter(
BulkProcessingJob.status.in_(['pending', 'running'])
).order_by(BulkProcessingJob.created_at.desc()).all()
result = []
for job in jobs:
result.append({
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'started_at': job.started_at,
'created_at': job.created_at,
'can_cancel': job.status in ['pending', 'running']
})
return result
except Exception as e:
logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)