fixed llm operations post refactor
This commit is contained in:
parent
a6fb367ed4
commit
e9a5f54d3a
7 changed files with 1089 additions and 2459 deletions
|
@ -1,10 +1,20 @@
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from sqlalchemy.pool import QueuePool
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
|
|
||||||
|
|
||||||
# Database setup
|
# Database setup with connection pooling
|
||||||
engine = create_engine(settings.DATABASE_URL)
|
engine = create_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
poolclass=QueuePool,
|
||||||
|
pool_size=10, # Number of connections to maintain in the pool
|
||||||
|
max_overflow=20, # Additional connections that can be created on demand
|
||||||
|
pool_timeout=30, # Timeout for getting connection from pool
|
||||||
|
pool_recycle=3600, # Recycle connections after 1 hour
|
||||||
|
pool_pre_ping=True, # Validate connections before use
|
||||||
|
echo=False # Set to True for SQL query logging
|
||||||
|
)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,8 @@ Script to delete all SIGMA rules from the database
|
||||||
This will clear existing rules so they can be regenerated with the improved LLM client
|
This will clear existing rules so they can be regenerated with the improved LLM client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from models import SigmaRule, SessionLocal
|
from models import SigmaRule
|
||||||
|
from config.database import SessionLocal
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
|
|
@ -108,7 +108,10 @@ class LLMClient:
|
||||||
self.llm = Ollama(
|
self.llm = Ollama(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
temperature=0.1
|
temperature=0.1,
|
||||||
|
num_ctx=4096, # Context window size
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=40
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.llm:
|
if self.llm:
|
||||||
|
@ -186,9 +189,22 @@ class LLMClient:
|
||||||
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
|
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
|
||||||
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
|
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
|
||||||
|
|
||||||
# Generate the response
|
# Generate the response with timeout handling
|
||||||
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
|
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
|
||||||
response = await chain.ainvoke(input_data)
|
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
# Add timeout wrapper around the LLM call
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
chain.ainvoke(input_data),
|
||||||
|
timeout=150 # 2.5 minutes total timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"LLM request timed out for {cve_id}")
|
||||||
|
return None
|
||||||
|
except Exception as llm_error:
|
||||||
|
logger.error(f"LLM generation error for {cve_id}: {llm_error}")
|
||||||
|
return None
|
||||||
|
|
||||||
# Debug: Log raw LLM response
|
# Debug: Log raw LLM response
|
||||||
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
||||||
|
@ -228,9 +244,29 @@ class LLMClient:
|
||||||
|
|
||||||
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
||||||
|
|
||||||
**CRITICAL: You must follow the exact SIGMA specification format:**
|
**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:**
|
||||||
|
The official SIGMA rule specification (v2.0.0) defines these requirements:
|
||||||
|
|
||||||
1. **YAML Structure Requirements:**
|
**MANDATORY Fields (must include):**
|
||||||
|
- title: Brief description (max 256 chars) - string
|
||||||
|
- logsource: Log data source specification - object with category/product/service
|
||||||
|
- detection: Search identifiers and conditions - object with selections and condition
|
||||||
|
|
||||||
|
**RECOMMENDED Fields:**
|
||||||
|
- id: Unique UUID (version 4) - string with UUID format
|
||||||
|
- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported"
|
||||||
|
- description: Detailed explanation - string
|
||||||
|
- author: Rule creator - string (use "AI Generated")
|
||||||
|
- date: Creation date - string in YYYY/MM/DD format
|
||||||
|
- modified: Last modification date - string in YYYY/MM/DD format
|
||||||
|
- references: Sources for rule derivation - array of strings (URLs)
|
||||||
|
- tags: MITRE ATT&CK techniques - array of strings
|
||||||
|
- level: Rule severity - enum: "informational", "low", "medium", "high", "critical"
|
||||||
|
- falsepositives: Known false positives - array of strings
|
||||||
|
- fields: Related fields - array of strings
|
||||||
|
- related: Related rules - array of objects with type and id
|
||||||
|
|
||||||
|
**YAML Structure Requirements:**
|
||||||
- Use UTF-8 encoding with LF line breaks
|
- Use UTF-8 encoding with LF line breaks
|
||||||
- Indent with 4 spaces (no tabs)
|
- Indent with 4 spaces (no tabs)
|
||||||
- Use lowercase keys only
|
- Use lowercase keys only
|
||||||
|
@ -238,26 +274,12 @@ class LLMClient:
|
||||||
- No quotes for numeric values
|
- No quotes for numeric values
|
||||||
- Follow proper YAML syntax
|
- Follow proper YAML syntax
|
||||||
|
|
||||||
2. **MANDATORY Fields (must include):**
|
**Detection Structure:**
|
||||||
- title: Brief description (max 256 chars)
|
|
||||||
- logsource: Log data source specification
|
|
||||||
- detection: Search identifiers and conditions
|
|
||||||
- condition: How detection elements combine
|
|
||||||
|
|
||||||
3. **RECOMMENDED Fields:**
|
|
||||||
- id: Unique UUID
|
|
||||||
- status: 'experimental' (for new rules)
|
|
||||||
- description: Detailed explanation
|
|
||||||
- author: 'AI Generated'
|
|
||||||
- date: Current date (YYYY/MM/DD)
|
|
||||||
- references: Array with CVE link
|
|
||||||
- tags: MITRE ATT&CK techniques
|
|
||||||
|
|
||||||
4. **Detection Structure:**
|
|
||||||
- Use selection blocks (selection, selection1, etc.)
|
- Use selection blocks (selection, selection1, etc.)
|
||||||
- Condition references these selections
|
- Condition references these selections
|
||||||
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
||||||
- Support wildcards (*) and value lists
|
- Support wildcards (*) and value lists
|
||||||
|
- Condition can be string expression or object with keywords
|
||||||
|
|
||||||
**ABSOLUTE REQUIREMENTS:**
|
**ABSOLUTE REQUIREMENTS:**
|
||||||
- Output ONLY valid YAML
|
- Output ONLY valid YAML
|
||||||
|
@ -1254,3 +1276,114 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting Ollama models: {e}")
|
logger.error(f"Error getting Ollama models: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""Test connection to the configured LLM provider."""
|
||||||
|
try:
|
||||||
|
if self.provider == 'openai':
|
||||||
|
api_key = os.getenv('OPENAI_API_KEY')
|
||||||
|
if not api_key:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": "OpenAI API key not configured",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test OpenAI connection without actual API call to avoid timeouts
|
||||||
|
if self.llm:
|
||||||
|
return {
|
||||||
|
"available": True,
|
||||||
|
"models": self.SUPPORTED_PROVIDERS['openai']['models'],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": "OpenAI client not initialized",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": True
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.provider == 'anthropic':
|
||||||
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||||
|
if not api_key:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": "Anthropic API key not configured",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test Anthropic connection without actual API call to avoid timeouts
|
||||||
|
if self.llm:
|
||||||
|
return {
|
||||||
|
"available": True,
|
||||||
|
"models": self.SUPPORTED_PROVIDERS['anthropic']['models'],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": "Anthropic client not initialized",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"has_api_key": True
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.provider == 'ollama':
|
||||||
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||||
|
|
||||||
|
# Test Ollama connection
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
available_models = self._get_ollama_available_models()
|
||||||
|
# Check if model is available using proper model name matching
|
||||||
|
model_available = self._check_ollama_model_available(base_url, self.model)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"available": model_available,
|
||||||
|
"models": available_models,
|
||||||
|
"current_model": self.model,
|
||||||
|
"base_url": base_url,
|
||||||
|
"error": None if model_available else f"Model {self.model} not available"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": f"Ollama server not responding (HTTP {response.status_code})",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"base_url": base_url
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": f"Cannot connect to Ollama server: {str(e)}",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model,
|
||||||
|
"base_url": base_url
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": f"Unsupported provider: {self.provider}",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"available": False,
|
||||||
|
"error": f"Connection test failed: {str(e)}",
|
||||||
|
"models": [],
|
||||||
|
"current_model": self.model
|
||||||
|
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,14 +1,23 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import func, text
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
from config.database import get_db
|
from config.database import get_db, SessionLocal
|
||||||
from models import BulkProcessingJob, CVE, SigmaRule
|
from models import BulkProcessingJob, CVE, SigmaRule
|
||||||
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
|
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
|
||||||
from services import CVEService, SigmaRuleService
|
from services import CVEService, SigmaRuleService
|
||||||
|
|
||||||
|
# Import global job tracking from main.py
|
||||||
|
import main
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["bulk-operations"])
|
router = APIRouter(prefix="/api", tags=["bulk-operations"])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bulk-seed")
|
@router.post("/bulk-seed")
|
||||||
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||||
|
@ -118,3 +127,656 @@ async def get_poc_stats(db: Session = Depends(get_db)):
|
||||||
"exploit_based_rules": exploit_based_rules,
|
"exploit_based_rules": exploit_based_rules,
|
||||||
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
|
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync-nomi-sec")
|
||||||
|
async def sync_nomi_sec(background_tasks: BackgroundTasks,
|
||||||
|
request: NomiSecSyncRequest,
|
||||||
|
db: Session = Depends(get_db)):
|
||||||
|
"""Synchronize nomi-sec PoC data"""
|
||||||
|
|
||||||
|
# Create job record
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='nomi_sec_sync',
|
||||||
|
status='pending',
|
||||||
|
job_metadata={
|
||||||
|
'cve_id': request.cve_id,
|
||||||
|
'batch_size': request.batch_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(job)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(job)
|
||||||
|
|
||||||
|
job_id = str(job.id)
|
||||||
|
main.running_jobs[job_id] = job
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def sync_task():
|
||||||
|
try:
|
||||||
|
job.status = 'running'
|
||||||
|
job.started_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
from nomi_sec_client import NomiSecClient
|
||||||
|
client = NomiSecClient(db)
|
||||||
|
|
||||||
|
if request.cve_id:
|
||||||
|
# Sync specific CVE
|
||||||
|
if main.job_cancellation_flags.get(job_id, False):
|
||||||
|
logger.info(f"Job {job_id} cancelled before starting")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await client.sync_cve_pocs(request.cve_id)
|
||||||
|
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
|
||||||
|
else:
|
||||||
|
# Sync all CVEs with cancellation support
|
||||||
|
result = await client.bulk_sync_all_cves(
|
||||||
|
batch_size=request.batch_size,
|
||||||
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||||
|
)
|
||||||
|
logger.info(f"Nomi-sec bulk sync completed: {result}")
|
||||||
|
|
||||||
|
# Update job status if not cancelled
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
job.status = 'completed'
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
job.status = 'failed'
|
||||||
|
job.error_message = str(e)
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.error(f"Nomi-sec sync failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Clean up tracking
|
||||||
|
main.running_jobs.pop(job_id, None)
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
|
||||||
|
background_tasks.add_task(sync_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"cve_id": request.cve_id,
|
||||||
|
"batch_size": request.batch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync-github-pocs")
|
||||||
|
async def sync_github_pocs(background_tasks: BackgroundTasks,
|
||||||
|
request: GitHubPoCSyncRequest,
|
||||||
|
db: Session = Depends(get_db)):
|
||||||
|
"""Synchronize GitHub PoC data"""
|
||||||
|
|
||||||
|
# Create job record
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='github_poc_sync',
|
||||||
|
status='pending',
|
||||||
|
job_metadata={
|
||||||
|
'cve_id': request.cve_id,
|
||||||
|
'batch_size': request.batch_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(job)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(job)
|
||||||
|
|
||||||
|
job_id = str(job.id)
|
||||||
|
main.running_jobs[job_id] = job
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def sync_task():
|
||||||
|
try:
|
||||||
|
job.status = 'running'
|
||||||
|
job.started_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
from mcdevitt_poc_client import GitHubPoCClient
|
||||||
|
client = GitHubPoCClient(db)
|
||||||
|
|
||||||
|
if request.cve_id:
|
||||||
|
# Sync specific CVE
|
||||||
|
if main.job_cancellation_flags.get(job_id, False):
|
||||||
|
logger.info(f"Job {job_id} cancelled before starting")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await client.sync_cve_pocs(request.cve_id)
|
||||||
|
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
|
||||||
|
else:
|
||||||
|
# Sync all CVEs with cancellation support
|
||||||
|
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
|
||||||
|
logger.info(f"GitHub PoC bulk sync completed: {result}")
|
||||||
|
|
||||||
|
# Update job status if not cancelled
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
job.status = 'completed'
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
job.status = 'failed'
|
||||||
|
job.error_message = str(e)
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.error(f"GitHub PoC sync failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Clean up tracking
|
||||||
|
main.running_jobs.pop(job_id, None)
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
|
||||||
|
background_tasks.add_task(sync_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"cve_id": request.cve_id,
|
||||||
|
"batch_size": request.batch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync-exploitdb")
|
||||||
|
async def sync_exploitdb(background_tasks: BackgroundTasks,
|
||||||
|
request: ExploitDBSyncRequest,
|
||||||
|
db: Session = Depends(get_db)):
|
||||||
|
"""Synchronize ExploitDB data from git mirror"""
|
||||||
|
|
||||||
|
# Create job record
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='exploitdb_sync',
|
||||||
|
status='pending',
|
||||||
|
job_metadata={
|
||||||
|
'cve_id': request.cve_id,
|
||||||
|
'batch_size': request.batch_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(job)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(job)
|
||||||
|
|
||||||
|
job_id = str(job.id)
|
||||||
|
main.running_jobs[job_id] = job
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def sync_task():
|
||||||
|
# Create a new database session for the background task
|
||||||
|
task_db = SessionLocal()
|
||||||
|
try:
|
||||||
|
# Get the job in the new session
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if not task_job:
|
||||||
|
logger.error(f"Job {job_id} not found in task session")
|
||||||
|
return
|
||||||
|
|
||||||
|
task_job.status = 'running'
|
||||||
|
task_job.started_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
from exploitdb_client_local import ExploitDBLocalClient
|
||||||
|
client = ExploitDBLocalClient(task_db)
|
||||||
|
|
||||||
|
if request.cve_id:
|
||||||
|
# Sync specific CVE
|
||||||
|
if main.job_cancellation_flags.get(job_id, False):
|
||||||
|
logger.info(f"Job {job_id} cancelled before starting")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await client.sync_cve_exploits(request.cve_id)
|
||||||
|
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
|
||||||
|
else:
|
||||||
|
# Sync all CVEs with cancellation support
|
||||||
|
result = await client.bulk_sync_exploitdb(
|
||||||
|
batch_size=request.batch_size,
|
||||||
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||||
|
)
|
||||||
|
logger.info(f"ExploitDB bulk sync completed: {result}")
|
||||||
|
|
||||||
|
# Update job status if not cancelled
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
task_job.status = 'completed'
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
# Get the job again in case it was modified
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if task_job:
|
||||||
|
task_job.status = 'failed'
|
||||||
|
task_job.error_message = str(e)
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
logger.error(f"ExploitDB sync failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Clean up tracking and close the task session
|
||||||
|
main.running_jobs.pop(job_id, None)
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
task_db.close()
|
||||||
|
|
||||||
|
background_tasks.add_task(sync_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"cve_id": request.cve_id,
|
||||||
|
"batch_size": request.batch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync-cisa-kev")
|
||||||
|
async def sync_cisa_kev(background_tasks: BackgroundTasks,
|
||||||
|
request: CISAKEVSyncRequest,
|
||||||
|
db: Session = Depends(get_db)):
|
||||||
|
"""Synchronize CISA Known Exploited Vulnerabilities data"""
|
||||||
|
|
||||||
|
# Create job record
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='cisa_kev_sync',
|
||||||
|
status='pending',
|
||||||
|
job_metadata={
|
||||||
|
'cve_id': request.cve_id,
|
||||||
|
'batch_size': request.batch_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(job)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(job)
|
||||||
|
|
||||||
|
job_id = str(job.id)
|
||||||
|
main.running_jobs[job_id] = job
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def sync_task():
|
||||||
|
# Create a new database session for the background task
|
||||||
|
task_db = SessionLocal()
|
||||||
|
try:
|
||||||
|
# Get the job in the new session
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if not task_job:
|
||||||
|
logger.error(f"Job {job_id} not found in task session")
|
||||||
|
return
|
||||||
|
|
||||||
|
task_job.status = 'running'
|
||||||
|
task_job.started_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
from cisa_kev_client import CISAKEVClient
|
||||||
|
client = CISAKEVClient(task_db)
|
||||||
|
|
||||||
|
if request.cve_id:
|
||||||
|
# Sync specific CVE
|
||||||
|
if main.job_cancellation_flags.get(job_id, False):
|
||||||
|
logger.info(f"Job {job_id} cancelled before starting")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await client.sync_cve_kev_data(request.cve_id)
|
||||||
|
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
|
||||||
|
else:
|
||||||
|
# Sync all CVEs with cancellation support
|
||||||
|
result = await client.bulk_sync_kev_data(
|
||||||
|
batch_size=request.batch_size,
|
||||||
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||||
|
)
|
||||||
|
logger.info(f"CISA KEV bulk sync completed: {result}")
|
||||||
|
|
||||||
|
# Update job status if not cancelled
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
task_job.status = 'completed'
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
# Get the job again in case it was modified
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if task_job:
|
||||||
|
task_job.status = 'failed'
|
||||||
|
task_job.error_message = str(e)
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
logger.error(f"CISA KEV sync failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Clean up tracking and close the task session
|
||||||
|
main.running_jobs.pop(job_id, None)
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
task_db.close()
|
||||||
|
|
||||||
|
background_tasks.add_task(sync_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"cve_id": request.cve_id,
|
||||||
|
"batch_size": request.batch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync-references")
|
||||||
|
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||||
|
"""Start reference data synchronization"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reference_client import ReferenceClient
|
||||||
|
client = ReferenceClient(db)
|
||||||
|
|
||||||
|
# Create job ID
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Add job to tracking
|
||||||
|
main.running_jobs[job_id] = {
|
||||||
|
'type': 'reference_sync',
|
||||||
|
'status': 'running',
|
||||||
|
'cve_id': request.cve_id,
|
||||||
|
'batch_size': request.batch_size,
|
||||||
|
'max_cves': request.max_cves,
|
||||||
|
'force_resync': request.force_resync,
|
||||||
|
'started_at': datetime.utcnow()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create cancellation flag
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def sync_task():
|
||||||
|
try:
|
||||||
|
if request.cve_id:
|
||||||
|
# Single CVE sync
|
||||||
|
result = await client.sync_cve_references(request.cve_id)
|
||||||
|
main.running_jobs[job_id]['result'] = result
|
||||||
|
main.running_jobs[job_id]['status'] = 'completed'
|
||||||
|
else:
|
||||||
|
# Bulk sync
|
||||||
|
result = await client.bulk_sync_references(
|
||||||
|
batch_size=request.batch_size,
|
||||||
|
max_cves=request.max_cves,
|
||||||
|
force_resync=request.force_resync,
|
||||||
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||||
|
)
|
||||||
|
main.running_jobs[job_id]['result'] = result
|
||||||
|
main.running_jobs[job_id]['status'] = 'completed'
|
||||||
|
|
||||||
|
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Reference sync task failed: {e}")
|
||||||
|
main.running_jobs[job_id]['status'] = 'failed'
|
||||||
|
main.running_jobs[job_id]['error'] = str(e)
|
||||||
|
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
||||||
|
finally:
|
||||||
|
# Clean up cancellation flag
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
|
||||||
|
background_tasks.add_task(sync_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"cve_id": request.cve_id,
|
||||||
|
"batch_size": request.batch_size,
|
||||||
|
"max_cves": request.max_cves,
|
||||||
|
"force_resync": request.force_resync
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start reference sync: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/exploitdb-stats")
|
||||||
|
async def get_exploitdb_stats(db: Session = Depends(get_db)):
|
||||||
|
"""Get ExploitDB-related statistics"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from exploitdb_client_local import ExploitDBLocalClient
|
||||||
|
client = ExploitDBLocalClient(db)
|
||||||
|
|
||||||
|
# Get sync status
|
||||||
|
status = await client.get_exploitdb_sync_status()
|
||||||
|
|
||||||
|
# Get quality distribution from ExploitDB data
|
||||||
|
quality_distribution = {}
|
||||||
|
cves_with_exploitdb = db.query(CVE).filter(
|
||||||
|
text("poc_data::text LIKE '%\"exploitdb\"%'")
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for cve in cves_with_exploitdb:
|
||||||
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
||||||
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
||||||
|
for exploit in exploits:
|
||||||
|
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
|
||||||
|
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
|
||||||
|
|
||||||
|
# Get category distribution
|
||||||
|
category_distribution = {}
|
||||||
|
for cve in cves_with_exploitdb:
|
||||||
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
||||||
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
||||||
|
for exploit in exploits:
|
||||||
|
category = exploit.get('category', 'unknown')
|
||||||
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"exploitdb_sync_status": status,
|
||||||
|
"quality_distribution": quality_distribution,
|
||||||
|
"category_distribution": category_distribution,
|
||||||
|
"total_exploitdb_cves": len(cves_with_exploitdb),
|
||||||
|
"total_exploits": sum(
|
||||||
|
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
|
||||||
|
for cve in cves_with_exploitdb
|
||||||
|
if cve.poc_data and 'exploitdb' in cve.poc_data
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting ExploitDB stats: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/github-poc-stats")
|
||||||
|
async def get_github_poc_stats(db: Session = Depends(get_db)):
|
||||||
|
"""Get GitHub PoC-related statistics"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get basic statistics
|
||||||
|
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
|
||||||
|
cves_with_github_pocs = db.query(CVE).filter(
|
||||||
|
CVE.poc_data.isnot(None), # Check if poc_data exists
|
||||||
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||||
|
).count()
|
||||||
|
|
||||||
|
# Get quality distribution
|
||||||
|
quality_distribution = {}
|
||||||
|
try:
|
||||||
|
quality_results = db.query(
|
||||||
|
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
|
||||||
|
func.count().label('count')
|
||||||
|
).filter(
|
||||||
|
CVE.poc_data.isnot(None),
|
||||||
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||||
|
).group_by('tier').all()
|
||||||
|
|
||||||
|
for tier, count in quality_results:
|
||||||
|
if tier:
|
||||||
|
quality_distribution[tier] = count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting quality distribution: {e}")
|
||||||
|
quality_distribution = {}
|
||||||
|
|
||||||
|
# Calculate average quality score
|
||||||
|
try:
|
||||||
|
from sqlalchemy import Integer
|
||||||
|
avg_quality = db.query(
|
||||||
|
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
|
||||||
|
).filter(
|
||||||
|
CVE.poc_data.isnot(None),
|
||||||
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||||
|
).scalar() or 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating average quality: {e}")
|
||||||
|
avg_quality = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'github_poc_rules': github_poc_rules,
|
||||||
|
'cves_with_github_pocs': cves_with_github_pocs,
|
||||||
|
'quality_distribution': quality_distribution,
|
||||||
|
'average_quality_score': float(avg_quality) if avg_quality else 0,
|
||||||
|
'source': 'github_poc'
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting GitHub PoC stats: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cisa-kev-stats")
|
||||||
|
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
|
||||||
|
"""Get CISA KEV-related statistics"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cisa_kev_client import CISAKEVClient
|
||||||
|
client = CISAKEVClient(db)
|
||||||
|
|
||||||
|
# Get sync status
|
||||||
|
status = await client.get_kev_sync_status()
|
||||||
|
|
||||||
|
# Get threat level distribution from CISA KEV data
|
||||||
|
threat_level_distribution = {}
|
||||||
|
cves_with_kev = db.query(CVE).filter(
|
||||||
|
text("poc_data::text LIKE '%\"cisa_kev\"%'")
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for cve in cves_with_kev:
|
||||||
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||||
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||||
|
threat_level = vuln_data.get('threat_level', 'unknown')
|
||||||
|
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
|
||||||
|
|
||||||
|
# Get vulnerability category distribution
|
||||||
|
category_distribution = {}
|
||||||
|
for cve in cves_with_kev:
|
||||||
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||||
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||||
|
category = vuln_data.get('vulnerability_category', 'unknown')
|
||||||
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
||||||
|
|
||||||
|
# Get ransomware usage statistics
|
||||||
|
ransomware_stats = {'known': 0, 'unknown': 0}
|
||||||
|
for cve in cves_with_kev:
|
||||||
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||||
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||||
|
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
|
||||||
|
if ransomware_use == 'known':
|
||||||
|
ransomware_stats['known'] += 1
|
||||||
|
else:
|
||||||
|
ransomware_stats['unknown'] += 1
|
||||||
|
|
||||||
|
# Calculate average threat score
|
||||||
|
threat_scores = []
|
||||||
|
for cve in cves_with_kev:
|
||||||
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||||
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||||
|
threat_score = vuln_data.get('threat_score', 0)
|
||||||
|
if threat_score:
|
||||||
|
threat_scores.append(threat_score)
|
||||||
|
|
||||||
|
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cisa_kev_sync_status": status,
|
||||||
|
"threat_level_distribution": threat_level_distribution,
|
||||||
|
"category_distribution": category_distribution,
|
||||||
|
"ransomware_stats": ransomware_stats,
|
||||||
|
"average_threat_score": round(avg_threat_score, 2),
|
||||||
|
"total_kev_cves": len(cves_with_kev),
|
||||||
|
"total_with_threat_scores": len(threat_scores)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting CISA KEV stats: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cancel-job/{job_id}")
|
||||||
|
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
||||||
|
"""Cancel a running job"""
|
||||||
|
try:
|
||||||
|
# Find the job in the database
|
||||||
|
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
|
||||||
|
if not job:
|
||||||
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
|
if job.status not in ['pending', 'running']:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
|
||||||
|
|
||||||
|
# Set cancellation flag
|
||||||
|
main.job_cancellation_flags[job_id] = True
|
||||||
|
|
||||||
|
# Update job status
|
||||||
|
job.status = 'cancelled'
|
||||||
|
job.cancelled_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Job {job_id} ({job.job_type}) cancelled by user")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Job {job_id} cancelled successfully",
|
||||||
|
"job_id": job_id,
|
||||||
|
"job_type": job.job_type
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling job {job_id}: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/running-jobs")
|
||||||
|
async def get_running_jobs(db: Session = Depends(get_db)):
|
||||||
|
"""Get currently running jobs"""
|
||||||
|
try:
|
||||||
|
# Get running jobs from database
|
||||||
|
running_jobs_db = db.query(BulkProcessingJob).filter(
|
||||||
|
BulkProcessingJob.status.in_(['pending', 'running'])
|
||||||
|
).order_by(BulkProcessingJob.created_at.desc()).all()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for job in running_jobs_db:
|
||||||
|
result.append({
|
||||||
|
'id': str(job.id),
|
||||||
|
'job_type': job.job_type,
|
||||||
|
'status': job.status,
|
||||||
|
'year': job.year,
|
||||||
|
'total_items': job.total_items,
|
||||||
|
'processed_items': job.processed_items,
|
||||||
|
'failed_items': job.failed_items,
|
||||||
|
'error_message': job.error_message,
|
||||||
|
'started_at': job.started_at,
|
||||||
|
'created_at': job.created_at,
|
||||||
|
'can_cancel': job.status in ['pending', 'running']
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting running jobs: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -1,17 +1,20 @@
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from fastapi import APIRouter, HTTPException, Depends
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import logging
|
||||||
|
|
||||||
from config.database import get_db
|
from config.database import get_db
|
||||||
from models import CVE, SigmaRule
|
from models import CVE, SigmaRule
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["llm-operations"])
|
router = APIRouter(prefix="/api", tags=["llm-operations"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMRuleRequest(BaseModel):
|
class LLMRuleRequest(BaseModel):
|
||||||
cve_id: str
|
cve_id: str = None # Optional for bulk operations
|
||||||
poc_content: str = ""
|
poc_content: str = ""
|
||||||
|
force: bool = False # For bulk operations
|
||||||
|
|
||||||
|
|
||||||
class LLMSwitchRequest(BaseModel):
|
class LLMSwitchRequest(BaseModel):
|
||||||
|
@ -20,12 +23,13 @@ class LLMSwitchRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/llm-enhanced-rules")
|
@router.post("/llm-enhanced-rules")
|
||||||
async def generate_llm_enhanced_rules(request: LLMRuleRequest, db: Session = Depends(get_db)):
|
async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||||
"""Generate SIGMA rules using LLM AI analysis"""
|
"""Generate SIGMA rules using LLM AI analysis"""
|
||||||
try:
|
try:
|
||||||
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
||||||
|
|
||||||
# Get CVE
|
if request.cve_id:
|
||||||
|
# Single CVE operation
|
||||||
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
|
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
|
||||||
if not cve:
|
if not cve:
|
||||||
raise HTTPException(status_code=404, detail="CVE not found")
|
raise HTTPException(status_code=404, detail="CVE not found")
|
||||||
|
@ -47,6 +51,143 @@ async def generate_llm_enhanced_rules(request: LLMRuleRequest, db: Session = Dep
|
||||||
"error": result.get('error', 'Unknown error'),
|
"error": result.get('error', 'Unknown error'),
|
||||||
"cve_id": request.cve_id
|
"cve_id": request.cve_id
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
# Bulk operation - run in background with job tracking
|
||||||
|
from models import BulkProcessingJob
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
import main
|
||||||
|
|
||||||
|
# Create job record
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='llm_rule_generation',
|
||||||
|
status='pending',
|
||||||
|
job_metadata={
|
||||||
|
'force': request.force
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(job)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(job)
|
||||||
|
|
||||||
|
job_id = str(job.id)
|
||||||
|
main.running_jobs[job_id] = job
|
||||||
|
main.job_cancellation_flags[job_id] = False
|
||||||
|
|
||||||
|
async def bulk_llm_generation_task():
|
||||||
|
# Create a new database session for the background task
|
||||||
|
from config.database import SessionLocal
|
||||||
|
task_db = SessionLocal()
|
||||||
|
try:
|
||||||
|
# Get the job in the new session
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if not task_job:
|
||||||
|
logger.error(f"Job {job_id} not found in task session")
|
||||||
|
return
|
||||||
|
|
||||||
|
task_job.status = 'running'
|
||||||
|
task_job.started_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
generator = EnhancedSigmaGenerator(task_db)
|
||||||
|
|
||||||
|
# Get CVEs with PoC data - limit to small batch initially
|
||||||
|
if request.force:
|
||||||
|
# Process all CVEs with PoC data - but limit to prevent system overload
|
||||||
|
cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all()
|
||||||
|
else:
|
||||||
|
# Only process CVEs without existing LLM-generated rules - small batch
|
||||||
|
cves_to_process = task_db.query(CVE).filter(
|
||||||
|
CVE.poc_count > 0
|
||||||
|
).limit(10).all()
|
||||||
|
|
||||||
|
# Filter out CVEs that already have LLM-generated rules
|
||||||
|
existing_llm_rules = task_db.query(SigmaRule).filter(
|
||||||
|
SigmaRule.detection_type.like('llm_%')
|
||||||
|
).all()
|
||||||
|
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
|
||||||
|
cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids]
|
||||||
|
|
||||||
|
# Update job with total items
|
||||||
|
task_job.total_items = len(cves_to_process)
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
rules_generated = 0
|
||||||
|
rules_updated = 0
|
||||||
|
failures = 0
|
||||||
|
|
||||||
|
logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})")
|
||||||
|
|
||||||
|
for i, cve in enumerate(cves_to_process):
|
||||||
|
# Check for cancellation
|
||||||
|
if main.job_cancellation_flags.get(job_id, False):
|
||||||
|
logger.info(f"Job {job_id} cancelled, stopping LLM generation")
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}")
|
||||||
|
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
||||||
|
|
||||||
|
if result.get('success'):
|
||||||
|
if result.get('updated'):
|
||||||
|
rules_updated += 1
|
||||||
|
else:
|
||||||
|
rules_generated += 1
|
||||||
|
logger.info(f"Successfully generated rule for {cve.cve_id}")
|
||||||
|
else:
|
||||||
|
failures += 1
|
||||||
|
logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}")
|
||||||
|
|
||||||
|
# Update progress
|
||||||
|
task_job.processed_items = i + 1
|
||||||
|
task_job.failed_items = failures
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failures += 1
|
||||||
|
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
||||||
|
|
||||||
|
# Update progress
|
||||||
|
task_job.processed_items = i + 1
|
||||||
|
task_job.failed_items = failures
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
# Continue with next CVE even if one fails
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update job status if not cancelled
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
task_job.status = 'completed'
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not main.job_cancellation_flags.get(job_id, False):
|
||||||
|
# Get the job again in case it was modified
|
||||||
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||||
|
if task_job:
|
||||||
|
task_job.status = 'failed'
|
||||||
|
task_job.error_message = str(e)
|
||||||
|
task_job.completed_at = datetime.utcnow()
|
||||||
|
task_db.commit()
|
||||||
|
|
||||||
|
logger.error(f"Bulk LLM rule generation failed: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up tracking and close the task session
|
||||||
|
main.running_jobs.pop(job_id, None)
|
||||||
|
main.job_cancellation_flags.pop(job_id, None)
|
||||||
|
task_db.close()
|
||||||
|
|
||||||
|
background_tasks.add_task(bulk_llm_generation_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "Bulk LLM-enhanced rule generation started",
|
||||||
|
"status": "started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"force": request.force
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Error generating LLM-enhanced rule: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Error generating LLM-enhanced rule: {str(e)}")
|
||||||
|
@ -58,68 +199,97 @@ async def get_llm_status():
|
||||||
try:
|
try:
|
||||||
from llm_client import LLMClient
|
from llm_client import LLMClient
|
||||||
|
|
||||||
# Test all providers
|
# Get current configuration first
|
||||||
providers_status = {}
|
current_client = LLMClient()
|
||||||
|
|
||||||
|
# Build available providers list in the format frontend expects
|
||||||
|
available_providers = []
|
||||||
|
|
||||||
# Test Ollama
|
# Test Ollama
|
||||||
try:
|
try:
|
||||||
ollama_client = LLMClient(provider="ollama")
|
ollama_client = LLMClient(provider="ollama")
|
||||||
ollama_status = await ollama_client.test_connection()
|
ollama_status = await ollama_client.test_connection()
|
||||||
providers_status["ollama"] = {
|
available_providers.append({
|
||||||
|
"name": "ollama",
|
||||||
"available": ollama_status.get("available", False),
|
"available": ollama_status.get("available", False),
|
||||||
|
"default_model": ollama_status.get("current_model", "llama3.2"),
|
||||||
"models": ollama_status.get("models", []),
|
"models": ollama_status.get("models", []),
|
||||||
"current_model": ollama_status.get("current_model"),
|
|
||||||
"base_url": ollama_status.get("base_url")
|
"base_url": ollama_status.get("base_url")
|
||||||
}
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
providers_status["ollama"] = {"available": False, "error": str(e)}
|
available_providers.append({
|
||||||
|
"name": "ollama",
|
||||||
|
"available": False,
|
||||||
|
"default_model": "llama3.2",
|
||||||
|
"models": [],
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
# Test OpenAI
|
# Test OpenAI
|
||||||
try:
|
try:
|
||||||
openai_client = LLMClient(provider="openai")
|
openai_client = LLMClient(provider="openai")
|
||||||
openai_status = await openai_client.test_connection()
|
openai_status = await openai_client.test_connection()
|
||||||
providers_status["openai"] = {
|
available_providers.append({
|
||||||
|
"name": "openai",
|
||||||
"available": openai_status.get("available", False),
|
"available": openai_status.get("available", False),
|
||||||
|
"default_model": openai_status.get("current_model", "gpt-4o-mini"),
|
||||||
"models": openai_status.get("models", []),
|
"models": openai_status.get("models", []),
|
||||||
"current_model": openai_status.get("current_model"),
|
|
||||||
"has_api_key": openai_status.get("has_api_key", False)
|
"has_api_key": openai_status.get("has_api_key", False)
|
||||||
}
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
providers_status["openai"] = {"available": False, "error": str(e)}
|
available_providers.append({
|
||||||
|
"name": "openai",
|
||||||
|
"available": False,
|
||||||
|
"default_model": "gpt-4o-mini",
|
||||||
|
"models": [],
|
||||||
|
"has_api_key": False,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
# Test Anthropic
|
# Test Anthropic
|
||||||
try:
|
try:
|
||||||
anthropic_client = LLMClient(provider="anthropic")
|
anthropic_client = LLMClient(provider="anthropic")
|
||||||
anthropic_status = await anthropic_client.test_connection()
|
anthropic_status = await anthropic_client.test_connection()
|
||||||
providers_status["anthropic"] = {
|
available_providers.append({
|
||||||
|
"name": "anthropic",
|
||||||
"available": anthropic_status.get("available", False),
|
"available": anthropic_status.get("available", False),
|
||||||
|
"default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"),
|
||||||
"models": anthropic_status.get("models", []),
|
"models": anthropic_status.get("models", []),
|
||||||
"current_model": anthropic_status.get("current_model"),
|
|
||||||
"has_api_key": anthropic_status.get("has_api_key", False)
|
"has_api_key": anthropic_status.get("has_api_key", False)
|
||||||
}
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
providers_status["anthropic"] = {"available": False, "error": str(e)}
|
available_providers.append({
|
||||||
|
"name": "anthropic",
|
||||||
|
"available": False,
|
||||||
|
"default_model": "claude-3-5-sonnet-20241022",
|
||||||
|
"models": [],
|
||||||
|
"has_api_key": False,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
# Get current configuration
|
# Determine overall status
|
||||||
current_client = LLMClient()
|
any_available = any(p.get("available") for p in available_providers)
|
||||||
current_config = {
|
status = "ready" if any_available else "not_ready"
|
||||||
"current_provider": current_client.provider,
|
|
||||||
"current_model": current_client.model,
|
|
||||||
"default_provider": "ollama"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
# Return in the format the frontend expects
|
||||||
return {
|
return {
|
||||||
"providers": providers_status,
|
"status": status,
|
||||||
"configuration": current_config,
|
"current_provider": {
|
||||||
"status": "operational" if any(p.get("available") for p in providers_status.values()) else "no_providers_available"
|
"provider": current_client.provider,
|
||||||
|
"model": current_client.model
|
||||||
|
},
|
||||||
|
"available_providers": available_providers
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"providers": {},
|
"current_provider": {
|
||||||
"configuration": {}
|
"provider": "unknown",
|
||||||
|
"model": "unknown"
|
||||||
|
},
|
||||||
|
"available_providers": []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,22 @@ function App() {
|
||||||
return runningJobTypes.has(jobType);
|
return runningJobTypes.has(jobType);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper function to format job types for display
|
||||||
|
const formatJobType = (jobType) => {
|
||||||
|
const jobTypeMap = {
|
||||||
|
'llm_rule_generation': 'LLM Rule Generation',
|
||||||
|
'rule_regeneration': 'Rule Regeneration',
|
||||||
|
'bulk_seed': 'Bulk Seed',
|
||||||
|
'incremental_update': 'Incremental Update',
|
||||||
|
'nomi_sec_sync': 'Nomi-Sec Sync',
|
||||||
|
'github_poc_sync': 'GitHub PoC Sync',
|
||||||
|
'exploitdb_sync': 'ExploitDB Sync',
|
||||||
|
'cisa_kev_sync': 'CISA KEV Sync',
|
||||||
|
'reference_sync': 'Reference Sync'
|
||||||
|
};
|
||||||
|
return jobTypeMap[jobType] || jobType;
|
||||||
|
};
|
||||||
|
|
||||||
const isBulkSeedRunning = () => {
|
const isBulkSeedRunning = () => {
|
||||||
return isJobTypeRunning('nvd_bulk_seed') || isJobTypeRunning('bulk_seed');
|
return isJobTypeRunning('nvd_bulk_seed') || isJobTypeRunning('bulk_seed');
|
||||||
};
|
};
|
||||||
|
@ -262,7 +278,7 @@ function App() {
|
||||||
try {
|
try {
|
||||||
const response = await axios.post('http://localhost:8000/api/sync-references', {
|
const response = await axios.post('http://localhost:8000/api/sync-references', {
|
||||||
batch_size: 30,
|
batch_size: 30,
|
||||||
max_cves: 100,
|
max_cves: null,
|
||||||
force_resync: false
|
force_resync: false
|
||||||
});
|
});
|
||||||
console.log('Reference sync response:', response.data);
|
console.log('Reference sync response:', response.data);
|
||||||
|
@ -291,9 +307,20 @@ function App() {
|
||||||
force: force
|
force: force
|
||||||
});
|
});
|
||||||
console.log('LLM rule generation response:', response.data);
|
console.log('LLM rule generation response:', response.data);
|
||||||
|
|
||||||
|
// For bulk operations, the job runs in background - no need to wait
|
||||||
|
if (response.data.status === 'started') {
|
||||||
|
console.log(`LLM rule generation started as background job: ${response.data.job_id}`);
|
||||||
|
alert(`LLM rule generation started in background. You can monitor progress in the Bulk Jobs tab.`);
|
||||||
|
// Refresh data to show the new job status
|
||||||
fetchData();
|
fetchData();
|
||||||
|
} else {
|
||||||
|
// For single CVE operations, refresh data after completion
|
||||||
|
fetchData();
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error generating LLM-enhanced rules:', error);
|
console.error('Error generating LLM-enhanced rules:', error);
|
||||||
|
alert('Error generating LLM-enhanced rules. Please check the console for details.');
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1062,7 +1089,7 @@ function App() {
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<div className="flex items-center space-x-3">
|
<div className="flex items-center space-x-3">
|
||||||
<h3 className="text-lg font-medium text-gray-900">{job.job_type}</h3>
|
<h3 className="text-lg font-medium text-gray-900">{formatJobType(job.job_type)}</h3>
|
||||||
<span className={`inline-flex px-2 py-1 text-xs font-semibold rounded-full ${
|
<span className={`inline-flex px-2 py-1 text-xs font-semibold rounded-full ${
|
||||||
job.status === 'running' ? 'bg-blue-100 text-blue-800' :
|
job.status === 'running' ? 'bg-blue-100 text-blue-800' :
|
||||||
'bg-gray-100 text-gray-800'
|
'bg-gray-100 text-gray-800'
|
||||||
|
|
Loading…
Add table
Reference in a new issue