auto_sigma_rule_generator/backend/routers/bulk_operations.py

782 lines
No EOL
29 KiB
Python

from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy.orm import Session
from sqlalchemy import func, text
import uuid
from datetime import datetime
import logging
from config.database import get_db, SessionLocal
from models import BulkProcessingJob, CVE, SigmaRule
from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest
from services import CVEService, SigmaRuleService
# Import global job tracking from main.py
import main
router = APIRouter(prefix="/api", tags=["bulk-operations"])
logger = logging.getLogger(__name__)
@router.post("/bulk-seed")
async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start bulk seeding operation"""
from bulk_seeder import BulkSeeder
async def run_bulk_seed():
try:
seeder = BulkSeeder(db)
result = await seeder.full_bulk_seed(
start_year=request.start_year,
end_year=request.end_year,
skip_nvd=request.skip_nvd,
skip_nomi_sec=request.skip_nomi_sec
)
print(f"Bulk seed completed: {result}")
except Exception as e:
print(f"Bulk seed failed: {str(e)}")
background_tasks.add_task(run_bulk_seed)
return {"message": "Bulk seeding started", "status": "running"}
@router.post("/incremental-update")
async def incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start incremental update using NVD modified/recent feeds"""
from nvd_bulk_processor import NVDBulkProcessor
async def run_incremental_update():
try:
processor = NVDBulkProcessor(db)
result = await processor.incremental_update()
print(f"Incremental update completed: {result}")
except Exception as e:
print(f"Incremental update failed: {str(e)}")
background_tasks.add_task(run_incremental_update)
return {"message": "Incremental update started", "status": "running"}
@router.get("/bulk-jobs")
async def get_bulk_jobs(db: Session = Depends(get_db)):
"""Get all bulk processing jobs"""
jobs = db.query(BulkProcessingJob).order_by(BulkProcessingJob.created_at.desc()).limit(20).all()
result = []
for job in jobs:
job_dict = {
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'job_metadata': job.job_metadata,
'started_at': job.started_at,
'completed_at': job.completed_at,
'cancelled_at': job.cancelled_at,
'created_at': job.created_at
}
result.append(job_dict)
return result
@router.get("/bulk-status")
async def get_bulk_status(db: Session = Depends(get_db)):
"""Get comprehensive bulk processing status"""
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
status = await seeder.get_seeding_status()
return status
@router.get("/poc-stats")
async def get_poc_stats(db: Session = Depends(get_db)):
"""Get PoC-related statistics"""
from sqlalchemy import func, text
total_cves = db.query(CVE).count()
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
# Get PoC quality distribution
quality_distribution = db.execute(text("""
SELECT
COUNT(*) as total,
AVG(poc_count) as avg_poc_count,
MAX(poc_count) as max_poc_count
FROM cves
WHERE poc_count > 0
""")).fetchone()
# Get rules with PoC data
total_rules = db.query(SigmaRule).count()
exploit_based_rules = db.query(SigmaRule).filter(SigmaRule.exploit_based == True).count()
return {
"total_cves": total_cves,
"cves_with_pocs": cves_with_pocs,
"poc_coverage_percentage": round((cves_with_pocs / total_cves * 100), 2) if total_cves > 0 else 0,
"average_pocs_per_cve": round(quality_distribution.avg_poc_count, 2) if quality_distribution.avg_poc_count else 0,
"max_pocs_for_single_cve": quality_distribution.max_poc_count or 0,
"total_rules": total_rules,
"exploit_based_rules": exploit_based_rules,
"exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0
}
@router.post("/sync-nomi-sec")
async def sync_nomi_sec(background_tasks: BackgroundTasks,
request: NomiSecSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize nomi-sec PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"Nomi-sec bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"Nomi-sec sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-github-pocs")
async def sync_github_pocs(background_tasks: BackgroundTasks,
request: GitHubPoCSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize GitHub PoC data"""
# Create job record
job = BulkProcessingJob(
job_type='github_poc_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
job.status = 'running'
job.started_at = datetime.utcnow()
db.commit()
from mcdevitt_poc_client import GitHubPoCClient
client = GitHubPoCClient(db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_pocs(request.cve_id)
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
logger.info(f"GitHub PoC bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
logger.error(f"GitHub PoC sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-exploitdb")
async def sync_exploitdb(background_tasks: BackgroundTasks,
request: ExploitDBSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize ExploitDB data from git mirror"""
# Create job record
job = BulkProcessingJob(
job_type='exploitdb_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(task_db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_exploits(request.cve_id)
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_exploitdb(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"ExploitDB bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"ExploitDB sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-cisa-kev")
async def sync_cisa_kev(background_tasks: BackgroundTasks,
request: CISAKEVSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize CISA Known Exploited Vulnerabilities data"""
# Create job record
job = BulkProcessingJob(
job_type='cisa_kev_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
main.running_jobs[job_id] = job
main.job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(task_db)
if request.cve_id:
# Sync specific CVE
if main.job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_kev_data(request.cve_id)
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_kev_data(
batch_size=request.batch_size,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
logger.info(f"CISA KEV bulk sync completed: {result}")
# Update job status if not cancelled
if not main.job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not main.job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"CISA KEV sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
main.running_jobs.pop(job_id, None)
main.job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@router.post("/sync-references")
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start reference data synchronization"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Create job ID
job_id = str(uuid.uuid4())
# Add job to tracking
main.running_jobs[job_id] = {
'type': 'reference_sync',
'status': 'running',
'cve_id': request.cve_id,
'batch_size': request.batch_size,
'max_cves': request.max_cves,
'force_resync': request.force_resync,
'started_at': datetime.utcnow()
}
# Create cancellation flag
main.job_cancellation_flags[job_id] = False
async def sync_task():
try:
if request.cve_id:
# Single CVE sync
result = await client.sync_cve_references(request.cve_id)
main.running_jobs[job_id]['result'] = result
main.running_jobs[job_id]['status'] = 'completed'
else:
# Bulk sync
result = await client.bulk_sync_references(
batch_size=request.batch_size,
max_cves=request.max_cves,
force_resync=request.force_resync,
cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False)
)
main.running_jobs[job_id]['result'] = result
main.running_jobs[job_id]['status'] = 'completed'
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
except Exception as e:
logger.error(f"Reference sync task failed: {e}")
main.running_jobs[job_id]['status'] = 'failed'
main.running_jobs[job_id]['error'] = str(e)
main.running_jobs[job_id]['completed_at'] = datetime.utcnow()
finally:
# Clean up cancellation flag
main.job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"max_cves": request.max_cves,
"force_resync": request.force_resync
}
except Exception as e:
logger.error(f"Failed to start reference sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
@router.get("/exploitdb-stats")
async def get_exploitdb_stats(db: Session = Depends(get_db)):
"""Get ExploitDB-related statistics"""
try:
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(db)
# Get sync status
status = await client.get_exploitdb_sync_status()
# Get quality distribution from ExploitDB data
quality_distribution = {}
cves_with_exploitdb = db.query(CVE).filter(
text("poc_data::text LIKE '%\"exploitdb\"%'")
).all()
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get category distribution
category_distribution = {}
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
category = exploit.get('category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
return {
"exploitdb_sync_status": status,
"quality_distribution": quality_distribution,
"category_distribution": category_distribution,
"total_exploitdb_cves": len(cves_with_exploitdb),
"total_exploits": sum(
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
for cve in cves_with_exploitdb
if cve.poc_data and 'exploitdb' in cve.poc_data
)
}
except Exception as e:
logger.error(f"Error getting ExploitDB stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/github-poc-stats")
async def get_github_poc_stats(db: Session = Depends(get_db)):
"""Get GitHub PoC-related statistics"""
try:
# Get basic statistics
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
cves_with_github_pocs = db.query(CVE).filter(
CVE.poc_data.isnot(None), # Check if poc_data exists
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).count()
# Get quality distribution
quality_distribution = {}
try:
quality_results = db.query(
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
func.count().label('count')
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).group_by('tier').all()
for tier, count in quality_results:
if tier:
quality_distribution[tier] = count
except Exception as e:
logger.warning(f"Error getting quality distribution: {e}")
quality_distribution = {}
# Calculate average quality score
try:
from sqlalchemy import Integer
avg_quality = db.query(
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).scalar() or 0
except Exception as e:
logger.warning(f"Error calculating average quality: {e}")
avg_quality = 0
return {
'github_poc_rules': github_poc_rules,
'cves_with_github_pocs': cves_with_github_pocs,
'quality_distribution': quality_distribution,
'average_quality_score': float(avg_quality) if avg_quality else 0,
'source': 'github_poc'
}
except Exception as e:
logger.error(f"Error getting GitHub PoC stats: {e}")
return {"error": str(e)}
@router.get("/cisa-kev-stats")
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
"""Get CISA KEV-related statistics"""
try:
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(db)
# Get sync status
status = await client.get_kev_sync_status()
# Get threat level distribution from CISA KEV data
threat_level_distribution = {}
cves_with_kev = db.query(CVE).filter(
text("poc_data::text LIKE '%\"cisa_kev\"%'")
).all()
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_level = vuln_data.get('threat_level', 'unknown')
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
# Get vulnerability category distribution
category_distribution = {}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
category = vuln_data.get('vulnerability_category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
# Get ransomware usage statistics
ransomware_stats = {'known': 0, 'unknown': 0}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
if ransomware_use == 'known':
ransomware_stats['known'] += 1
else:
ransomware_stats['unknown'] += 1
# Calculate average threat score
threat_scores = []
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_score = vuln_data.get('threat_score', 0)
if threat_score:
threat_scores.append(threat_score)
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
return {
"cisa_kev_sync_status": status,
"threat_level_distribution": threat_level_distribution,
"category_distribution": category_distribution,
"ransomware_stats": ransomware_stats,
"average_threat_score": round(avg_threat_score, 2),
"total_kev_cves": len(cves_with_kev),
"total_with_threat_scores": len(threat_scores)
}
except Exception as e:
logger.error(f"Error getting CISA KEV stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job"""
try:
# Find the job in the database
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status not in ['pending', 'running']:
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
# Set cancellation flag
main.job_cancellation_flags[job_id] = True
# Update job status
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} ({job.job_type}) cancelled by user")
return {
"success": True,
"message": f"Job {job_id} cancelled successfully",
"job_id": job_id,
"job_type": job.job_type
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running-jobs")
async def get_running_jobs(db: Session = Depends(get_db)):
"""Get currently running jobs"""
try:
# Get running jobs from database
running_jobs_db = db.query(BulkProcessingJob).filter(
BulkProcessingJob.status.in_(['pending', 'running'])
).order_by(BulkProcessingJob.created_at.desc()).all()
result = []
for job in running_jobs_db:
result.append({
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'started_at': job.started_at,
'created_at': job.created_at,
'can_cancel': job.status in ['pending', 'running']
})
return result
except Exception as e:
logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e))