574 lines
22 KiB
Python
574 lines
22 KiB
Python
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from sqlalchemy.dialects.postgresql import UUID
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
import requests
|
|
import json
|
|
import re
|
|
import os
|
|
from typing import List, Optional
|
|
from pydantic import BaseModel
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
|
|
# Database setup
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
|
|
engine = create_engine(DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
# Database Models
|
|
class CVE(Base):
|
|
__tablename__ = "cves"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20), unique=True, nullable=False)
|
|
description = Column(Text)
|
|
cvss_score = Column(DECIMAL(3, 1))
|
|
severity = Column(String(20))
|
|
published_date = Column(TIMESTAMP)
|
|
modified_date = Column(TIMESTAMP)
|
|
affected_products = Column(ARRAY(String))
|
|
reference_urls = Column(ARRAY(String))
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class SigmaRule(Base):
|
|
__tablename__ = "sigma_rules"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20))
|
|
rule_name = Column(String(255), nullable=False)
|
|
rule_content = Column(Text, nullable=False)
|
|
detection_type = Column(String(50))
|
|
log_source = Column(String(100))
|
|
confidence_level = Column(String(20))
|
|
auto_generated = Column(Boolean, default=True)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class RuleTemplate(Base):
|
|
__tablename__ = "rule_templates"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
template_name = Column(String(255), nullable=False)
|
|
template_content = Column(Text, nullable=False)
|
|
applicable_product_patterns = Column(ARRAY(String))
|
|
description = Column(Text)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
# Pydantic models
|
|
class CVEResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
description: Optional[str] = None
|
|
cvss_score: Optional[float] = None
|
|
severity: Optional[str] = None
|
|
published_date: Optional[datetime] = None
|
|
affected_products: Optional[List[str]] = None
|
|
reference_urls: Optional[List[str]] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class SigmaRuleResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
rule_name: str
|
|
rule_content: str
|
|
detection_type: Optional[str] = None
|
|
log_source: Optional[str] = None
|
|
confidence_level: Optional[str] = None
|
|
auto_generated: bool = True
|
|
created_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
# CVE and SIGMA Rule Generator Service
|
|
class CVESigmaService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.nvd_api_key = os.getenv("NVD_API_KEY")
|
|
|
|
async def fetch_recent_cves(self, days_back: int = 7):
|
|
"""Fetch recent CVEs from NVD API"""
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=days_back)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
|
params = {
|
|
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"resultsPerPage": 100
|
|
}
|
|
|
|
headers = {}
|
|
if self.nvd_api_key:
|
|
headers["apiKey"] = self.nvd_api_key
|
|
|
|
try:
|
|
response = requests.get(url, params=params, headers=headers, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
new_cves = []
|
|
for vuln in data.get("vulnerabilities", []):
|
|
cve_data = vuln.get("cve", {})
|
|
cve_id = cve_data.get("id")
|
|
|
|
# Check if CVE already exists
|
|
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if existing:
|
|
continue
|
|
|
|
# Extract CVE information
|
|
description = ""
|
|
if cve_data.get("descriptions"):
|
|
description = cve_data["descriptions"][0].get("value", "")
|
|
|
|
cvss_score = None
|
|
severity = None
|
|
if cve_data.get("metrics", {}).get("cvssMetricV31"):
|
|
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
|
|
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
|
|
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
|
|
|
|
affected_products = []
|
|
if cve_data.get("configurations"):
|
|
for config in cve_data["configurations"]:
|
|
for node in config.get("nodes", []):
|
|
for cpe_match in node.get("cpeMatch", []):
|
|
if cpe_match.get("vulnerable"):
|
|
affected_products.append(cpe_match.get("criteria", ""))
|
|
|
|
reference_urls = []
|
|
if cve_data.get("references"):
|
|
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
|
|
|
|
cve_obj = CVE(
|
|
cve_id=cve_id,
|
|
description=description,
|
|
cvss_score=cvss_score,
|
|
severity=severity,
|
|
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
|
|
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
|
|
affected_products=affected_products,
|
|
reference_urls=reference_urls
|
|
)
|
|
|
|
self.db.add(cve_obj)
|
|
new_cves.append(cve_obj)
|
|
|
|
self.db.commit()
|
|
return new_cves
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching CVEs: {str(e)}")
|
|
return []
|
|
|
|
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
|
|
"""Generate SIGMA rule based on CVE data"""
|
|
if not cve.description:
|
|
return None
|
|
|
|
# Analyze CVE to determine appropriate template
|
|
description_lower = cve.description.lower()
|
|
affected_products = [p.lower() for p in (cve.affected_products or [])]
|
|
|
|
template = self._select_template(description_lower, affected_products)
|
|
if not template:
|
|
return None
|
|
|
|
# Generate rule content
|
|
rule_content = self._populate_template(cve, template)
|
|
if not rule_content:
|
|
return None
|
|
|
|
# Determine detection type and confidence
|
|
detection_type = self._determine_detection_type(description_lower)
|
|
confidence_level = self._calculate_confidence(cve)
|
|
|
|
sigma_rule = SigmaRule(
|
|
cve_id=cve.cve_id,
|
|
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
|
|
rule_content=rule_content,
|
|
detection_type=detection_type,
|
|
log_source=template.template_name.lower().replace(" ", "_"),
|
|
confidence_level=confidence_level,
|
|
auto_generated=True
|
|
)
|
|
|
|
self.db.add(sigma_rule)
|
|
return sigma_rule
|
|
|
|
def _select_template(self, description: str, affected_products: List[str]):
|
|
"""Select appropriate SIGMA rule template"""
|
|
templates = self.db.query(RuleTemplate).all()
|
|
|
|
# Simple template selection logic
|
|
if any("windows" in p or "microsoft" in p for p in affected_products):
|
|
if "process" in description or "execution" in description:
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
elif "network" in description or "remote" in description:
|
|
return next((t for t in templates if "Network Connection" in t.template_name), None)
|
|
elif "file" in description or "write" in description:
|
|
return next((t for t in templates if "File Modification" in t.template_name), None)
|
|
|
|
# Default to process execution template
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
|
|
def _populate_template(self, cve: CVE, template: RuleTemplate) -> str:
|
|
"""Populate template with CVE-specific data"""
|
|
try:
|
|
# Extract suspicious indicators from description
|
|
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
|
|
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
|
|
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
|
|
|
|
# Determine severity level
|
|
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
|
|
|
|
rule_content = template.template_content.format(
|
|
title=f"CVE-{cve.cve_id} Exploitation Attempt",
|
|
description=cve.description[:200] + "..." if len(cve.description) > 200 else cve.description,
|
|
rule_id=str(uuid.uuid4()),
|
|
date=datetime.utcnow().strftime("%Y/%m/%d"),
|
|
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
|
|
cve_id=cve.cve_id.lower(),
|
|
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
|
|
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
|
|
file_patterns=file_patterns or ["temp", "malware", "exploit"],
|
|
level=level
|
|
)
|
|
|
|
return rule_content
|
|
|
|
except Exception as e:
|
|
print(f"Error populating template: {str(e)}")
|
|
return None
|
|
|
|
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
|
|
"""Extract suspicious indicators from CVE description"""
|
|
if indicator_type == "process":
|
|
# Look for executable names or process patterns
|
|
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
|
|
return exe_pattern[:5] if exe_pattern else None
|
|
|
|
elif indicator_type == "port":
|
|
# Look for port numbers
|
|
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
|
|
return [int(p) for p in port_pattern[:3]] if port_pattern else None
|
|
|
|
elif indicator_type == "file":
|
|
# Look for file extensions or paths
|
|
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
|
|
return file_pattern[:5] if file_pattern else None
|
|
|
|
return None
|
|
|
|
def _determine_detection_type(self, description: str) -> str:
|
|
"""Determine detection type based on CVE description"""
|
|
if "remote" in description or "network" in description:
|
|
return "network"
|
|
elif "process" in description or "execution" in description:
|
|
return "process"
|
|
elif "file" in description or "filesystem" in description:
|
|
return "file"
|
|
else:
|
|
return "general"
|
|
|
|
def _calculate_confidence(self, cve: CVE) -> str:
|
|
"""Calculate confidence level for the generated rule"""
|
|
if cve.cvss_score and cve.cvss_score >= 9.0:
|
|
return "high"
|
|
elif cve.cvss_score and cve.cvss_score >= 7.0:
|
|
return "medium"
|
|
else:
|
|
return "low"
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
# Background task to fetch CVEs and generate rules
|
|
async def background_cve_fetch():
|
|
retry_count = 0
|
|
max_retries = 3
|
|
|
|
while True:
|
|
try:
|
|
db = SessionLocal()
|
|
service = CVESigmaService(db)
|
|
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
|
|
print(f"[{current_time}] Starting CVE fetch cycle...")
|
|
|
|
# Use a longer initial period (30 days) to find CVEs
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
if new_cves:
|
|
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
try:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
print(f"Generated SIGMA rule for {cve.cve_id}")
|
|
else:
|
|
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
|
|
except Exception as e:
|
|
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
|
|
|
|
db.commit()
|
|
print(f"Successfully generated {rules_generated} SIGMA rules")
|
|
retry_count = 0 # Reset retry count on success
|
|
else:
|
|
print("No new CVEs found in this cycle")
|
|
# After first successful run, reduce to 7 days for regular updates
|
|
if retry_count == 0:
|
|
print("Switching to 7-day lookback for future runs...")
|
|
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
|
|
if retry_count >= max_retries:
|
|
print(f"Max retries reached, waiting longer before next attempt...")
|
|
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
|
|
retry_count = 0
|
|
else:
|
|
await asyncio.sleep(300) # Wait 5 minutes before retry
|
|
continue
|
|
|
|
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
|
|
wait_time = 3600 if retry_count == 0 else 1800
|
|
print(f"Next CVE fetch in {wait_time//60} minutes...")
|
|
await asyncio.sleep(wait_time)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Start background task
|
|
task = asyncio.create_task(background_cve_fetch())
|
|
yield
|
|
# Clean up
|
|
task.cancel()
|
|
|
|
# FastAPI app
|
|
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/api/cves", response_model=List[CVEResponse])
|
|
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each CVE
|
|
result = []
|
|
for cve in cves:
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
result.append(CVEResponse(**cve_dict))
|
|
return result
|
|
|
|
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
|
|
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if not cve:
|
|
raise HTTPException(status_code=404, detail="CVE not found")
|
|
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
return CVEResponse(**cve_dict)
|
|
|
|
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.post("/api/fetch-cves")
|
|
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
async def fetch_task():
|
|
try:
|
|
service = CVESigmaService(db)
|
|
print("Manual CVE fetch initiated...")
|
|
# Use 30 days for manual fetch to get more results
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
|
|
db.commit()
|
|
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
|
|
except Exception as e:
|
|
print(f"Manual fetch error: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(fetch_task)
|
|
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
|
|
|
|
@app.get("/api/test-nvd")
|
|
async def test_nvd_connection():
|
|
"""Test endpoint to check NVD API connectivity"""
|
|
try:
|
|
# Test with a simple request using current date
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=30)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
|
|
params = {
|
|
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
|
|
headers = {
|
|
"User-Agent": "CVE-SIGMA-Generator/1.0",
|
|
"Accept": "application/json"
|
|
}
|
|
|
|
nvd_api_key = os.getenv("NVD_API_KEY")
|
|
if nvd_api_key:
|
|
headers["apiKey"] = nvd_api_key
|
|
|
|
print(f"Testing NVD API with URL: {url}")
|
|
print(f"Test params: {params}")
|
|
print(f"Test headers: {headers}")
|
|
|
|
response = requests.get(url, params=params, headers=headers, timeout=15)
|
|
|
|
result = {
|
|
"status": "success" if response.status_code == 200 else "error",
|
|
"status_code": response.status_code,
|
|
"has_api_key": bool(nvd_api_key),
|
|
"request_url": f"{url}?{requests.compat.urlencode(params)}",
|
|
"response_headers": dict(response.headers)
|
|
}
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result.update({
|
|
"total_results": data.get("totalResults", 0),
|
|
"results_per_page": data.get("resultsPerPage", 0),
|
|
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
|
|
"message": "NVD API is accessible and returning data"
|
|
})
|
|
else:
|
|
result.update({
|
|
"error_message": response.text[:200],
|
|
"message": f"NVD API returned {response.status_code}"
|
|
})
|
|
|
|
# Try fallback without date filters if we get 404
|
|
if response.status_code == 404:
|
|
print("Trying fallback without date filters...")
|
|
fallback_params = {
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
|
|
result["fallback_status_code"] = fallback_response.status_code
|
|
|
|
if fallback_response.status_code == 200:
|
|
fallback_data = fallback_response.json()
|
|
result.update({
|
|
"fallback_success": True,
|
|
"fallback_total_results": fallback_data.get("totalResults", 0),
|
|
"message": "NVD API works without date filters"
|
|
})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"NVD API test error: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"Failed to connect to NVD API: {str(e)}"
|
|
}
|
|
|
|
@app.get("/api/stats")
|
|
async def get_stats(db: Session = Depends(get_db)):
|
|
total_cves = db.query(CVE).count()
|
|
total_rules = db.query(SigmaRule).count()
|
|
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
|
|
|
|
return {
|
|
"total_cves": total_cves,
|
|
"total_sigma_rules": total_rules,
|
|
"recent_cves_7_days": recent_cves
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|