782 lines
No EOL
29 KiB
Python
782 lines
No EOL
29 KiB
Python
from typing import List, Optional
|
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func, text
|
|
import uuid
|
|
from datetime import datetime
|
|
import logging
|
|
|
|
from config.database import get_db, SessionLocal
|
|
from models import BulkProcessingJob, CVE, SigmaRule
|
|
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
|
|
from services import CVEService, SigmaRuleService
|
|
|
|
# Import global job tracking from main.py
|
|
import main
|
|
|
|
router = APIRouter(prefix="/api", tags=["bulk-operations"])
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@router.post("/bulk-seed")
|
|
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Start bulk seeding operation"""
|
|
from bulk_seeder import BulkSeeder
|
|
|
|
async def run_bulk_seed():
|
|
try:
|
|
seeder = BulkSeeder(db)
|
|
result = await seeder.full_bulk_seed(
|
|
start_year=request.start_year,
|
|
end_year=request.end_year,
|
|
skip_nvd=request.skip_nvd,
|
|
skip_nomi_sec=request.skip_nomi_sec
|
|
)
|
|
print(f"Bulk seed completed: {result}")
|
|
except Exception as e:
|
|
print(f"Bulk seed failed: {str(e)}")
|
|
|
|
background_tasks.add_task(run_bulk_seed)
|
|
return {"message": "Bulk seeding started", "status": "running"}
|
|
|
|
|
|
@router.post("/incremental-update")
|
|
async def incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Start incremental update using NVD modified/recent feeds"""
|
|
from nvd_bulk_processor import NVDBulkProcessor
|
|
|
|
async def run_incremental_update():
|
|
try:
|
|
processor = NVDBulkProcessor(db)
|
|
result = await processor.incremental_update()
|
|
print(f"Incremental update completed: {result}")
|
|
except Exception as e:
|
|
print(f"Incremental update failed: {str(e)}")
|
|
|
|
background_tasks.add_task(run_incremental_update)
|
|
return {"message": "Incremental update started", "status": "running"}
|
|
|
|
|
|
@router.get("/bulk-jobs")
|
|
async def get_bulk_jobs(db: Session = Depends(get_db)):
|
|
"""Get all bulk processing jobs"""
|
|
jobs = db.query(BulkProcessingJob).order_by(BulkProcessingJob.created_at.desc()).limit(20).all()
|
|
|
|
result = []
|
|
for job in jobs:
|
|
job_dict = {
|
|
'id': str(job.id),
|
|
'job_type': job.job_type,
|
|
'status': job.status,
|
|
'year': job.year,
|
|
'total_items': job.total_items,
|
|
'processed_items': job.processed_items,
|
|
'failed_items': job.failed_items,
|
|
'error_message': job.error_message,
|
|
'job_metadata': job.job_metadata,
|
|
'started_at': job.started_at,
|
|
'completed_at': job.completed_at,
|
|
'cancelled_at': job.cancelled_at,
|
|
'created_at': job.created_at
|
|
}
|
|
result.append(job_dict)
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/bulk-status")
|
|
async def get_bulk_status(db: Session = Depends(get_db)):
|
|
"""Get comprehensive bulk processing status"""
|
|
from bulk_seeder import BulkSeeder
|
|
|
|
seeder = BulkSeeder(db)
|
|
status = await seeder.get_seeding_status()
|
|
return status
|
|
|
|
|
|
@router.get("/poc-stats")
|
|
async def get_poc_stats(db: Session = Depends(get_db)):
|
|
"""Get PoC-related statistics"""
|
|
from sqlalchemy import func, text
|
|
|
|
total_cves = db.query(CVE).count()
|
|
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
|
|
|
|
# Get PoC quality distribution
|
|
quality_distribution = db.execute(text("""
|
|
SELECT
|
|
COUNT(*) as total,
|
|
AVG(poc_count) as avg_poc_count,
|
|
MAX(poc_count) as max_poc_count
|
|
FROM cves
|
|
WHERE poc_count > 0
|
|
""")).fetchone()
|
|
|
|
# Get rules with PoC data
|
|
total_rules = db.query(SigmaRule).count()
|
|
exploit_based_rules = db.query(SigmaRule).filter(SigmaRule.exploit_based == True).count()
|
|
|
|
return {
|
|
"total_cves": total_cves,
|
|
"cves_with_pocs": cves_with_pocs,
|
|
"poc_coverage_percentage": round((cves_with_pocs / total_cves * 100), 2) if total_cves > 0 else 0,
|
|
"average_pocs_per_cve": round(quality_distribution.avg_poc_count, 2) if quality_distribution.avg_poc_count else 0,
|
|
"max_pocs_for_single_cve": quality_distribution.max_poc_count or 0,
|
|
"total_rules": total_rules,
|
|
"exploit_based_rules": exploit_based_rules,
|
|
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
|
|
}
|
|
|
|
|
|
@router.post("/sync-nomi-sec")
|
|
async def sync_nomi_sec(background_tasks: BackgroundTasks,
|
|
request: NomiSecSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize nomi-sec PoC data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='nomi_sec_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
main.running_jobs[job_id] = job
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
job.status = 'running'
|
|
job.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
from nomi_sec_client import NomiSecClient
|
|
client = NomiSecClient(db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if main.job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_pocs(request.cve_id)
|
|
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_all_cves(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"Nomi-sec bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
job.status = 'completed'
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
job.status = 'failed'
|
|
job.error_message = str(e)
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
logger.error(f"Nomi-sec sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking
|
|
main.running_jobs.pop(job_id, None)
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
|
|
@router.post("/sync-github-pocs")
|
|
async def sync_github_pocs(background_tasks: BackgroundTasks,
|
|
request: GitHubPoCSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize GitHub PoC data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='github_poc_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
main.running_jobs[job_id] = job
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
job.status = 'running'
|
|
job.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
from mcdevitt_poc_client import GitHubPoCClient
|
|
client = GitHubPoCClient(db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if main.job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_pocs(request.cve_id)
|
|
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
|
|
logger.info(f"GitHub PoC bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
job.status = 'completed'
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
job.status = 'failed'
|
|
job.error_message = str(e)
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
logger.error(f"GitHub PoC sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking
|
|
main.running_jobs.pop(job_id, None)
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
|
|
@router.post("/sync-exploitdb")
|
|
async def sync_exploitdb(background_tasks: BackgroundTasks,
|
|
request: ExploitDBSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize ExploitDB data from git mirror"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='exploitdb_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
main.running_jobs[job_id] = job
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
# Create a new database session for the background task
|
|
task_db = SessionLocal()
|
|
try:
|
|
# Get the job in the new session
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if not task_job:
|
|
logger.error(f"Job {job_id} not found in task session")
|
|
return
|
|
|
|
task_job.status = 'running'
|
|
task_job.started_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
from exploitdb_client_local import ExploitDBLocalClient
|
|
client = ExploitDBLocalClient(task_db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if main.job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_exploits(request.cve_id)
|
|
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_exploitdb(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"ExploitDB bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
task_job.status = 'completed'
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
except Exception as e:
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
# Get the job again in case it was modified
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if task_job:
|
|
task_job.status = 'failed'
|
|
task_job.error_message = str(e)
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
logger.error(f"ExploitDB sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking and close the task session
|
|
main.running_jobs.pop(job_id, None)
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
task_db.close()
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
|
|
@router.post("/sync-cisa-kev")
|
|
async def sync_cisa_kev(background_tasks: BackgroundTasks,
|
|
request: CISAKEVSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize CISA Known Exploited Vulnerabilities data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='cisa_kev_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
main.running_jobs[job_id] = job
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
# Create a new database session for the background task
|
|
task_db = SessionLocal()
|
|
try:
|
|
# Get the job in the new session
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if not task_job:
|
|
logger.error(f"Job {job_id} not found in task session")
|
|
return
|
|
|
|
task_job.status = 'running'
|
|
task_job.started_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
from cisa_kev_client import CISAKEVClient
|
|
client = CISAKEVClient(task_db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if main.job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_kev_data(request.cve_id)
|
|
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_kev_data(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"CISA KEV bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
task_job.status = 'completed'
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
except Exception as e:
|
|
if not main.job_cancellation_flags.get(job_id, False):
|
|
# Get the job again in case it was modified
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if task_job:
|
|
task_job.status = 'failed'
|
|
task_job.error_message = str(e)
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
logger.error(f"CISA KEV sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking and close the task session
|
|
main.running_jobs.pop(job_id, None)
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
task_db.close()
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
|
|
@router.post("/sync-references")
|
|
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Start reference data synchronization"""
|
|
|
|
try:
|
|
from reference_client import ReferenceClient
|
|
client = ReferenceClient(db)
|
|
|
|
# Create job ID
|
|
job_id = str(uuid.uuid4())
|
|
|
|
# Add job to tracking
|
|
main.running_jobs[job_id] = {
|
|
'type': 'reference_sync',
|
|
'status': 'running',
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size,
|
|
'max_cves': request.max_cves,
|
|
'force_resync': request.force_resync,
|
|
'started_at': datetime.utcnow()
|
|
}
|
|
|
|
# Create cancellation flag
|
|
main.job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
if request.cve_id:
|
|
# Single CVE sync
|
|
result = await client.sync_cve_references(request.cve_id)
|
|
main.running_jobs[job_id]['result'] = result
|
|
main.running_jobs[job_id]['status'] = 'completed'
|
|
else:
|
|
# Bulk sync
|
|
result = await client.bulk_sync_references(
|
|
batch_size=request.batch_size,
|
|
max_cves=request.max_cves,
|
|
force_resync=request.force_resync,
|
|
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
|
|
)
|
|
main.running_jobs[job_id]['result'] = result
|
|
main.running_jobs[job_id]['status'] = 'completed'
|
|
|
|
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Reference sync task failed: {e}")
|
|
main.running_jobs[job_id]['status'] = 'failed'
|
|
main.running_jobs[job_id]['error'] = str(e)
|
|
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
|
finally:
|
|
# Clean up cancellation flag
|
|
main.job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size,
|
|
"max_cves": request.max_cves,
|
|
"force_resync": request.force_resync
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start reference sync: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
|
|
|
|
|
|
@router.get("/exploitdb-stats")
|
|
async def get_exploitdb_stats(db: Session = Depends(get_db)):
|
|
"""Get ExploitDB-related statistics"""
|
|
|
|
try:
|
|
from exploitdb_client_local import ExploitDBLocalClient
|
|
client = ExploitDBLocalClient(db)
|
|
|
|
# Get sync status
|
|
status = await client.get_exploitdb_sync_status()
|
|
|
|
# Get quality distribution from ExploitDB data
|
|
quality_distribution = {}
|
|
cves_with_exploitdb = db.query(CVE).filter(
|
|
text("poc_data::text LIKE '%\"exploitdb\"%'")
|
|
).all()
|
|
|
|
for cve in cves_with_exploitdb:
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
|
for exploit in exploits:
|
|
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
|
|
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
|
|
|
|
# Get category distribution
|
|
category_distribution = {}
|
|
for cve in cves_with_exploitdb:
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
|
for exploit in exploits:
|
|
category = exploit.get('category', 'unknown')
|
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
|
|
|
return {
|
|
"exploitdb_sync_status": status,
|
|
"quality_distribution": quality_distribution,
|
|
"category_distribution": category_distribution,
|
|
"total_exploitdb_cves": len(cves_with_exploitdb),
|
|
"total_exploits": sum(
|
|
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
|
|
for cve in cves_with_exploitdb
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data
|
|
)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting ExploitDB stats: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/github-poc-stats")
|
|
async def get_github_poc_stats(db: Session = Depends(get_db)):
|
|
"""Get GitHub PoC-related statistics"""
|
|
|
|
try:
|
|
# Get basic statistics
|
|
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
|
|
cves_with_github_pocs = db.query(CVE).filter(
|
|
CVE.poc_data.isnot(None), # Check if poc_data exists
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).count()
|
|
|
|
# Get quality distribution
|
|
quality_distribution = {}
|
|
try:
|
|
quality_results = db.query(
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
|
|
func.count().label('count')
|
|
).filter(
|
|
CVE.poc_data.isnot(None),
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).group_by('tier').all()
|
|
|
|
for tier, count in quality_results:
|
|
if tier:
|
|
quality_distribution[tier] = count
|
|
except Exception as e:
|
|
logger.warning(f"Error getting quality distribution: {e}")
|
|
quality_distribution = {}
|
|
|
|
# Calculate average quality score
|
|
try:
|
|
from sqlalchemy import Integer
|
|
avg_quality = db.query(
|
|
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
|
|
).filter(
|
|
CVE.poc_data.isnot(None),
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).scalar() or 0
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating average quality: {e}")
|
|
avg_quality = 0
|
|
|
|
return {
|
|
'github_poc_rules': github_poc_rules,
|
|
'cves_with_github_pocs': cves_with_github_pocs,
|
|
'quality_distribution': quality_distribution,
|
|
'average_quality_score': float(avg_quality) if avg_quality else 0,
|
|
'source': 'github_poc'
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting GitHub PoC stats: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
@router.get("/cisa-kev-stats")
|
|
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
|
|
"""Get CISA KEV-related statistics"""
|
|
|
|
try:
|
|
from cisa_kev_client import CISAKEVClient
|
|
client = CISAKEVClient(db)
|
|
|
|
# Get sync status
|
|
status = await client.get_kev_sync_status()
|
|
|
|
# Get threat level distribution from CISA KEV data
|
|
threat_level_distribution = {}
|
|
cves_with_kev = db.query(CVE).filter(
|
|
text("poc_data::text LIKE '%\"cisa_kev\"%'")
|
|
).all()
|
|
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
threat_level = vuln_data.get('threat_level', 'unknown')
|
|
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
|
|
|
|
# Get vulnerability category distribution
|
|
category_distribution = {}
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
category = vuln_data.get('vulnerability_category', 'unknown')
|
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
|
|
|
# Get ransomware usage statistics
|
|
ransomware_stats = {'known': 0, 'unknown': 0}
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
|
|
if ransomware_use == 'known':
|
|
ransomware_stats['known'] += 1
|
|
else:
|
|
ransomware_stats['unknown'] += 1
|
|
|
|
# Calculate average threat score
|
|
threat_scores = []
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
threat_score = vuln_data.get('threat_score', 0)
|
|
if threat_score:
|
|
threat_scores.append(threat_score)
|
|
|
|
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
|
|
|
|
return {
|
|
"cisa_kev_sync_status": status,
|
|
"threat_level_distribution": threat_level_distribution,
|
|
"category_distribution": category_distribution,
|
|
"ransomware_stats": ransomware_stats,
|
|
"average_threat_score": round(avg_threat_score, 2),
|
|
"total_kev_cves": len(cves_with_kev),
|
|
"total_with_threat_scores": len(threat_scores)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting CISA KEV stats: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/cancel-job/{job_id}")
|
|
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
|
"""Cancel a running job"""
|
|
try:
|
|
# Find the job in the database
|
|
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
if job.status not in ['pending', 'running']:
|
|
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
|
|
|
|
# Set cancellation flag
|
|
main.job_cancellation_flags[job_id] = True
|
|
|
|
# Update job status
|
|
job.status = 'cancelled'
|
|
job.cancelled_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
logger.info(f"Job {job_id} ({job.job_type}) cancelled by user")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"Job {job_id} cancelled successfully",
|
|
"job_id": job_id,
|
|
"job_type": job.job_type
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling job {job_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/running-jobs")
|
|
async def get_running_jobs(db: Session = Depends(get_db)):
|
|
"""Get currently running jobs"""
|
|
try:
|
|
# Get running jobs from database
|
|
running_jobs_db = db.query(BulkProcessingJob).filter(
|
|
BulkProcessingJob.status.in_(['pending', 'running'])
|
|
).order_by(BulkProcessingJob.created_at.desc()).all()
|
|
|
|
result = []
|
|
for job in running_jobs_db:
|
|
result.append({
|
|
'id': str(job.id),
|
|
'job_type': job.job_type,
|
|
'status': job.status,
|
|
'year': job.year,
|
|
'total_items': job.total_items,
|
|
'processed_items': job.processed_items,
|
|
'failed_items': job.failed_items,
|
|
'error_message': job.error_message,
|
|
'started_at': job.started_at,
|
|
'created_at': job.created_at,
|
|
'can_cancel': job.status in ['pending', 'running']
|
|
})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting running jobs: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e)) |