from typing import Dict, Any from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from sqlalchemy.orm import Session from pydantic import BaseModel import logging from config.database import get_db from models import CVE, SigmaRule router = APIRouter(prefix="/api", tags=["llm-operations"]) logger = logging.getLogger(__name__) class LLMRuleRequest(BaseModel): cve_id: str = None # Optional for bulk operations poc_content: str = "" force: bool = False # For bulk operations class LLMSwitchRequest(BaseModel): provider: str model: str = "" @router.post("/llm-enhanced-rules") async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Generate SIGMA rules using LLM AI analysis""" try: from enhanced_sigma_generator import EnhancedSigmaGenerator if request.cve_id: # Single CVE operation cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first() if not cve: raise HTTPException(status_code=404, detail="CVE not found") # Generate enhanced rule using LLM generator = EnhancedSigmaGenerator(db) result = await generator.generate_enhanced_rule(cve, use_llm=True) if result.get('success'): return { "success": True, "message": f"Generated LLM-enhanced rule for {request.cve_id}", "rule_id": result.get('rule_id'), "generation_method": "llm_enhanced" } else: return { "success": False, "error": result.get('error', 'Unknown error'), "cve_id": request.cve_id } else: # Bulk operation - run in background with job tracking from models import BulkProcessingJob import uuid from datetime import datetime import main # Create job record job = BulkProcessingJob( job_type='llm_rule_generation', status='pending', job_metadata={ 'force': request.force } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) main.running_jobs[job_id] = job main.job_cancellation_flags[job_id] = False async def bulk_llm_generation_task(): # Create a new database session for the background task from config.database import SessionLocal task_db = SessionLocal() try: # Get the job in the new session task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if not task_job: logger.error(f"Job {job_id} not found in task session") return task_job.status = 'running' task_job.started_at = datetime.utcnow() task_db.commit() generator = EnhancedSigmaGenerator(task_db) # Get CVEs with PoC data - limit to small batch initially if request.force: # Process all CVEs with PoC data - but limit to prevent system overload cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all() else: # Only process CVEs without existing LLM-generated rules - small batch cves_to_process = task_db.query(CVE).filter( CVE.poc_count > 0 ).limit(10).all() # Filter out CVEs that already have LLM-generated rules existing_llm_rules = task_db.query(SigmaRule).filter( SigmaRule.detection_type.like('llm_%') ).all() existing_cve_ids = {rule.cve_id for rule in existing_llm_rules} cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids] # Update job with total items task_job.total_items = len(cves_to_process) task_db.commit() rules_generated = 0 rules_updated = 0 failures = 0 logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})") for i, cve in enumerate(cves_to_process): # Check for cancellation if main.job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled, stopping LLM generation") break try: logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}") result = await generator.generate_enhanced_rule(cve, use_llm=True) if result.get('success'): if result.get('updated'): rules_updated += 1 else: rules_generated += 1 logger.info(f"Successfully generated rule for {cve.cve_id}") else: failures += 1 logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}") # Update progress task_job.processed_items = i + 1 task_job.failed_items = failures task_db.commit() except Exception as e: failures += 1 logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") # Update progress task_job.processed_items = i + 1 task_job.failed_items = failures task_db.commit() # Continue with next CVE even if one fails continue # Update job status if not cancelled if not main.job_cancellation_flags.get(job_id, False): task_job.status = 'completed' task_job.completed_at = datetime.utcnow() task_db.commit() logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures") except Exception as e: if not main.job_cancellation_flags.get(job_id, False): # Get the job again in case it was modified task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if task_job: task_job.status = 'failed' task_job.error_message = str(e) task_job.completed_at = datetime.utcnow() task_db.commit() logger.error(f"Bulk LLM rule generation failed: {e}") finally: # Clean up tracking and close the task session main.running_jobs.pop(job_id, None) main.job_cancellation_flags.pop(job_id, None) task_db.close() background_tasks.add_task(bulk_llm_generation_task) return { "success": True, "message": "Bulk LLM-enhanced rule generation started", "status": "started", "job_id": job_id, "force": request.force } except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating LLM-enhanced rule: {str(e)}") @router.get("/llm-status") async def get_llm_status(): """Check LLM API availability and configuration for all providers""" try: from llm_client import LLMClient # Get current configuration first current_client = LLMClient() # Build available providers list in the format frontend expects available_providers = [] # Test Ollama try: ollama_client = LLMClient(provider="ollama") ollama_status = await ollama_client.test_connection() available_providers.append({ "name": "ollama", "available": ollama_status.get("available", False), "default_model": ollama_status.get("current_model", "llama3.2"), "models": ollama_status.get("models", []), "base_url": ollama_status.get("base_url") }) except Exception as e: available_providers.append({ "name": "ollama", "available": False, "default_model": "llama3.2", "models": [], "error": str(e) }) # Test OpenAI try: openai_client = LLMClient(provider="openai") openai_status = await openai_client.test_connection() available_providers.append({ "name": "openai", "available": openai_status.get("available", False), "default_model": openai_status.get("current_model", "gpt-4o-mini"), "models": openai_status.get("models", []), "has_api_key": openai_status.get("has_api_key", False) }) except Exception as e: available_providers.append({ "name": "openai", "available": False, "default_model": "gpt-4o-mini", "models": [], "has_api_key": False, "error": str(e) }) # Test Anthropic try: anthropic_client = LLMClient(provider="anthropic") anthropic_status = await anthropic_client.test_connection() available_providers.append({ "name": "anthropic", "available": anthropic_status.get("available", False), "default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"), "models": anthropic_status.get("models", []), "has_api_key": anthropic_status.get("has_api_key", False) }) except Exception as e: available_providers.append({ "name": "anthropic", "available": False, "default_model": "claude-3-5-sonnet-20241022", "models": [], "has_api_key": False, "error": str(e) }) # Determine overall status any_available = any(p.get("available") for p in available_providers) status = "ready" if any_available else "not_ready" # Return in the format the frontend expects return { "status": status, "current_provider": { "provider": current_client.provider, "model": current_client.model }, "available_providers": available_providers } except Exception as e: return { "status": "error", "error": str(e), "current_provider": { "provider": "unknown", "model": "unknown" }, "available_providers": [] } @router.post("/llm-switch") async def switch_llm_provider(request: LLMSwitchRequest): """Switch between LLM providers and models""" try: from llm_client import LLMClient # Test the new provider/model test_client = LLMClient(provider=request.provider, model=request.model) connection_test = await test_client.test_connection() if not connection_test.get("available"): raise HTTPException( status_code=400, detail=f"Provider {request.provider} with model {request.model} is not available" ) # Switch to new configuration (this would typically involve updating environment variables # or configuration files, but for now we'll just confirm the switch) return { "success": True, "message": f"Switched to {request.provider}" + (f" with model {request.model}" if request.model else ""), "provider": request.provider, "model": request.model or connection_test.get("current_model"), "available": True } except Exception as e: raise HTTPException(status_code=500, detail=f"Error switching LLM provider: {str(e)}") @router.post("/ollama-pull-model") async def pull_ollama_model(model: str = "llama3.2"): """Pull a model in Ollama""" try: import aiohttp import os ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434") async with aiohttp.ClientSession() as session: async with session.post(f"{ollama_url}/api/pull", json={"name": model}) as response: if response.status == 200: return { "success": True, "message": f"Successfully pulled model {model}", "model": model } else: raise HTTPException(status_code=500, detail=f"Failed to pull model: {response.status}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error pulling Ollama model: {str(e)}") @router.get("/ollama-models") async def get_ollama_models(): """Get available Ollama models""" try: import aiohttp import os ollama_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434") async with aiohttp.ClientSession() as session: async with session.get(f"{ollama_url}/api/tags") as response: if response.status == 200: data = await response.json() models = [model["name"] for model in data.get("models", [])] return { "models": models, "total_models": len(models), "ollama_url": ollama_url } else: return { "models": [], "total_models": 0, "error": f"Ollama not available (status: {response.status})" } except Exception as e: return { "models": [], "total_models": 0, "error": f"Error connecting to Ollama: {str(e)}" }