fixed llm operations post refactor

This commit is contained in:
Brendan McDevitt 2025-07-15 12:43:27 -05:00
parent a6fb367ed4
commit e9a5f54d3a
7 changed files with 1089 additions and 2459 deletions

View file

@ -1,10 +1,20 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from .settings import settings
# Database setup
engine = create_engine(settings.DATABASE_URL)
# Database setup with connection pooling
engine = create_engine(
settings.DATABASE_URL,
poolclass=QueuePool,
pool_size=10, # Number of connections to maintain in the pool
max_overflow=20, # Additional connections that can be created on demand
pool_timeout=30, # Timeout for getting connection from pool
pool_recycle=3600, # Recycle connections after 1 hour
pool_pre_ping=True, # Validate connections before use
echo=False # Set to True for SQL query logging
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View file

@ -4,7 +4,8 @@ Script to delete all SIGMA rules from the database
This will clear existing rules so they can be regenerated with the improved LLM client
"""
from models import SigmaRule, SessionLocal
from models import SigmaRule
from config.database import SessionLocal
import logging
# Setup logging

View file

@ -108,7 +108,10 @@ class LLMClient:
self.llm = Ollama(
model=self.model,
base_url=base_url,
temperature=0.1
temperature=0.1,
num_ctx=4096, # Context window size
top_p=0.9,
top_k=40
)
if self.llm:
@ -186,9 +189,22 @@ class LLMClient:
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
# Generate the response
# Generate the response with timeout handling
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
response = await chain.ainvoke(input_data)
import asyncio
try:
# Add timeout wrapper around the LLM call
response = await asyncio.wait_for(
chain.ainvoke(input_data),
timeout=150 # 2.5 minutes total timeout
)
except asyncio.TimeoutError:
logger.error(f"LLM request timed out for {cve_id}")
return None
except Exception as llm_error:
logger.error(f"LLM generation error for {cve_id}: {llm_error}")
return None
# Debug: Log raw LLM response
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
@ -228,36 +244,42 @@ class LLMClient:
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
**CRITICAL: You must follow the exact SIGMA specification format:**
**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:**
The official SIGMA rule specification (v2.0.0) defines these requirements:
1. **YAML Structure Requirements:**
- Use UTF-8 encoding with LF line breaks
- Indent with 4 spaces (no tabs)
- Use lowercase keys only
- Use single quotes for string values
- No quotes for numeric values
- Follow proper YAML syntax
**MANDATORY Fields (must include):**
- title: Brief description (max 256 chars) - string
- logsource: Log data source specification - object with category/product/service
- detection: Search identifiers and conditions - object with selections and condition
2. **MANDATORY Fields (must include):**
- title: Brief description (max 256 chars)
- logsource: Log data source specification
- detection: Search identifiers and conditions
- condition: How detection elements combine
**RECOMMENDED Fields:**
- id: Unique UUID (version 4) - string with UUID format
- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported"
- description: Detailed explanation - string
- author: Rule creator - string (use "AI Generated")
- date: Creation date - string in YYYY/MM/DD format
- modified: Last modification date - string in YYYY/MM/DD format
- references: Sources for rule derivation - array of strings (URLs)
- tags: MITRE ATT&CK techniques - array of strings
- level: Rule severity - enum: "informational", "low", "medium", "high", "critical"
- falsepositives: Known false positives - array of strings
- fields: Related fields - array of strings
- related: Related rules - array of objects with type and id
3. **RECOMMENDED Fields:**
- id: Unique UUID
- status: 'experimental' (for new rules)
- description: Detailed explanation
- author: 'AI Generated'
- date: Current date (YYYY/MM/DD)
- references: Array with CVE link
- tags: MITRE ATT&CK techniques
**YAML Structure Requirements:**
- Use UTF-8 encoding with LF line breaks
- Indent with 4 spaces (no tabs)
- Use lowercase keys only
- Use single quotes for string values
- No quotes for numeric values
- Follow proper YAML syntax
4. **Detection Structure:**
- Use selection blocks (selection, selection1, etc.)
- Condition references these selections
- Use proper field names (Image, CommandLine, ProcessName, etc.)
- Support wildcards (*) and value lists
**Detection Structure:**
- Use selection blocks (selection, selection1, etc.)
- Condition references these selections
- Use proper field names (Image, CommandLine, ProcessName, etc.)
- Support wildcards (*) and value lists
- Condition can be string expression or object with keywords
**ABSOLUTE REQUIREMENTS:**
- Output ONLY valid YAML
@ -1253,4 +1275,115 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
return []
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
return []
return []
async def test_connection(self) -> Dict[str, Any]:
"""Test connection to the configured LLM provider."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
return {
"available": False,
"error": "OpenAI API key not configured",
"models": [],
"current_model": self.model,
"has_api_key": False
}
# Test OpenAI connection without actual API call to avoid timeouts
if self.llm:
return {
"available": True,
"models": self.SUPPORTED_PROVIDERS['openai']['models'],
"current_model": self.model,
"has_api_key": True
}
else:
return {
"available": False,
"error": "OpenAI client not initialized",
"models": [],
"current_model": self.model,
"has_api_key": True
}
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
return {
"available": False,
"error": "Anthropic API key not configured",
"models": [],
"current_model": self.model,
"has_api_key": False
}
# Test Anthropic connection without actual API call to avoid timeouts
if self.llm:
return {
"available": True,
"models": self.SUPPORTED_PROVIDERS['anthropic']['models'],
"current_model": self.model,
"has_api_key": True
}
else:
return {
"available": False,
"error": "Anthropic client not initialized",
"models": [],
"current_model": self.model,
"has_api_key": True
}
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Test Ollama connection
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
available_models = self._get_ollama_available_models()
# Check if model is available using proper model name matching
model_available = self._check_ollama_model_available(base_url, self.model)
return {
"available": model_available,
"models": available_models,
"current_model": self.model,
"base_url": base_url,
"error": None if model_available else f"Model {self.model} not available"
}
else:
return {
"available": False,
"error": f"Ollama server not responding (HTTP {response.status_code})",
"models": [],
"current_model": self.model,
"base_url": base_url
}
except Exception as e:
return {
"available": False,
"error": f"Cannot connect to Ollama server: {str(e)}",
"models": [],
"current_model": self.model,
"base_url": base_url
}
else:
return {
"available": False,
"error": f"Unsupported provider: {self.provider}",
"models": [],
"current_model": self.model
}
except Exception as e:
return {
"available": False,
"error": f"Connection test failed: {str(e)}",
"models": [],
"current_model": self.model
}

File diff suppressed because it is too large Load diff

View file

@ -1,14 +1,23 @@
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy.orm import Session
from sqlalchemy import func, text
import uuid
from datetime import datetime
import logging
from config.database import get_db
from config.database import get_db, SessionLocal
from models import BulkProcessingJob, CVE, SigmaRule
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
from services import CVEService, SigmaRuleService
# Import global job tracking from main.py
import main
router = APIRouter(prefix="/api", tags=["bulk-operations"])
logger = logging.getLogger(__name__)
@router.post("/bulk-seed")
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
@ -117,4 +126,657 @@ async def get_poc_stats(db: Session = Depends(get_db)):
"total_rules": total_rules,
"exploit_based_rules": exploit_based_rules,
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
}
}
@router.post("/sync-nomi-sec")
async def sync_nomi_sec(background_tasks: BackgroundTasks,
request: NomiSecSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize nomi-sec PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"Nomi-sec bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"Nomi-sec sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-github-pocs")
async def sync_github_pocs(background_tasks: BackgroundTasks,
request: GitHubPoCSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize GitHub PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='github_poc_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from mcdevitt_poc_client import GitHubPoCClient
client = GitHubPoCClient(db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
logger.info(f"GitHub PoC bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"GitHub PoC sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-exploitdb")
async def sync_exploitdb(background_tasks: BackgroundTasks,
request: ExploitDBSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize ExploitDB data from git mirror"""
# Create job record
job = BulkProcessingJob(
job_type='exploitdb_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(task_db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_exploits(request.cve_id)
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_exploitdb(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"ExploitDB bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"ExploitDB sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-cisa-kev")
async def sync_cisa_kev(background_tasks: BackgroundTasks,
request: CISAKEVSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize CISA Known Exploited Vulnerabilities data"""
# Create job record
job = BulkProcessingJob(
job_type='cisa_kev_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(task_db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_kev_data(request.cve_id)
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_kev_data(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"CISA KEV bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"CISA KEV sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-references")
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start reference data synchronization"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Create job ID
job_id = str(uuid.uuid4())
# Add job to tracking
main.running_jobs[job_id] = {
'type': 'reference_sync',
'status': 'running',
'cve_id': request.cve_id,
'batch_size': request.batch_size,
'max_cves': request.max_cves,
'force_resync': request.force_resync,
'started_at': datetime.utcnow()
}
# Create cancellation flag
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
if request.cve_id:
# Single CVE sync
result = await client.sync_cve_references(request.cve_id)
main.running_jobs[job_id]['result'] = result
main.running_jobs[job_id]['status'] = 'completed'
else:
# Bulk sync
result = await client.bulk_sync_references(
batch_size=request.batch_size,
max_cves=request.max_cves,
force_resync=request.force_resync,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
main.running_jobs[job_id]['result'] = result
main.running_jobs[job_id]['status'] = 'completed'
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
except Exception as e:
logger.error(f"Reference sync task failed: {e}")
main.running_jobs[job_id]['status'] = 'failed'
main.running_jobs[job_id]['error'] = str(e)
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
finally:
# Clean up cancellation flag
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"max_cves": request.max_cves,
"force_resync": request.force_resync
}
except Exception as e:
logger.error(f"Failed to start reference sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
@router.get("/exploitdb-stats")
async def get_exploitdb_stats(db: Session = Depends(get_db)):
"""Get ExploitDB-related statistics"""
try:
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(db)
# Get sync status
status = await client.get_exploitdb_sync_status()
# Get quality distribution from ExploitDB data
quality_distribution = {}
cves_with_exploitdb = db.query(CVE).filter(
text("poc_data::text LIKE '%\"exploitdb\"%'")
).all()
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get category distribution
category_distribution = {}
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
category = exploit.get('category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
return {
"exploitdb_sync_status": status,
"quality_distribution": quality_distribution,
"category_distribution": category_distribution,
"total_exploitdb_cves": len(cves_with_exploitdb),
"total_exploits": sum(
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
for cve in cves_with_exploitdb
if cve.poc_data and 'exploitdb' in cve.poc_data
)
}
except Exception as e:
logger.error(f"Error getting ExploitDB stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/github-poc-stats")
async def get_github_poc_stats(db: Session = Depends(get_db)):
"""Get GitHub PoC-related statistics"""
try:
# Get basic statistics
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
cves_with_github_pocs = db.query(CVE).filter(
CVE.poc_data.isnot(None), # Check if poc_data exists
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).count()
# Get quality distribution
quality_distribution = {}
try:
quality_results = db.query(
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
func.count().label('count')
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).group_by('tier').all()
for tier, count in quality_results:
if tier:
quality_distribution[tier] = count
except Exception as e:
logger.warning(f"Error getting quality distribution: {e}")
quality_distribution = {}
# Calculate average quality score
try:
from sqlalchemy import Integer
avg_quality = db.query(
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).scalar() or 0
except Exception as e:
logger.warning(f"Error calculating average quality: {e}")
avg_quality = 0
return {
'github_poc_rules': github_poc_rules,
'cves_with_github_pocs': cves_with_github_pocs,
'quality_distribution': quality_distribution,
'average_quality_score': float(avg_quality) if avg_quality else 0,
'source': 'github_poc'
}
except Exception as e:
logger.error(f"Error getting GitHub PoC stats: {e}")
return {"error": str(e)}
@router.get("/cisa-kev-stats")
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
"""Get CISA KEV-related statistics"""
try:
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(db)
# Get sync status
status = await client.get_kev_sync_status()
# Get threat level distribution from CISA KEV data
threat_level_distribution = {}
cves_with_kev = db.query(CVE).filter(
text("poc_data::text LIKE '%\"cisa_kev\"%'")
).all()
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_level = vuln_data.get('threat_level', 'unknown')
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
# Get vulnerability category distribution
category_distribution = {}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
category = vuln_data.get('vulnerability_category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
# Get ransomware usage statistics
ransomware_stats = {'known': 0, 'unknown': 0}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
if ransomware_use == 'known':
ransomware_stats['known'] += 1
else:
ransomware_stats['unknown'] += 1
# Calculate average threat score
threat_scores = []
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_score = vuln_data.get('threat_score', 0)
if threat_score:
threat_scores.append(threat_score)
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
return {
"cisa_kev_sync_status": status,
"threat_level_distribution": threat_level_distribution,
"category_distribution": category_distribution,
"ransomware_stats": ransomware_stats,
"average_threat_score": round(avg_threat_score, 2),
"total_kev_cves": len(cves_with_kev),
"total_with_threat_scores": len(threat_scores)
}
except Exception as e:
logger.error(f"Error getting CISA KEV stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job"""
try:
# Find the job in the database
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status not in ['pending', 'running']:
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
# Set cancellation flag
main.job_cancellation_flags[job_id] = True
# Update job status
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} ({job.job_type}) cancelled by user")
return {
"success": True,
"message": f"Job {job_id} cancelled successfully",
"job_id": job_id,
"job_type": job.job_type
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running-jobs")
async def get_running_jobs(db: Session = Depends(get_db)):
"""Get currently running jobs"""
try:
# Get running jobs from database
running_jobs_db = db.query(BulkProcessingJob).filter(
BulkProcessingJob.status.in_(['pending', 'running'])
).order_by(BulkProcessingJob.created_at.desc()).all()
result = []
for job in running_jobs_db:
result.append({
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'started_at': job.started_at,
'created_at': job.created_at,
'can_cancel': job.status in ['pending', 'running']
})
return result
except Exception as e:
logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e))

View file

@ -1,17 +1,20 @@
from typing import Dict, Any
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy.orm import Session
from pydantic import BaseModel
import logging
from config.database import get_db
from models import CVE, SigmaRule
router = APIRouter(prefix="/api", tags=["llm-operations"])
logger = logging.getLogger(__name__)
class LLMRuleRequest(BaseModel):
cve_id: str
cve_id: str = None # Optional for bulk operations
poc_content: str = ""
force: bool = False # For bulk operations
class LLMSwitchRequest(BaseModel):
@ -20,32 +23,170 @@ class LLMSwitchRequest(BaseModel):
@router.post("/llm-enhanced-rules")
async def generate_llm_enhanced_rules(request: LLMRuleRequest, db: Session = Depends(get_db)):
async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Generate SIGMA rules using LLM AI analysis"""
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
# Get CVE
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
# Generate enhanced rule using LLM
generator = EnhancedSigmaGenerator(db)
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if request.cve_id:
# Single CVE operation
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
# Generate enhanced rule using LLM
generator = EnhancedSigmaGenerator(db)
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
return {
"success": True,
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
"rule_id": result.get('rule_id'),
"generation_method": "llm_enhanced"
}
else:
return {
"success": False,
"error": result.get('error', 'Unknown error'),
"cve_id": request.cve_id
}
else:
# Bulk operation - run in background with job tracking
from models import BulkProcessingJob
import uuid
from datetime import datetime
import main
# Create job record
job = BulkProcessingJob(
job_type='llm_rule_generation',
status='pending',
job_metadata={
'force': request.force
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def bulk_llm_generation_task():
# Create a new database session for the background task
from config.database import SessionLocal
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
generator = EnhancedSigmaGenerator(task_db)
# Get CVEs with PoC data - limit to small batch initially
if request.force:
# Process all CVEs with PoC data - but limit to prevent system overload
cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all()
else:
# Only process CVEs without existing LLM-generated rules - small batch
cves_to_process = task_db.query(CVE).filter(
CVE.poc_count > 0
).limit(10).all()
# Filter out CVEs that already have LLM-generated rules
existing_llm_rules = task_db.query(SigmaRule).filter(
SigmaRule.detection_type.like('llm_%')
).all()
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids]
# Update job with total items
task_job.total_items = len(cves_to_process)
task_db.commit()
rules_generated = 0
rules_updated = 0
failures = 0
logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})")
for i, cve in enumerate(cves_to_process):
# Check for cancellation
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled, stopping LLM generation")
break
try:
logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}")
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if result.get('updated'):
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Successfully generated rule for {cve.cve_id}")
else:
failures += 1
logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}")
# Update progress
task_job.processed_items = i + 1
task_job.failed_items = failures
task_db.commit()
except Exception as e:
failures += 1
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
# Update progress
task_job.processed_items = i + 1
task_job.failed_items = failures
task_db.commit()
# Continue with next CVE even if one fails
continue
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"Bulk LLM rule generation failed: {e}")
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(bulk_llm_generation_task)
return {
"success": True,
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
"rule_id": result.get('rule_id'),
"generation_method": "llm_enhanced"
}
else:
return {
"success": False,
"error": result.get('error', 'Unknown error'),
"cve_id": request.cve_id
"message": "Bulk LLM-enhanced rule generation started",
"status": "started",
"job_id": job_id,
"force": request.force
}
except Exception as e:
@ -58,68 +199,97 @@ async def get_llm_status():
try:
from llm_client import LLMClient
# Test all providers
providers_status = {}
# Get current configuration first
current_client = LLMClient()
# Build available providers list in the format frontend expects
available_providers = []
# Test Ollama
try:
ollama_client = LLMClient(provider="ollama")
ollama_status = await ollama_client.test_connection()
providers_status["ollama"] = {
available_providers.append({
"name": "ollama",
"available": ollama_status.get("available", False),
"default_model": ollama_status.get("current_model", "llama3.2"),
"models": ollama_status.get("models", []),
"current_model": ollama_status.get("current_model"),
"base_url": ollama_status.get("base_url")
}
})
except Exception as e:
providers_status["ollama"] = {"available": False, "error": str(e)}
available_providers.append({
"name": "ollama",
"available": False,
"default_model": "llama3.2",
"models": [],
"error": str(e)
})
# Test OpenAI
try:
openai_client = LLMClient(provider="openai")
openai_status = await openai_client.test_connection()
providers_status["openai"] = {
available_providers.append({
"name": "openai",
"available": openai_status.get("available", False),
"default_model": openai_status.get("current_model", "gpt-4o-mini"),
"models": openai_status.get("models", []),
"current_model": openai_status.get("current_model"),
"has_api_key": openai_status.get("has_api_key", False)
}
})
except Exception as e:
providers_status["openai"] = {"available": False, "error": str(e)}
available_providers.append({
"name": "openai",
"available": False,
"default_model": "gpt-4o-mini",
"models": [],
"has_api_key": False,
"error": str(e)
})
# Test Anthropic
try:
anthropic_client = LLMClient(provider="anthropic")
anthropic_status = await anthropic_client.test_connection()
providers_status["anthropic"] = {
available_providers.append({
"name": "anthropic",
"available": anthropic_status.get("available", False),
"default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"),
"models": anthropic_status.get("models", []),
"current_model": anthropic_status.get("current_model"),
"has_api_key": anthropic_status.get("has_api_key", False)
}
})
except Exception as e:
providers_status["anthropic"] = {"available": False, "error": str(e)}
available_providers.append({
"name": "anthropic",
"available": False,
"default_model": "claude-3-5-sonnet-20241022",
"models": [],
"has_api_key": False,
"error": str(e)
})
# Get current configuration
current_client = LLMClient()
current_config = {
"current_provider": current_client.provider,
"current_model": current_client.model,
"default_provider": "ollama"
}
# Determine overall status
any_available = any(p.get("available") for p in available_providers)
status = "ready" if any_available else "not_ready"
# Return in the format the frontend expects
return {
"providers": providers_status,
"configuration": current_config,
"status": "operational" if any(p.get("available") for p in providers_status.values()) else "no_providers_available"
"status": status,
"current_provider": {
"provider": current_client.provider,
"model": current_client.model
},
"available_providers": available_providers
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"providers": {},
"configuration": {}
"current_provider": {
"provider": "unknown",
"model": "unknown"
},
"available_providers": []
}

View file

@ -52,6 +52,22 @@ function App() {
return runningJobTypes.has(jobType);
};
// Helper function to format job types for display
const formatJobType = (jobType) => {
const jobTypeMap = {
'llm_rule_generation': 'LLM Rule Generation',
'rule_regeneration': 'Rule Regeneration',
'bulk_seed': 'Bulk Seed',
'incremental_update': 'Incremental Update',
'nomi_sec_sync': 'Nomi-Sec Sync',
'github_poc_sync': 'GitHub PoC Sync',
'exploitdb_sync': 'ExploitDB Sync',
'cisa_kev_sync': 'CISA KEV Sync',
'reference_sync': 'Reference Sync'
};
return jobTypeMap[jobType] || jobType;
};
const isBulkSeedRunning = () => {
return isJobTypeRunning('nvd_bulk_seed') || isJobTypeRunning('bulk_seed');
};
@ -262,7 +278,7 @@ function App() {
try {
const response = await axios.post('http://localhost:8000/api/sync-references', {
batch_size: 30,
max_cves: 100,
max_cves: null,
force_resync: false
});
console.log('Reference sync response:', response.data);
@ -291,9 +307,20 @@ function App() {
force: force
});
console.log('LLM rule generation response:', response.data);
fetchData();
// For bulk operations, the job runs in background - no need to wait
if (response.data.status === 'started') {
console.log(`LLM rule generation started as background job: ${response.data.job_id}`);
alert(`LLM rule generation started in background. You can monitor progress in the Bulk Jobs tab.`);
// Refresh data to show the new job status
fetchData();
} else {
// For single CVE operations, refresh data after completion
fetchData();
}
} catch (error) {
console.error('Error generating LLM-enhanced rules:', error);
alert('Error generating LLM-enhanced rules. Please check the console for details.');
}
};
@ -1062,7 +1089,7 @@ function App() {
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="flex items-center space-x-3">
<h3 className="text-lg font-medium text-gray-900">{job.job_type}</h3>
<h3 className="text-lg font-medium text-gray-900">{formatJobType(job.job_type)}</h3>
<span className={`inline-flex px-2 py-1 text-xs font-semibold rounded-full ${
job.status === 'running' ? 'bg-blue-100 text-blue-800' :
'bg-gray-100 text-gray-800'