auto_sigma_rule_generator/backend/main.py

574 lines
22 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.dialects.postgresql import UUID
import uuid
from datetime import datetime, timedelta
import requests
import json
import re
import os
from typing import List, Optional
from pydantic import BaseModel
import asyncio
from contextlib import asynccontextmanager
# Database setup
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database Models
class CVE(Base):
__tablename__ = "cves"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20), unique=True, nullable=False)
description = Column(Text)
cvss_score = Column(DECIMAL(3, 1))
severity = Column(String(20))
published_date = Column(TIMESTAMP)
modified_date = Column(TIMESTAMP)
affected_products = Column(ARRAY(String))
reference_urls = Column(ARRAY(String))
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class SigmaRule(Base):
__tablename__ = "sigma_rules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20))
rule_name = Column(String(255), nullable=False)
rule_content = Column(Text, nullable=False)
detection_type = Column(String(50))
log_source = Column(String(100))
confidence_level = Column(String(20))
auto_generated = Column(Boolean, default=True)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class RuleTemplate(Base):
__tablename__ = "rule_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
template_name = Column(String(255), nullable=False)
template_content = Column(Text, nullable=False)
applicable_product_patterns = Column(ARRAY(String))
description = Column(Text)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
# Pydantic models
class CVEResponse(BaseModel):
id: str
cve_id: str
description: Optional[str] = None
cvss_score: Optional[float] = None
severity: Optional[str] = None
published_date: Optional[datetime] = None
affected_products: Optional[List[str]] = None
reference_urls: Optional[List[str]] = None
class Config:
from_attributes = True
class SigmaRuleResponse(BaseModel):
id: str
cve_id: str
rule_name: str
rule_content: str
detection_type: Optional[str] = None
log_source: Optional[str] = None
confidence_level: Optional[str] = None
auto_generated: bool = True
created_at: datetime
class Config:
from_attributes = True
# CVE and SIGMA Rule Generator Service
class CVESigmaService:
def __init__(self, db: Session):
self.db = db
self.nvd_api_key = os.getenv("NVD_API_KEY")
async def fetch_recent_cves(self, days_back: int = 7):
"""Fetch recent CVEs from NVD API"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days_back)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
params = {
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"resultsPerPage": 100
}
headers = {}
if self.nvd_api_key:
headers["apiKey"] = self.nvd_api_key
try:
response = requests.get(url, params=params, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
new_cves = []
for vuln in data.get("vulnerabilities", []):
cve_data = vuln.get("cve", {})
cve_id = cve_data.get("id")
# Check if CVE already exists
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
if existing:
continue
# Extract CVE information
description = ""
if cve_data.get("descriptions"):
description = cve_data["descriptions"][0].get("value", "")
cvss_score = None
severity = None
if cve_data.get("metrics", {}).get("cvssMetricV31"):
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
affected_products = []
if cve_data.get("configurations"):
for config in cve_data["configurations"]:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable"):
affected_products.append(cpe_match.get("criteria", ""))
reference_urls = []
if cve_data.get("references"):
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
cve_obj = CVE(
cve_id=cve_id,
description=description,
cvss_score=cvss_score,
severity=severity,
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
affected_products=affected_products,
reference_urls=reference_urls
)
self.db.add(cve_obj)
new_cves.append(cve_obj)
self.db.commit()
return new_cves
except Exception as e:
print(f"Error fetching CVEs: {str(e)}")
return []
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
"""Generate SIGMA rule based on CVE data"""
if not cve.description:
return None
# Analyze CVE to determine appropriate template
description_lower = cve.description.lower()
affected_products = [p.lower() for p in (cve.affected_products or [])]
template = self._select_template(description_lower, affected_products)
if not template:
return None
# Generate rule content
rule_content = self._populate_template(cve, template)
if not rule_content:
return None
# Determine detection type and confidence
detection_type = self._determine_detection_type(description_lower)
confidence_level = self._calculate_confidence(cve)
sigma_rule = SigmaRule(
cve_id=cve.cve_id,
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
rule_content=rule_content,
detection_type=detection_type,
log_source=template.template_name.lower().replace(" ", "_"),
confidence_level=confidence_level,
auto_generated=True
)
self.db.add(sigma_rule)
return sigma_rule
def _select_template(self, description: str, affected_products: List[str]):
"""Select appropriate SIGMA rule template"""
templates = self.db.query(RuleTemplate).all()
# Simple template selection logic
if any("windows" in p or "microsoft" in p for p in affected_products):
if "process" in description or "execution" in description:
return next((t for t in templates if "Process Execution" in t.template_name), None)
elif "network" in description or "remote" in description:
return next((t for t in templates if "Network Connection" in t.template_name), None)
elif "file" in description or "write" in description:
return next((t for t in templates if "File Modification" in t.template_name), None)
# Default to process execution template
return next((t for t in templates if "Process Execution" in t.template_name), None)
def _populate_template(self, cve: CVE, template: RuleTemplate) -> str:
"""Populate template with CVE-specific data"""
try:
# Extract suspicious indicators from description
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
# Determine severity level
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
rule_content = template.template_content.format(
title=f"CVE-{cve.cve_id} Exploitation Attempt",
description=cve.description[:200] + "..." if len(cve.description) > 200 else cve.description,
rule_id=str(uuid.uuid4()),
date=datetime.utcnow().strftime("%Y/%m/%d"),
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
cve_id=cve.cve_id.lower(),
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
file_patterns=file_patterns or ["temp", "malware", "exploit"],
level=level
)
return rule_content
except Exception as e:
print(f"Error populating template: {str(e)}")
return None
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
"""Extract suspicious indicators from CVE description"""
if indicator_type == "process":
# Look for executable names or process patterns
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
return exe_pattern[:5] if exe_pattern else None
elif indicator_type == "port":
# Look for port numbers
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
return [int(p) for p in port_pattern[:3]] if port_pattern else None
elif indicator_type == "file":
# Look for file extensions or paths
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
return file_pattern[:5] if file_pattern else None
return None
def _determine_detection_type(self, description: str) -> str:
"""Determine detection type based on CVE description"""
if "remote" in description or "network" in description:
return "network"
elif "process" in description or "execution" in description:
return "process"
elif "file" in description or "filesystem" in description:
return "file"
else:
return "general"
def _calculate_confidence(self, cve: CVE) -> str:
"""Calculate confidence level for the generated rule"""
if cve.cvss_score and cve.cvss_score >= 9.0:
return "high"
elif cve.cvss_score and cve.cvss_score >= 7.0:
return "medium"
else:
return "low"
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Background task to fetch CVEs and generate rules
async def background_cve_fetch():
retry_count = 0
max_retries = 3
while True:
try:
db = SessionLocal()
service = CVESigmaService(db)
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
print(f"[{current_time}] Starting CVE fetch cycle...")
# Use a longer initial period (30 days) to find CVEs
new_cves = await service.fetch_recent_cves(days_back=30)
if new_cves:
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
rules_generated = 0
for cve in new_cves:
try:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
print(f"Generated SIGMA rule for {cve.cve_id}")
else:
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
except Exception as e:
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
db.commit()
print(f"Successfully generated {rules_generated} SIGMA rules")
retry_count = 0 # Reset retry count on success
else:
print("No new CVEs found in this cycle")
# After first successful run, reduce to 7 days for regular updates
if retry_count == 0:
print("Switching to 7-day lookback for future runs...")
db.close()
except Exception as e:
retry_count += 1
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
if retry_count >= max_retries:
print(f"Max retries reached, waiting longer before next attempt...")
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
retry_count = 0
else:
await asyncio.sleep(300) # Wait 5 minutes before retry
continue
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
wait_time = 3600 if retry_count == 0 else 1800
print(f"Next CVE fetch in {wait_time//60} minutes...")
await asyncio.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Start background task
task = asyncio.create_task(background_cve_fetch())
yield
# Clean up
task.cancel()
# FastAPI app
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/cves", response_model=List[CVEResponse])
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each CVE
result = []
for cve in cves:
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
result.append(CVEResponse(**cve_dict))
return result
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
return CVEResponse(**cve_dict)
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.post("/api/fetch-cves")
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
async def fetch_task():
try:
service = CVESigmaService(db)
print("Manual CVE fetch initiated...")
# Use 30 days for manual fetch to get more results
new_cves = await service.fetch_recent_cves(days_back=30)
rules_generated = 0
for cve in new_cves:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
db.commit()
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
except Exception as e:
print(f"Manual fetch error: {str(e)}")
import traceback
traceback.print_exc()
background_tasks.add_task(fetch_task)
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
@app.get("/api/test-nvd")
async def test_nvd_connection():
"""Test endpoint to check NVD API connectivity"""
try:
# Test with a simple request using current date
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=30)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
params = {
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"resultsPerPage": 5,
"startIndex": 0
}
headers = {
"User-Agent": "CVE-SIGMA-Generator/1.0",
"Accept": "application/json"
}
nvd_api_key = os.getenv("NVD_API_KEY")
if nvd_api_key:
headers["apiKey"] = nvd_api_key
print(f"Testing NVD API with URL: {url}")
print(f"Test params: {params}")
print(f"Test headers: {headers}")
response = requests.get(url, params=params, headers=headers, timeout=15)
result = {
"status": "success" if response.status_code == 200 else "error",
"status_code": response.status_code,
"has_api_key": bool(nvd_api_key),
"request_url": f"{url}?{requests.compat.urlencode(params)}",
"response_headers": dict(response.headers)
}
if response.status_code == 200:
data = response.json()
result.update({
"total_results": data.get("totalResults", 0),
"results_per_page": data.get("resultsPerPage", 0),
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
"message": "NVD API is accessible and returning data"
})
else:
result.update({
"error_message": response.text[:200],
"message": f"NVD API returned {response.status_code}"
})
# Try fallback without date filters if we get 404
if response.status_code == 404:
print("Trying fallback without date filters...")
fallback_params = {
"resultsPerPage": 5,
"startIndex": 0
}
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
result["fallback_status_code"] = fallback_response.status_code
if fallback_response.status_code == 200:
fallback_data = fallback_response.json()
result.update({
"fallback_success": True,
"fallback_total_results": fallback_data.get("totalResults", 0),
"message": "NVD API works without date filters"
})
return result
except Exception as e:
print(f"NVD API test error: {str(e)}")
return {
"status": "error",
"message": f"Failed to connect to NVD API: {str(e)}"
}
@app.get("/api/stats")
async def get_stats(db: Session = Depends(get_db)):
total_cves = db.query(CVE).count()
total_rules = db.query(SigmaRule).count()
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
return {
"total_cves": total_cves,
"total_sigma_rules": total_rules,
"recent_cves_7_days": recent_cves
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)