fixed llm operations post refactor
This commit is contained in:
parent
a6fb367ed4
commit
e9a5f54d3a
7 changed files with 1089 additions and 2459 deletions
|
@ -1,10 +1,20 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from .settings import settings
|
||||
|
||||
|
||||
# Database setup
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
# Database setup with connection pooling
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=10, # Number of connections to maintain in the pool
|
||||
max_overflow=20, # Additional connections that can be created on demand
|
||||
pool_timeout=30, # Timeout for getting connection from pool
|
||||
pool_recycle=3600, # Recycle connections after 1 hour
|
||||
pool_pre_ping=True, # Validate connections before use
|
||||
echo=False # Set to True for SQL query logging
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ Script to delete all SIGMA rules from the database
|
|||
This will clear existing rules so they can be regenerated with the improved LLM client
|
||||
"""
|
||||
|
||||
from models import SigmaRule, SessionLocal
|
||||
from models import SigmaRule
|
||||
from config.database import SessionLocal
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
|
|
|
@ -108,7 +108,10 @@ class LLMClient:
|
|||
self.llm = Ollama(
|
||||
model=self.model,
|
||||
base_url=base_url,
|
||||
temperature=0.1
|
||||
temperature=0.1,
|
||||
num_ctx=4096, # Context window size
|
||||
top_p=0.9,
|
||||
top_k=40
|
||||
)
|
||||
|
||||
if self.llm:
|
||||
|
@ -186,9 +189,22 @@ class LLMClient:
|
|||
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
|
||||
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
|
||||
|
||||
# Generate the response
|
||||
# Generate the response with timeout handling
|
||||
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
|
||||
response = await chain.ainvoke(input_data)
|
||||
|
||||
import asyncio
|
||||
try:
|
||||
# Add timeout wrapper around the LLM call
|
||||
response = await asyncio.wait_for(
|
||||
chain.ainvoke(input_data),
|
||||
timeout=150 # 2.5 minutes total timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LLM request timed out for {cve_id}")
|
||||
return None
|
||||
except Exception as llm_error:
|
||||
logger.error(f"LLM generation error for {cve_id}: {llm_error}")
|
||||
return None
|
||||
|
||||
# Debug: Log raw LLM response
|
||||
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
||||
|
@ -228,36 +244,42 @@ class LLMClient:
|
|||
|
||||
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
||||
|
||||
**CRITICAL: You must follow the exact SIGMA specification format:**
|
||||
**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:**
|
||||
The official SIGMA rule specification (v2.0.0) defines these requirements:
|
||||
|
||||
1. **YAML Structure Requirements:**
|
||||
- Use UTF-8 encoding with LF line breaks
|
||||
- Indent with 4 spaces (no tabs)
|
||||
- Use lowercase keys only
|
||||
- Use single quotes for string values
|
||||
- No quotes for numeric values
|
||||
- Follow proper YAML syntax
|
||||
**MANDATORY Fields (must include):**
|
||||
- title: Brief description (max 256 chars) - string
|
||||
- logsource: Log data source specification - object with category/product/service
|
||||
- detection: Search identifiers and conditions - object with selections and condition
|
||||
|
||||
2. **MANDATORY Fields (must include):**
|
||||
- title: Brief description (max 256 chars)
|
||||
- logsource: Log data source specification
|
||||
- detection: Search identifiers and conditions
|
||||
- condition: How detection elements combine
|
||||
**RECOMMENDED Fields:**
|
||||
- id: Unique UUID (version 4) - string with UUID format
|
||||
- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported"
|
||||
- description: Detailed explanation - string
|
||||
- author: Rule creator - string (use "AI Generated")
|
||||
- date: Creation date - string in YYYY/MM/DD format
|
||||
- modified: Last modification date - string in YYYY/MM/DD format
|
||||
- references: Sources for rule derivation - array of strings (URLs)
|
||||
- tags: MITRE ATT&CK techniques - array of strings
|
||||
- level: Rule severity - enum: "informational", "low", "medium", "high", "critical"
|
||||
- falsepositives: Known false positives - array of strings
|
||||
- fields: Related fields - array of strings
|
||||
- related: Related rules - array of objects with type and id
|
||||
|
||||
3. **RECOMMENDED Fields:**
|
||||
- id: Unique UUID
|
||||
- status: 'experimental' (for new rules)
|
||||
- description: Detailed explanation
|
||||
- author: 'AI Generated'
|
||||
- date: Current date (YYYY/MM/DD)
|
||||
- references: Array with CVE link
|
||||
- tags: MITRE ATT&CK techniques
|
||||
**YAML Structure Requirements:**
|
||||
- Use UTF-8 encoding with LF line breaks
|
||||
- Indent with 4 spaces (no tabs)
|
||||
- Use lowercase keys only
|
||||
- Use single quotes for string values
|
||||
- No quotes for numeric values
|
||||
- Follow proper YAML syntax
|
||||
|
||||
4. **Detection Structure:**
|
||||
- Use selection blocks (selection, selection1, etc.)
|
||||
- Condition references these selections
|
||||
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
||||
- Support wildcards (*) and value lists
|
||||
**Detection Structure:**
|
||||
- Use selection blocks (selection, selection1, etc.)
|
||||
- Condition references these selections
|
||||
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
||||
- Support wildcards (*) and value lists
|
||||
- Condition can be string expression or object with keywords
|
||||
|
||||
**ABSOLUTE REQUIREMENTS:**
|
||||
- Output ONLY valid YAML
|
||||
|
@ -1253,4 +1275,115 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Ollama models: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""Test connection to the configured LLM provider."""
|
||||
try:
|
||||
if self.provider == 'openai':
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
if not api_key:
|
||||
return {
|
||||
"available": False,
|
||||
"error": "OpenAI API key not configured",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"has_api_key": False
|
||||
}
|
||||
|
||||
# Test OpenAI connection without actual API call to avoid timeouts
|
||||
if self.llm:
|
||||
return {
|
||||
"available": True,
|
||||
"models": self.SUPPORTED_PROVIDERS['openai']['models'],
|
||||
"current_model": self.model,
|
||||
"has_api_key": True
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"available": False,
|
||||
"error": "OpenAI client not initialized",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"has_api_key": True
|
||||
}
|
||||
|
||||
elif self.provider == 'anthropic':
|
||||
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||
if not api_key:
|
||||
return {
|
||||
"available": False,
|
||||
"error": "Anthropic API key not configured",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"has_api_key": False
|
||||
}
|
||||
|
||||
# Test Anthropic connection without actual API call to avoid timeouts
|
||||
if self.llm:
|
||||
return {
|
||||
"available": True,
|
||||
"models": self.SUPPORTED_PROVIDERS['anthropic']['models'],
|
||||
"current_model": self.model,
|
||||
"has_api_key": True
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"available": False,
|
||||
"error": "Anthropic client not initialized",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"has_api_key": True
|
||||
}
|
||||
|
||||
elif self.provider == 'ollama':
|
||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||
|
||||
# Test Ollama connection
|
||||
try:
|
||||
import requests
|
||||
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
available_models = self._get_ollama_available_models()
|
||||
# Check if model is available using proper model name matching
|
||||
model_available = self._check_ollama_model_available(base_url, self.model)
|
||||
|
||||
return {
|
||||
"available": model_available,
|
||||
"models": available_models,
|
||||
"current_model": self.model,
|
||||
"base_url": base_url,
|
||||
"error": None if model_available else f"Model {self.model} not available"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"available": False,
|
||||
"error": f"Ollama server not responding (HTTP {response.status_code})",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"base_url": base_url
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"available": False,
|
||||
"error": f"Cannot connect to Ollama server: {str(e)}",
|
||||
"models": [],
|
||||
"current_model": self.model,
|
||||
"base_url": base_url
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"available": False,
|
||||
"error": f"Unsupported provider: {self.provider}",
|
||||
"models": [],
|
||||
"current_model": self.model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"available": False,
|
||||
"error": f"Connection test failed: {str(e)}",
|
||||
"models": [],
|
||||
"current_model": self.model
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,14 +1,23 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, text
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from config.database import get_db
|
||||
from config.database import get_db, SessionLocal
|
||||
from models import BulkProcessingJob, CVE, SigmaRule
|
||||
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
|
||||
from services import CVEService, SigmaRuleService
|
||||
|
||||
# Import global job tracking from main.py
|
||||
import main
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["bulk-operations"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/bulk-seed")
|
||||
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
|
@ -117,4 +126,657 @@ async def get_poc_stats(db: Session = Depends(get_db)):
|
|||
"total_rules": total_rules,
|
||||
"exploit_based_rules": exploit_based_rules,
|
||||
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sync-nomi-sec")
|
||||
async def sync_nomi_sec(background_tasks: BackgroundTasks,
|
||||
request: NomiSecSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize nomi-sec PoC data"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='nomi_sec_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
main.running_jobs[job_id] = job
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
try:
|
||||
job.status = 'running'
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
from nomi_sec_client import NomiSecClient
|
||||
client = NomiSecClient(db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if main.job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_pocs(request.cve_id)
|
||||
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_all_cves(
|
||||
batch_size=request.batch_size,
|
||||
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
logger.info(f"Nomi-sec bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.error(f"Nomi-sec sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking
|
||||
main.running_jobs.pop(job_id, None)
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sync-github-pocs")
|
||||
async def sync_github_pocs(background_tasks: BackgroundTasks,
|
||||
request: GitHubPoCSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize GitHub PoC data"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='github_poc_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
main.running_jobs[job_id] = job
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
try:
|
||||
job.status = 'running'
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
from mcdevitt_poc_client import GitHubPoCClient
|
||||
client = GitHubPoCClient(db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if main.job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_pocs(request.cve_id)
|
||||
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
|
||||
logger.info(f"GitHub PoC bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.error(f"GitHub PoC sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking
|
||||
main.running_jobs.pop(job_id, None)
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sync-exploitdb")
|
||||
async def sync_exploitdb(background_tasks: BackgroundTasks,
|
||||
request: ExploitDBSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize ExploitDB data from git mirror"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='exploitdb_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
main.running_jobs[job_id] = job
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
# Create a new database session for the background task
|
||||
task_db = SessionLocal()
|
||||
try:
|
||||
# Get the job in the new session
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if not task_job:
|
||||
logger.error(f"Job {job_id} not found in task session")
|
||||
return
|
||||
|
||||
task_job.status = 'running'
|
||||
task_job.started_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
from exploitdb_client_local import ExploitDBLocalClient
|
||||
client = ExploitDBLocalClient(task_db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if main.job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_exploits(request.cve_id)
|
||||
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_exploitdb(
|
||||
batch_size=request.batch_size,
|
||||
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
logger.info(f"ExploitDB bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
task_job.status = 'completed'
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
# Get the job again in case it was modified
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if task_job:
|
||||
task_job.status = 'failed'
|
||||
task_job.error_message = str(e)
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
logger.error(f"ExploitDB sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking and close the task session
|
||||
main.running_jobs.pop(job_id, None)
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
task_db.close()
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sync-cisa-kev")
|
||||
async def sync_cisa_kev(background_tasks: BackgroundTasks,
|
||||
request: CISAKEVSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize CISA Known Exploited Vulnerabilities data"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='cisa_kev_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
main.running_jobs[job_id] = job
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
# Create a new database session for the background task
|
||||
task_db = SessionLocal()
|
||||
try:
|
||||
# Get the job in the new session
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if not task_job:
|
||||
logger.error(f"Job {job_id} not found in task session")
|
||||
return
|
||||
|
||||
task_job.status = 'running'
|
||||
task_job.started_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
from cisa_kev_client import CISAKEVClient
|
||||
client = CISAKEVClient(task_db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if main.job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_kev_data(request.cve_id)
|
||||
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_kev_data(
|
||||
batch_size=request.batch_size,
|
||||
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
logger.info(f"CISA KEV bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
task_job.status = 'completed'
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
# Get the job again in case it was modified
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if task_job:
|
||||
task_job.status = 'failed'
|
||||
task_job.error_message = str(e)
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
logger.error(f"CISA KEV sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking and close the task session
|
||||
main.running_jobs.pop(job_id, None)
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
task_db.close()
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sync-references")
|
||||
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
"""Start reference data synchronization"""
|
||||
|
||||
try:
|
||||
from reference_client import ReferenceClient
|
||||
client = ReferenceClient(db)
|
||||
|
||||
# Create job ID
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
# Add job to tracking
|
||||
main.running_jobs[job_id] = {
|
||||
'type': 'reference_sync',
|
||||
'status': 'running',
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size,
|
||||
'max_cves': request.max_cves,
|
||||
'force_resync': request.force_resync,
|
||||
'started_at': datetime.utcnow()
|
||||
}
|
||||
|
||||
# Create cancellation flag
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
try:
|
||||
if request.cve_id:
|
||||
# Single CVE sync
|
||||
result = await client.sync_cve_references(request.cve_id)
|
||||
main.running_jobs[job_id]['result'] = result
|
||||
main.running_jobs[job_id]['status'] = 'completed'
|
||||
else:
|
||||
# Bulk sync
|
||||
result = await client.bulk_sync_references(
|
||||
batch_size=request.batch_size,
|
||||
max_cves=request.max_cves,
|
||||
force_resync=request.force_resync,
|
||||
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
main.running_jobs[job_id]['result'] = result
|
||||
main.running_jobs[job_id]['status'] = 'completed'
|
||||
|
||||
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reference sync task failed: {e}")
|
||||
main.running_jobs[job_id]['status'] = 'failed'
|
||||
main.running_jobs[job_id]['error'] = str(e)
|
||||
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
||||
finally:
|
||||
# Clean up cancellation flag
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size,
|
||||
"max_cves": request.max_cves,
|
||||
"force_resync": request.force_resync
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start reference sync: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/exploitdb-stats")
|
||||
async def get_exploitdb_stats(db: Session = Depends(get_db)):
|
||||
"""Get ExploitDB-related statistics"""
|
||||
|
||||
try:
|
||||
from exploitdb_client_local import ExploitDBLocalClient
|
||||
client = ExploitDBLocalClient(db)
|
||||
|
||||
# Get sync status
|
||||
status = await client.get_exploitdb_sync_status()
|
||||
|
||||
# Get quality distribution from ExploitDB data
|
||||
quality_distribution = {}
|
||||
cves_with_exploitdb = db.query(CVE).filter(
|
||||
text("poc_data::text LIKE '%\"exploitdb\"%'")
|
||||
).all()
|
||||
|
||||
for cve in cves_with_exploitdb:
|
||||
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
||||
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
||||
for exploit in exploits:
|
||||
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
|
||||
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
|
||||
|
||||
# Get category distribution
|
||||
category_distribution = {}
|
||||
for cve in cves_with_exploitdb:
|
||||
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
||||
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
||||
for exploit in exploits:
|
||||
category = exploit.get('category', 'unknown')
|
||||
category_distribution[category] = category_distribution.get(category, 0) + 1
|
||||
|
||||
return {
|
||||
"exploitdb_sync_status": status,
|
||||
"quality_distribution": quality_distribution,
|
||||
"category_distribution": category_distribution,
|
||||
"total_exploitdb_cves": len(cves_with_exploitdb),
|
||||
"total_exploits": sum(
|
||||
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
|
||||
for cve in cves_with_exploitdb
|
||||
if cve.poc_data and 'exploitdb' in cve.poc_data
|
||||
)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ExploitDB stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/github-poc-stats")
|
||||
async def get_github_poc_stats(db: Session = Depends(get_db)):
|
||||
"""Get GitHub PoC-related statistics"""
|
||||
|
||||
try:
|
||||
# Get basic statistics
|
||||
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
|
||||
cves_with_github_pocs = db.query(CVE).filter(
|
||||
CVE.poc_data.isnot(None), # Check if poc_data exists
|
||||
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||
).count()
|
||||
|
||||
# Get quality distribution
|
||||
quality_distribution = {}
|
||||
try:
|
||||
quality_results = db.query(
|
||||
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
CVE.poc_data.isnot(None),
|
||||
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||
).group_by('tier').all()
|
||||
|
||||
for tier, count in quality_results:
|
||||
if tier:
|
||||
quality_distribution[tier] = count
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting quality distribution: {e}")
|
||||
quality_distribution = {}
|
||||
|
||||
# Calculate average quality score
|
||||
try:
|
||||
from sqlalchemy import Integer
|
||||
avg_quality = db.query(
|
||||
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
|
||||
).filter(
|
||||
CVE.poc_data.isnot(None),
|
||||
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
||||
).scalar() or 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating average quality: {e}")
|
||||
avg_quality = 0
|
||||
|
||||
return {
|
||||
'github_poc_rules': github_poc_rules,
|
||||
'cves_with_github_pocs': cves_with_github_pocs,
|
||||
'quality_distribution': quality_distribution,
|
||||
'average_quality_score': float(avg_quality) if avg_quality else 0,
|
||||
'source': 'github_poc'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting GitHub PoC stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@router.get("/cisa-kev-stats")
|
||||
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
|
||||
"""Get CISA KEV-related statistics"""
|
||||
|
||||
try:
|
||||
from cisa_kev_client import CISAKEVClient
|
||||
client = CISAKEVClient(db)
|
||||
|
||||
# Get sync status
|
||||
status = await client.get_kev_sync_status()
|
||||
|
||||
# Get threat level distribution from CISA KEV data
|
||||
threat_level_distribution = {}
|
||||
cves_with_kev = db.query(CVE).filter(
|
||||
text("poc_data::text LIKE '%\"cisa_kev\"%'")
|
||||
).all()
|
||||
|
||||
for cve in cves_with_kev:
|
||||
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||
threat_level = vuln_data.get('threat_level', 'unknown')
|
||||
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
|
||||
|
||||
# Get vulnerability category distribution
|
||||
category_distribution = {}
|
||||
for cve in cves_with_kev:
|
||||
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||
category = vuln_data.get('vulnerability_category', 'unknown')
|
||||
category_distribution[category] = category_distribution.get(category, 0) + 1
|
||||
|
||||
# Get ransomware usage statistics
|
||||
ransomware_stats = {'known': 0, 'unknown': 0}
|
||||
for cve in cves_with_kev:
|
||||
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
|
||||
if ransomware_use == 'known':
|
||||
ransomware_stats['known'] += 1
|
||||
else:
|
||||
ransomware_stats['unknown'] += 1
|
||||
|
||||
# Calculate average threat score
|
||||
threat_scores = []
|
||||
for cve in cves_with_kev:
|
||||
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
||||
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
||||
threat_score = vuln_data.get('threat_score', 0)
|
||||
if threat_score:
|
||||
threat_scores.append(threat_score)
|
||||
|
||||
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
|
||||
|
||||
return {
|
||||
"cisa_kev_sync_status": status,
|
||||
"threat_level_distribution": threat_level_distribution,
|
||||
"category_distribution": category_distribution,
|
||||
"ransomware_stats": ransomware_stats,
|
||||
"average_threat_score": round(avg_threat_score, 2),
|
||||
"total_kev_cves": len(cves_with_kev),
|
||||
"total_with_threat_scores": len(threat_scores)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CISA KEV stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cancel-job/{job_id}")
|
||||
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
||||
"""Cancel a running job"""
|
||||
try:
|
||||
# Find the job in the database
|
||||
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status not in ['pending', 'running']:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
|
||||
|
||||
# Set cancellation flag
|
||||
main.job_cancellation_flags[job_id] = True
|
||||
|
||||
# Update job status
|
||||
job.status = 'cancelled'
|
||||
job.cancelled_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} ({job.job_type}) cancelled by user")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Job {job_id} cancelled successfully",
|
||||
"job_id": job_id,
|
||||
"job_type": job.job_type
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling job {job_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/running-jobs")
|
||||
async def get_running_jobs(db: Session = Depends(get_db)):
|
||||
"""Get currently running jobs"""
|
||||
try:
|
||||
# Get running jobs from database
|
||||
running_jobs_db = db.query(BulkProcessingJob).filter(
|
||||
BulkProcessingJob.status.in_(['pending', 'running'])
|
||||
).order_by(BulkProcessingJob.created_at.desc()).all()
|
||||
|
||||
result = []
|
||||
for job in running_jobs_db:
|
||||
result.append({
|
||||
'id': str(job.id),
|
||||
'job_type': job.job_type,
|
||||
'status': job.status,
|
||||
'year': job.year,
|
||||
'total_items': job.total_items,
|
||||
'processed_items': job.processed_items,
|
||||
'failed_items': job.failed_items,
|
||||
'error_message': job.error_message,
|
||||
'started_at': job.started_at,
|
||||
'created_at': job.created_at,
|
||||
'can_cancel': job.status in ['pending', 'running']
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting running jobs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -1,17 +1,20 @@
|
|||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from config.database import get_db
|
||||
from models import CVE, SigmaRule
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["llm-operations"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMRuleRequest(BaseModel):
|
||||
cve_id: str
|
||||
cve_id: str = None # Optional for bulk operations
|
||||
poc_content: str = ""
|
||||
force: bool = False # For bulk operations
|
||||
|
||||
|
||||
class LLMSwitchRequest(BaseModel):
|
||||
|
@ -20,32 +23,170 @@ class LLMSwitchRequest(BaseModel):
|
|||
|
||||
|
||||
@router.post("/llm-enhanced-rules")
|
||||
async def generate_llm_enhanced_rules(request: LLMRuleRequest, db: Session = Depends(get_db)):
|
||||
async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
"""Generate SIGMA rules using LLM AI analysis"""
|
||||
try:
|
||||
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
||||
|
||||
# Get CVE
|
||||
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
|
||||
if not cve:
|
||||
raise HTTPException(status_code=404, detail="CVE not found")
|
||||
|
||||
# Generate enhanced rule using LLM
|
||||
generator = EnhancedSigmaGenerator(db)
|
||||
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
||||
|
||||
if result.get('success'):
|
||||
if request.cve_id:
|
||||
# Single CVE operation
|
||||
cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first()
|
||||
if not cve:
|
||||
raise HTTPException(status_code=404, detail="CVE not found")
|
||||
|
||||
# Generate enhanced rule using LLM
|
||||
generator = EnhancedSigmaGenerator(db)
|
||||
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
||||
|
||||
if result.get('success'):
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
|
||||
"rule_id": result.get('rule_id'),
|
||||
"generation_method": "llm_enhanced"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.get('error', 'Unknown error'),
|
||||
"cve_id": request.cve_id
|
||||
}
|
||||
else:
|
||||
# Bulk operation - run in background with job tracking
|
||||
from models import BulkProcessingJob
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import main
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='llm_rule_generation',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'force': request.force
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
main.running_jobs[job_id] = job
|
||||
main.job_cancellation_flags[job_id] = False
|
||||
|
||||
async def bulk_llm_generation_task():
|
||||
# Create a new database session for the background task
|
||||
from config.database import SessionLocal
|
||||
task_db = SessionLocal()
|
||||
try:
|
||||
# Get the job in the new session
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if not task_job:
|
||||
logger.error(f"Job {job_id} not found in task session")
|
||||
return
|
||||
|
||||
task_job.status = 'running'
|
||||
task_job.started_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
generator = EnhancedSigmaGenerator(task_db)
|
||||
|
||||
# Get CVEs with PoC data - limit to small batch initially
|
||||
if request.force:
|
||||
# Process all CVEs with PoC data - but limit to prevent system overload
|
||||
cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all()
|
||||
else:
|
||||
# Only process CVEs without existing LLM-generated rules - small batch
|
||||
cves_to_process = task_db.query(CVE).filter(
|
||||
CVE.poc_count > 0
|
||||
).limit(10).all()
|
||||
|
||||
# Filter out CVEs that already have LLM-generated rules
|
||||
existing_llm_rules = task_db.query(SigmaRule).filter(
|
||||
SigmaRule.detection_type.like('llm_%')
|
||||
).all()
|
||||
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
|
||||
cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids]
|
||||
|
||||
# Update job with total items
|
||||
task_job.total_items = len(cves_to_process)
|
||||
task_db.commit()
|
||||
|
||||
rules_generated = 0
|
||||
rules_updated = 0
|
||||
failures = 0
|
||||
|
||||
logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})")
|
||||
|
||||
for i, cve in enumerate(cves_to_process):
|
||||
# Check for cancellation
|
||||
if main.job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled, stopping LLM generation")
|
||||
break
|
||||
|
||||
try:
|
||||
logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}")
|
||||
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
||||
|
||||
if result.get('success'):
|
||||
if result.get('updated'):
|
||||
rules_updated += 1
|
||||
else:
|
||||
rules_generated += 1
|
||||
logger.info(f"Successfully generated rule for {cve.cve_id}")
|
||||
else:
|
||||
failures += 1
|
||||
logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}")
|
||||
|
||||
# Update progress
|
||||
task_job.processed_items = i + 1
|
||||
task_job.failed_items = failures
|
||||
task_db.commit()
|
||||
|
||||
except Exception as e:
|
||||
failures += 1
|
||||
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
||||
|
||||
# Update progress
|
||||
task_job.processed_items = i + 1
|
||||
task_job.failed_items = failures
|
||||
task_db.commit()
|
||||
|
||||
# Continue with next CVE even if one fails
|
||||
continue
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
task_job.status = 'completed'
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
|
||||
|
||||
except Exception as e:
|
||||
if not main.job_cancellation_flags.get(job_id, False):
|
||||
# Get the job again in case it was modified
|
||||
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
||||
if task_job:
|
||||
task_job.status = 'failed'
|
||||
task_job.error_message = str(e)
|
||||
task_job.completed_at = datetime.utcnow()
|
||||
task_db.commit()
|
||||
|
||||
logger.error(f"Bulk LLM rule generation failed: {e}")
|
||||
finally:
|
||||
# Clean up tracking and close the task session
|
||||
main.running_jobs.pop(job_id, None)
|
||||
main.job_cancellation_flags.pop(job_id, None)
|
||||
task_db.close()
|
||||
|
||||
background_tasks.add_task(bulk_llm_generation_task)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Generated LLM-enhanced rule for {request.cve_id}",
|
||||
"rule_id": result.get('rule_id'),
|
||||
"generation_method": "llm_enhanced"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.get('error', 'Unknown error'),
|
||||
"cve_id": request.cve_id
|
||||
"message": "Bulk LLM-enhanced rule generation started",
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"force": request.force
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -58,68 +199,97 @@ async def get_llm_status():
|
|||
try:
|
||||
from llm_client import LLMClient
|
||||
|
||||
# Test all providers
|
||||
providers_status = {}
|
||||
# Get current configuration first
|
||||
current_client = LLMClient()
|
||||
|
||||
# Build available providers list in the format frontend expects
|
||||
available_providers = []
|
||||
|
||||
# Test Ollama
|
||||
try:
|
||||
ollama_client = LLMClient(provider="ollama")
|
||||
ollama_status = await ollama_client.test_connection()
|
||||
providers_status["ollama"] = {
|
||||
available_providers.append({
|
||||
"name": "ollama",
|
||||
"available": ollama_status.get("available", False),
|
||||
"default_model": ollama_status.get("current_model", "llama3.2"),
|
||||
"models": ollama_status.get("models", []),
|
||||
"current_model": ollama_status.get("current_model"),
|
||||
"base_url": ollama_status.get("base_url")
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
providers_status["ollama"] = {"available": False, "error": str(e)}
|
||||
available_providers.append({
|
||||
"name": "ollama",
|
||||
"available": False,
|
||||
"default_model": "llama3.2",
|
||||
"models": [],
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Test OpenAI
|
||||
try:
|
||||
openai_client = LLMClient(provider="openai")
|
||||
openai_status = await openai_client.test_connection()
|
||||
providers_status["openai"] = {
|
||||
available_providers.append({
|
||||
"name": "openai",
|
||||
"available": openai_status.get("available", False),
|
||||
"default_model": openai_status.get("current_model", "gpt-4o-mini"),
|
||||
"models": openai_status.get("models", []),
|
||||
"current_model": openai_status.get("current_model"),
|
||||
"has_api_key": openai_status.get("has_api_key", False)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
providers_status["openai"] = {"available": False, "error": str(e)}
|
||||
available_providers.append({
|
||||
"name": "openai",
|
||||
"available": False,
|
||||
"default_model": "gpt-4o-mini",
|
||||
"models": [],
|
||||
"has_api_key": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Test Anthropic
|
||||
try:
|
||||
anthropic_client = LLMClient(provider="anthropic")
|
||||
anthropic_status = await anthropic_client.test_connection()
|
||||
providers_status["anthropic"] = {
|
||||
available_providers.append({
|
||||
"name": "anthropic",
|
||||
"available": anthropic_status.get("available", False),
|
||||
"default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"),
|
||||
"models": anthropic_status.get("models", []),
|
||||
"current_model": anthropic_status.get("current_model"),
|
||||
"has_api_key": anthropic_status.get("has_api_key", False)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
providers_status["anthropic"] = {"available": False, "error": str(e)}
|
||||
available_providers.append({
|
||||
"name": "anthropic",
|
||||
"available": False,
|
||||
"default_model": "claude-3-5-sonnet-20241022",
|
||||
"models": [],
|
||||
"has_api_key": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Get current configuration
|
||||
current_client = LLMClient()
|
||||
current_config = {
|
||||
"current_provider": current_client.provider,
|
||||
"current_model": current_client.model,
|
||||
"default_provider": "ollama"
|
||||
}
|
||||
# Determine overall status
|
||||
any_available = any(p.get("available") for p in available_providers)
|
||||
status = "ready" if any_available else "not_ready"
|
||||
|
||||
# Return in the format the frontend expects
|
||||
return {
|
||||
"providers": providers_status,
|
||||
"configuration": current_config,
|
||||
"status": "operational" if any(p.get("available") for p in providers_status.values()) else "no_providers_available"
|
||||
"status": status,
|
||||
"current_provider": {
|
||||
"provider": current_client.provider,
|
||||
"model": current_client.model
|
||||
},
|
||||
"available_providers": available_providers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"providers": {},
|
||||
"configuration": {}
|
||||
"current_provider": {
|
||||
"provider": "unknown",
|
||||
"model": "unknown"
|
||||
},
|
||||
"available_providers": []
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -52,6 +52,22 @@ function App() {
|
|||
return runningJobTypes.has(jobType);
|
||||
};
|
||||
|
||||
// Helper function to format job types for display
|
||||
const formatJobType = (jobType) => {
|
||||
const jobTypeMap = {
|
||||
'llm_rule_generation': 'LLM Rule Generation',
|
||||
'rule_regeneration': 'Rule Regeneration',
|
||||
'bulk_seed': 'Bulk Seed',
|
||||
'incremental_update': 'Incremental Update',
|
||||
'nomi_sec_sync': 'Nomi-Sec Sync',
|
||||
'github_poc_sync': 'GitHub PoC Sync',
|
||||
'exploitdb_sync': 'ExploitDB Sync',
|
||||
'cisa_kev_sync': 'CISA KEV Sync',
|
||||
'reference_sync': 'Reference Sync'
|
||||
};
|
||||
return jobTypeMap[jobType] || jobType;
|
||||
};
|
||||
|
||||
const isBulkSeedRunning = () => {
|
||||
return isJobTypeRunning('nvd_bulk_seed') || isJobTypeRunning('bulk_seed');
|
||||
};
|
||||
|
@ -262,7 +278,7 @@ function App() {
|
|||
try {
|
||||
const response = await axios.post('http://localhost:8000/api/sync-references', {
|
||||
batch_size: 30,
|
||||
max_cves: 100,
|
||||
max_cves: null,
|
||||
force_resync: false
|
||||
});
|
||||
console.log('Reference sync response:', response.data);
|
||||
|
@ -291,9 +307,20 @@ function App() {
|
|||
force: force
|
||||
});
|
||||
console.log('LLM rule generation response:', response.data);
|
||||
fetchData();
|
||||
|
||||
// For bulk operations, the job runs in background - no need to wait
|
||||
if (response.data.status === 'started') {
|
||||
console.log(`LLM rule generation started as background job: ${response.data.job_id}`);
|
||||
alert(`LLM rule generation started in background. You can monitor progress in the Bulk Jobs tab.`);
|
||||
// Refresh data to show the new job status
|
||||
fetchData();
|
||||
} else {
|
||||
// For single CVE operations, refresh data after completion
|
||||
fetchData();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error generating LLM-enhanced rules:', error);
|
||||
alert('Error generating LLM-enhanced rules. Please check the console for details.');
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1062,7 +1089,7 @@ function App() {
|
|||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center space-x-3">
|
||||
<h3 className="text-lg font-medium text-gray-900">{job.job_type}</h3>
|
||||
<h3 className="text-lg font-medium text-gray-900">{formatJobType(job.job_type)}</h3>
|
||||
<span className={`inline-flex px-2 py-1 text-xs font-semibold rounded-full ${
|
||||
job.status === 'running' ? 'bg-blue-100 text-blue-800' :
|
||||
'bg-gray-100 text-gray-800'
|
||||
|
|
Loading…
Add table
Reference in a new issue