381 lines
No EOL
16 KiB
Python
381 lines
No EOL
16 KiB
Python
from typing import Dict, Any
|
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
from pydantic import BaseModel
|
|
import logging
|
|
|
|
from config.database import get_db
|
|
from models import CVE, SigmaRule
|
|
|
|
router = APIRouter(prefix="/api", tags=["llm-operations"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMRuleRequest(BaseModel):
|
|
cve_id: str = None # Optional for bulk operations
|
|
poc_content: str = ""
|
|
force: bool = False # For bulk operations
|
|
|
|
|
|
class LLMSwitchRequest(BaseModel):
|
|
provider: str
|
|
model: str = ""
|
|
|
|
|
|
@router.post("/llm-enhanced-rules")
|
|
async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Generate SIGMA rules using LLM AI analysis"""
|
|
try:
|
|
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
|
|
|
if request.cve_id:
|
|
# Single CVE operation
|
|
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
|
|
if not cve:
|
|
raise HTTPException(status_code=404, detail="CVE not found")
|
|
|
|
# Generate enhanced rule using LLM
|
|
generator = EnhancedSigmaGenerator(db)
|
|
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
|
|
|
if result.get('success'):
|
|
return {
|
|
"success": True,
|
|
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
|
|
"rule_id": result.get('rule_id'),
|
|
"generation_method": "llm_enhanced"
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": result.get('error', 'Unknown error'),
|
|
"cve_id": request.cve_id
|
|
}
|
|
else:
|
|
# Bulk operation - run in background with job tracking
|
|
from models import BulkProcessingJob
|
|
import uuid
|
|
from datetime import datetime
|
|
import main
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='llm_rule_generation',
|
|
status='pending',
|
|
job_metadata={
|
|
'force': request.force
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
main.running_jobs[job_id] = job
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def bulk_llm_generation_task():
|
|
# Create a new database session for the background task
|
|
from config.database import SessionLocal
|
|
task_db = SessionLocal()
|
|
try:
|
|
# Get the job in the new session
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if not task_job:
|
|
logger.error(f"Job {job_id} not found in task session")
|
|
return
|
|
|
|
task_job.status = 'running'
|
|
task_job.started_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
generator = EnhancedSigmaGenerator(task_db)
|
|
|
|
# Get CVEs with PoC data - limit to small batch initially
|
|
if request.force:
|
|
# Process all CVEs with PoC data - but limit to prevent system overload
|
|
cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all()
|
|
else:
|
|
# Only process CVEs without existing LLM-generated rules - small batch
|
|
cves_to_process = task_db.query(CVE).filter(
|
|
CVE.poc_count > 0
|
|
).limit(10).all()
|
|
|
|
# Filter out CVEs that already have LLM-generated rules
|
|
existing_llm_rules = task_db.query(SigmaRule).filter(
|
|
SigmaRule.detection_type.like('llm_%')
|
|
).all()
|
|
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
|
|
cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids]
|
|
|
|
# Update job with total items
|
|
task_job.total_items = len(cves_to_process)
|
|
task_db.commit()
|
|
|
|
rules_generated = 0
|
|
rules_updated = 0
|
|
failures = 0
|
|
|
|
logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})")
|
|
|
|
for i, cve in enumerate(cves_to_process):
|
|
# Check for cancellation
|
|
if main.job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled, stopping LLM generation")
|
|
break
|
|
|
|
try:
|
|
logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}")
|
|
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
|
|
|
if result.get('success'):
|
|
if result.get('updated'):
|
|
rules_updated += 1
|
|
else:
|
|
rules_generated += 1
|
|
logger.info(f"Successfully generated rule for {cve.cve_id}")
|
|
else:
|
|
failures += 1
|
|
logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}")
|
|
|
|
# Update progress
|
|
task_job.processed_items = i + 1
|
|
task_job.failed_items = failures
|
|
task_db.commit()
|
|
|
|
except Exception as e:
|
|
failures += 1
|
|
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
|
|
|
# Update progress
|
|
task_job.processed_items = i + 1
|
|
task_job.failed_items = failures
|
|
task_db.commit()
|
|
|
|
# Continue with next CVE even if one fails
|
|
continue
|
|
|
|
# Update job status if not cancelled
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
task_job.status = 'completed'
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
|
|
|
|
except Exception as e:
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
# Get the job again in case it was modified
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if task_job:
|
|
task_job.status = 'failed'
|
|
task_job.error_message = str(e)
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
logger.error(f"Bulk LLM rule generation failed: {e}")
|
|
finally:
|
|
# Clean up tracking and close the task session
|
|
main.running_jobs.pop(job_id, None)
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
task_db.close()
|
|
|
|
background_tasks.add_task(bulk_llm_generation_task)
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Bulk LLM-enhanced rule generation started",
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"force": request.force
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error generating LLM-enhanced rule: {str(e)}")
|
|
|
|
|
|
@router.get("/llm-status")
|
|
async def get_llm_status():
|
|
"""Check LLM API availability and configuration for all providers"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
# Get current configuration first
|
|
current_client = LLMClient()
|
|
|
|
# Build available providers list in the format frontend expects
|
|
available_providers = []
|
|
|
|
# Test Ollama
|
|
try:
|
|
ollama_client = LLMClient(provider="ollama")
|
|
ollama_status = await ollama_client.test_connection()
|
|
available_providers.append({
|
|
"name": "ollama",
|
|
"available": ollama_status.get("available", False),
|
|
"default_model": ollama_status.get("current_model", "llama3.2"),
|
|
"models": ollama_status.get("models", []),
|
|
"base_url": ollama_status.get("base_url")
|
|
})
|
|
except Exception as e:
|
|
available_providers.append({
|
|
"name": "ollama",
|
|
"available": False,
|
|
"default_model": "llama3.2",
|
|
"models": [],
|
|
"error": str(e)
|
|
})
|
|
|
|
# Test OpenAI
|
|
try:
|
|
openai_client = LLMClient(provider="openai")
|
|
openai_status = await openai_client.test_connection()
|
|
available_providers.append({
|
|
"name": "openai",
|
|
"available": openai_status.get("available", False),
|
|
"default_model": openai_status.get("current_model", "gpt-4o-mini"),
|
|
"models": openai_status.get("models", []),
|
|
"has_api_key": openai_status.get("has_api_key", False)
|
|
})
|
|
except Exception as e:
|
|
available_providers.append({
|
|
"name": "openai",
|
|
"available": False,
|
|
"default_model": "gpt-4o-mini",
|
|
"models": [],
|
|
"has_api_key": False,
|
|
"error": str(e)
|
|
})
|
|
|
|
# Test Anthropic
|
|
try:
|
|
anthropic_client = LLMClient(provider="anthropic")
|
|
anthropic_status = await anthropic_client.test_connection()
|
|
available_providers.append({
|
|
"name": "anthropic",
|
|
"available": anthropic_status.get("available", False),
|
|
"default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"),
|
|
"models": anthropic_status.get("models", []),
|
|
"has_api_key": anthropic_status.get("has_api_key", False)
|
|
})
|
|
except Exception as e:
|
|
available_providers.append({
|
|
"name": "anthropic",
|
|
"available": False,
|
|
"default_model": "claude-3-5-sonnet-20241022",
|
|
"models": [],
|
|
"has_api_key": False,
|
|
"error": str(e)
|
|
})
|
|
|
|
# Determine overall status
|
|
any_available = any(p.get("available") for p in available_providers)
|
|
status = "ready" if any_available else "not_ready"
|
|
|
|
# Return in the format the frontend expects
|
|
return {
|
|
"status": status,
|
|
"current_provider": {
|
|
"provider": current_client.provider,
|
|
"model": current_client.model
|
|
},
|
|
"available_providers": available_providers
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"error": str(e),
|
|
"current_provider": {
|
|
"provider": "unknown",
|
|
"model": "unknown"
|
|
},
|
|
"available_providers": []
|
|
}
|
|
|
|
|
|
@router.post("/llm-switch")
|
|
async def switch_llm_provider(request: LLMSwitchRequest):
|
|
"""Switch between LLM providers and models"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
# Test the new provider/model
|
|
test_client = LLMClient(provider=request.provider, model=request.model)
|
|
connection_test = await test_client.test_connection()
|
|
|
|
if not connection_test.get("available"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Provider {request.provider} with model {request.model} is not available"
|
|
)
|
|
|
|
# Switch to new configuration (this would typically involve updating environment variables
|
|
# or configuration files, but for now we'll just confirm the switch)
|
|
return {
|
|
"success": True,
|
|
"message": f"Switched to {request.provider}" + (f" with model {request.model}" if request.model else ""),
|
|
"provider": request.provider,
|
|
"model": request.model or connection_test.get("current_model"),
|
|
"available": True
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error switching LLM provider: {str(e)}")
|
|
|
|
|
|
@router.post("/ollama-pull-model")
|
|
async def pull_ollama_model(model: str = "llama3.2"):
|
|
"""Pull a model in Ollama"""
|
|
try:
|
|
import aiohttp
|
|
import os
|
|
|
|
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(f"{ollama_url}/api/pull", json={"name": model}) as response:
|
|
if response.status == 200:
|
|
return {
|
|
"success": True,
|
|
"message": f"Successfully pulled model {model}",
|
|
"model": model
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=500, detail=f"Failed to pull model: {response.status}")
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error pulling Ollama model: {str(e)}")
|
|
|
|
|
|
@router.get("/ollama-models")
|
|
async def get_ollama_models():
|
|
"""Get available Ollama models"""
|
|
try:
|
|
import aiohttp
|
|
import os
|
|
|
|
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(f"{ollama_url}/api/tags") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
models = [model["name"] for model in data.get("models", [])]
|
|
return {
|
|
"models": models,
|
|
"total_models": len(models),
|
|
"ollama_url": ollama_url
|
|
}
|
|
else:
|
|
return {
|
|
"models": [],
|
|
"total_models": 0,
|
|
"error": f"Ollama not available (status: {response.status})"
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"models": [],
|
|
"total_models": 0,
|
|
"error": f"Error connecting to Ollama: {str(e)}"
|
|
} |