auto_sigma_rule_generator/backend/routers/llm_operations.py

381 lines
No EOL
16 KiB
Python

from typing import Dict, Any
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy.orm import Session
from pydantic import BaseModel
import logging
from config.database import get_db
from models import CVE, SigmaRule
router = APIRouter(prefix="/api", tags=["llm-operations"])
logger = logging.getLogger(__name__)
class LLMRuleRequest(BaseModel):
cve_id: str = None # Optional for bulk operations
poc_content: str = ""
force: bool = False # For bulk operations
class LLMSwitchRequest(BaseModel):
provider: str
model: str = ""
@router.post("/llm-enhanced-rules")
async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Generate SIGMA rules using LLM AI analysis"""
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
if request.cve_id:
# Single CVE operation
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
# Generate enhanced rule using LLM
generator = EnhancedSigmaGenerator(db)
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
return {
"success": True,
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
"rule_id": result.get('rule_id'),
"generation_method": "llm_enhanced"
}
else:
return {
"success": False,
"error": result.get('error', 'Unknown error'),
"cve_id": request.cve_id
}
else:
# Bulk operation - run in background with job tracking
from models import BulkProcessingJob
import uuid
from datetime import datetime
import main
# Create job record
job = BulkProcessingJob(
job_type='llm_rule_generation',
status='pending',
job_metadata={
'force': request.force
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def bulk_llm_generation_task():
# Create a new database session for the background task
from config.database import SessionLocal
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
generator = EnhancedSigmaGenerator(task_db)
# Get CVEs with PoC data - limit to small batch initially
if request.force:
# Process all CVEs with PoC data - but limit to prevent system overload
cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all()
else:
# Only process CVEs without existing LLM-generated rules - small batch
cves_to_process = task_db.query(CVE).filter(
CVE.poc_count > 0
).limit(10).all()
# Filter out CVEs that already have LLM-generated rules
existing_llm_rules = task_db.query(SigmaRule).filter(
SigmaRule.detection_type.like('llm_%')
).all()
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids]
# Update job with total items
task_job.total_items = len(cves_to_process)
task_db.commit()
rules_generated = 0
rules_updated = 0
failures = 0
logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})")
for i, cve in enumerate(cves_to_process):
# Check for cancellation
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled, stopping LLM generation")
break
try:
logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}")
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if result.get('updated'):
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Successfully generated rule for {cve.cve_id}")
else:
failures += 1
logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}")
# Update progress
task_job.processed_items = i + 1
task_job.failed_items = failures
task_db.commit()
except Exception as e:
failures += 1
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
# Update progress
task_job.processed_items = i + 1
task_job.failed_items = failures
task_db.commit()
# Continue with next CVE even if one fails
continue
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"Bulk LLM rule generation failed: {e}")
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(bulk_llm_generation_task)
return {
"success": True,
"message": "Bulk LLM-enhanced rule generation started",
"status": "started",
"job_id": job_id,
"force": request.force
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating LLM-enhanced rule: {str(e)}")
@router.get("/llm-status")
async def get_llm_status():
"""Check LLM API availability and configuration for all providers"""
try:
from llm_client import LLMClient
# Get current configuration first
current_client = LLMClient()
# Build available providers list in the format frontend expects
available_providers = []
# Test Ollama
try:
ollama_client = LLMClient(provider="ollama")
ollama_status = await ollama_client.test_connection()
available_providers.append({
"name": "ollama",
"available": ollama_status.get("available", False),
"default_model": ollama_status.get("current_model", "llama3.2"),
"models": ollama_status.get("models", []),
"base_url": ollama_status.get("base_url")
})
except Exception as e:
available_providers.append({
"name": "ollama",
"available": False,
"default_model": "llama3.2",
"models": [],
"error": str(e)
})
# Test OpenAI
try:
openai_client = LLMClient(provider="openai")
openai_status = await openai_client.test_connection()
available_providers.append({
"name": "openai",
"available": openai_status.get("available", False),
"default_model": openai_status.get("current_model", "gpt-4o-mini"),
"models": openai_status.get("models", []),
"has_api_key": openai_status.get("has_api_key", False)
})
except Exception as e:
available_providers.append({
"name": "openai",
"available": False,
"default_model": "gpt-4o-mini",
"models": [],
"has_api_key": False,
"error": str(e)
})
# Test Anthropic
try:
anthropic_client = LLMClient(provider="anthropic")
anthropic_status = await anthropic_client.test_connection()
available_providers.append({
"name": "anthropic",
"available": anthropic_status.get("available", False),
"default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"),
"models": anthropic_status.get("models", []),
"has_api_key": anthropic_status.get("has_api_key", False)
})
except Exception as e:
available_providers.append({
"name": "anthropic",
"available": False,
"default_model": "claude-3-5-sonnet-20241022",
"models": [],
"has_api_key": False,
"error": str(e)
})
# Determine overall status
any_available = any(p.get("available") for p in available_providers)
status = "ready" if any_available else "not_ready"
# Return in the format the frontend expects
return {
"status": status,
"current_provider": {
"provider": current_client.provider,
"model": current_client.model
},
"available_providers": available_providers
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"current_provider": {
"provider": "unknown",
"model": "unknown"
},
"available_providers": []
}
@router.post("/llm-switch")
async def switch_llm_provider(request: LLMSwitchRequest):
"""Switch between LLM providers and models"""
try:
from llm_client import LLMClient
# Test the new provider/model
test_client = LLMClient(provider=request.provider, model=request.model)
connection_test = await test_client.test_connection()
if not connection_test.get("available"):
raise HTTPException(
status_code=400,
detail=f"Provider {request.provider} with model {request.model} is not available"
)
# Switch to new configuration (this would typically involve updating environment variables
# or configuration files, but for now we'll just confirm the switch)
return {
"success": True,
"message": f"Switched to {request.provider}" + (f" with model {request.model}" if request.model else ""),
"provider": request.provider,
"model": request.model or connection_test.get("current_model"),
"available": True
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error switching LLM provider: {str(e)}")
@router.post("/ollama-pull-model")
async def pull_ollama_model(model: str = "llama3.2"):
"""Pull a model in Ollama"""
try:
import aiohttp
import os
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
async with aiohttp.ClientSession() as session:
async with session.post(f"{ollama_url}/api/pull", json={"name": model}) as response:
if response.status == 200:
return {
"success": True,
"message": f"Successfully pulled model {model}",
"model": model
}
else:
raise HTTPException(status_code=500, detail=f"Failed to pull model: {response.status}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error pulling Ollama model: {str(e)}")
@router.get("/ollama-models")
async def get_ollama_models():
"""Get available Ollama models"""
try:
import aiohttp
import os
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
async with aiohttp.ClientSession() as session:
async with session.get(f"{ollama_url}/api/tags") as response:
if response.status == 200:
data = await response.json()
models = [model["name"] for model in data.get("models", [])]
return {
"models": models,
"total_models": len(models),
"ollama_url": ollama_url
}
else:
return {
"models": [],
"total_models": 0,
"error": f"Ollama not available (status: {response.status})"
}
except Exception as e:
return {
"models": [],
"total_models": 0,
"error": f"Error connecting to Ollama: {str(e)}"
}