from typing import List, Optional from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from sqlalchemy.orm import Session from sqlalchemy import func, text import uuid from datetime import datetime import logging from config.database import get_db, SessionLocal from models import BulkProcessingJob, CVE, SigmaRule from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest from services import CVEService, SigmaRuleService # Import global job tracking from main.py import main router = APIRouter(prefix="/api", tags=["bulk-operations"]) logger = logging.getLogger(__name__) @router.post("/bulk-seed") async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start bulk seeding operation""" from bulk_seeder import BulkSeeder async def run_bulk_seed(): try: seeder = BulkSeeder(db) result = await seeder.full_bulk_seed( start_year=request.start_year, end_year=request.end_year, skip_nvd=request.skip_nvd, skip_nomi_sec=request.skip_nomi_sec ) print(f"Bulk seed completed: {result}") except Exception as e: print(f"Bulk seed failed: {str(e)}") background_tasks.add_task(run_bulk_seed) return {"message": "Bulk seeding started", "status": "running"} @router.post("/incremental-update") async def incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start incremental update using NVD modified/recent feeds""" from nvd_bulk_processor import NVDBulkProcessor async def run_incremental_update(): try: processor = NVDBulkProcessor(db) result = await processor.incremental_update() print(f"Incremental update completed: {result}") except Exception as e: print(f"Incremental update failed: {str(e)}") background_tasks.add_task(run_incremental_update) return {"message": "Incremental update started", "status": "running"} @router.get("/bulk-jobs") async def get_bulk_jobs(db: Session = Depends(get_db)): """Get all bulk processing jobs""" jobs = db.query(BulkProcessingJob).order_by(BulkProcessingJob.created_at.desc()).limit(20).all() result = [] for job in jobs: job_dict = { 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'job_metadata': job.job_metadata, 'started_at': job.started_at, 'completed_at': job.completed_at, 'cancelled_at': job.cancelled_at, 'created_at': job.created_at } result.append(job_dict) return result @router.get("/bulk-status") async def get_bulk_status(db: Session = Depends(get_db)): """Get comprehensive bulk processing status""" from bulk_seeder import BulkSeeder seeder = BulkSeeder(db) status = await seeder.get_seeding_status() return status @router.get("/poc-stats") async def get_poc_stats(db: Session = Depends(get_db)): """Get PoC-related statistics""" from sqlalchemy import func, text total_cves = db.query(CVE).count() cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count() # Get PoC quality distribution quality_distribution = db.execute(text(""" SELECT COUNT(*) as total, AVG(poc_count) as avg_poc_count, MAX(poc_count) as max_poc_count FROM cves WHERE poc_count > 0 """)).fetchone() # Get rules with PoC data total_rules = db.query(SigmaRule).count() exploit_based_rules = db.query(SigmaRule).filter(SigmaRule.exploit_based == True).count() return { "total_cves": total_cves, "cves_with_pocs": cves_with_pocs, "poc_coverage_percentage": round((cves_with_pocs / total_cves * 100), 2) if total_cves > 0 else 0, "average_pocs_per_cve": round(quality_distribution.avg_poc_count, 2) if quality_distribution.avg_poc_count else 0, "max_pocs_for_single_cve": quality_distribution.max_poc_count or 0, "total_rules": total_rules, "exploit_based_rules": exploit_based_rules, "exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0 } @router.post("/sync-nomi-sec") async def sync_nomi_sec(background_tasks: BackgroundTasks, request: NomiSecSyncRequest, db: Session = Depends(get_db)): """Synchronize nomi-sec PoC data""" # Create job record job = BulkProcessingJob( job_type='nomi_sec_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) main.running_jobs[job_id] = job main.job_cancellation_flags[job_id] = False async def sync_task(): try: job.status = 'running' job.started_at = datetime.utcnow() db.commit() from nomi_sec_client import NomiSecClient client = NomiSecClient(db) if request.cve_id: # Sync specific CVE if main.job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_pocs(request.cve_id) logger.info(f"Nomi-sec sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_all_cves( batch_size=request.batch_size, cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) ) logger.info(f"Nomi-sec bulk sync completed: {result}") # Update job status if not cancelled if not main.job_cancellation_flags.get(job_id, False): job.status = 'completed' job.completed_at = datetime.utcnow() db.commit() except Exception as e: if not main.job_cancellation_flags.get(job_id, False): job.status = 'failed' job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() logger.error(f"Nomi-sec sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking main.running_jobs.pop(job_id, None) main.job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @router.post("/sync-github-pocs") async def sync_github_pocs(background_tasks: BackgroundTasks, request: GitHubPoCSyncRequest, db: Session = Depends(get_db)): """Synchronize GitHub PoC data""" # Create job record job = BulkProcessingJob( job_type='github_poc_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) main.running_jobs[job_id] = job main.job_cancellation_flags[job_id] = False async def sync_task(): try: job.status = 'running' job.started_at = datetime.utcnow() db.commit() from mcdevitt_poc_client import GitHubPoCClient client = GitHubPoCClient(db) if request.cve_id: # Sync specific CVE if main.job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_pocs(request.cve_id) logger.info(f"GitHub PoC sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_all_cves(batch_size=request.batch_size) logger.info(f"GitHub PoC bulk sync completed: {result}") # Update job status if not cancelled if not main.job_cancellation_flags.get(job_id, False): job.status = 'completed' job.completed_at = datetime.utcnow() db.commit() except Exception as e: if not main.job_cancellation_flags.get(job_id, False): job.status = 'failed' job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() logger.error(f"GitHub PoC sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking main.running_jobs.pop(job_id, None) main.job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @router.post("/sync-exploitdb") async def sync_exploitdb(background_tasks: BackgroundTasks, request: ExploitDBSyncRequest, db: Session = Depends(get_db)): """Synchronize ExploitDB data from git mirror""" # Create job record job = BulkProcessingJob( job_type='exploitdb_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) main.running_jobs[job_id] = job main.job_cancellation_flags[job_id] = False async def sync_task(): # Create a new database session for the background task task_db = SessionLocal() try: # Get the job in the new session task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if not task_job: logger.error(f"Job {job_id} not found in task session") return task_job.status = 'running' task_job.started_at = datetime.utcnow() task_db.commit() from exploitdb_client_local import ExploitDBLocalClient client = ExploitDBLocalClient(task_db) if request.cve_id: # Sync specific CVE if main.job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_exploits(request.cve_id) logger.info(f"ExploitDB sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_exploitdb( batch_size=request.batch_size, cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) ) logger.info(f"ExploitDB bulk sync completed: {result}") # Update job status if not cancelled if not main.job_cancellation_flags.get(job_id, False): task_job.status = 'completed' task_job.completed_at = datetime.utcnow() task_db.commit() except Exception as e: if not main.job_cancellation_flags.get(job_id, False): # Get the job again in case it was modified task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if task_job: task_job.status = 'failed' task_job.error_message = str(e) task_job.completed_at = datetime.utcnow() task_db.commit() logger.error(f"ExploitDB sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking and close the task session main.running_jobs.pop(job_id, None) main.job_cancellation_flags.pop(job_id, None) task_db.close() background_tasks.add_task(sync_task) return { "message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @router.post("/sync-cisa-kev") async def sync_cisa_kev(background_tasks: BackgroundTasks, request: CISAKEVSyncRequest, db: Session = Depends(get_db)): """Synchronize CISA Known Exploited Vulnerabilities data""" # Create job record job = BulkProcessingJob( job_type='cisa_kev_sync', status='pending', job_metadata={ 'cve_id': request.cve_id, 'batch_size': request.batch_size } ) db.add(job) db.commit() db.refresh(job) job_id = str(job.id) main.running_jobs[job_id] = job main.job_cancellation_flags[job_id] = False async def sync_task(): # Create a new database session for the background task task_db = SessionLocal() try: # Get the job in the new session task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if not task_job: logger.error(f"Job {job_id} not found in task session") return task_job.status = 'running' task_job.started_at = datetime.utcnow() task_db.commit() from cisa_kev_client import CISAKEVClient client = CISAKEVClient(task_db) if request.cve_id: # Sync specific CVE if main.job_cancellation_flags.get(job_id, False): logger.info(f"Job {job_id} cancelled before starting") return result = await client.sync_cve_kev_data(request.cve_id) logger.info(f"CISA KEV sync for {request.cve_id}: {result}") else: # Sync all CVEs with cancellation support result = await client.bulk_sync_kev_data( batch_size=request.batch_size, cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) ) logger.info(f"CISA KEV bulk sync completed: {result}") # Update job status if not cancelled if not main.job_cancellation_flags.get(job_id, False): task_job.status = 'completed' task_job.completed_at = datetime.utcnow() task_db.commit() except Exception as e: if not main.job_cancellation_flags.get(job_id, False): # Get the job again in case it was modified task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() if task_job: task_job.status = 'failed' task_job.error_message = str(e) task_job.completed_at = datetime.utcnow() task_db.commit() logger.error(f"CISA KEV sync failed: {e}") import traceback traceback.print_exc() finally: # Clean up tracking and close the task session main.running_jobs.pop(job_id, None) main.job_cancellation_flags.pop(job_id, None) task_db.close() background_tasks.add_task(sync_task) return { "message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size } @router.post("/sync-references") async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Start reference data synchronization""" try: from reference_client import ReferenceClient client = ReferenceClient(db) # Create job ID job_id = str(uuid.uuid4()) # Add job to tracking main.running_jobs[job_id] = { 'type': 'reference_sync', 'status': 'running', 'cve_id': request.cve_id, 'batch_size': request.batch_size, 'max_cves': request.max_cves, 'force_resync': request.force_resync, 'started_at': datetime.utcnow() } # Create cancellation flag main.job_cancellation_flags[job_id] = False async def sync_task(): try: if request.cve_id: # Single CVE sync result = await client.sync_cve_references(request.cve_id) main.running_jobs[job_id]['result'] = result main.running_jobs[job_id]['status'] = 'completed' else: # Bulk sync result = await client.bulk_sync_references( batch_size=request.batch_size, max_cves=request.max_cves, force_resync=request.force_resync, cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) ) main.running_jobs[job_id]['result'] = result main.running_jobs[job_id]['status'] = 'completed' main.running_jobs[job_id]['completed_at'] = datetime.utcnow() except Exception as e: logger.error(f"Reference sync task failed: {e}") main.running_jobs[job_id]['status'] = 'failed' main.running_jobs[job_id]['error'] = str(e) main.running_jobs[job_id]['completed_at'] = datetime.utcnow() finally: # Clean up cancellation flag main.job_cancellation_flags.pop(job_id, None) background_tasks.add_task(sync_task) return { "message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), "status": "started", "job_id": job_id, "cve_id": request.cve_id, "batch_size": request.batch_size, "max_cves": request.max_cves, "force_resync": request.force_resync } except Exception as e: logger.error(f"Failed to start reference sync: {e}") raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}") @router.get("/exploitdb-stats") async def get_exploitdb_stats(db: Session = Depends(get_db)): """Get ExploitDB-related statistics""" try: from exploitdb_client_local import ExploitDBLocalClient client = ExploitDBLocalClient(db) # Get sync status status = await client.get_exploitdb_sync_status() # Get quality distribution from ExploitDB data quality_distribution = {} cves_with_exploitdb = db.query(CVE).filter( text("poc_data::text LIKE '%\"exploitdb\"%'") ).all() for cve in cves_with_exploitdb: if cve.poc_data and 'exploitdb' in cve.poc_data: exploits = cve.poc_data['exploitdb'].get('exploits', []) for exploit in exploits: quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown') quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 # Get category distribution category_distribution = {} for cve in cves_with_exploitdb: if cve.poc_data and 'exploitdb' in cve.poc_data: exploits = cve.poc_data['exploitdb'].get('exploits', []) for exploit in exploits: category = exploit.get('category', 'unknown') category_distribution[category] = category_distribution.get(category, 0) + 1 return { "exploitdb_sync_status": status, "quality_distribution": quality_distribution, "category_distribution": category_distribution, "total_exploitdb_cves": len(cves_with_exploitdb), "total_exploits": sum( len(cve.poc_data.get('exploitdb', {}).get('exploits', [])) for cve in cves_with_exploitdb if cve.poc_data and 'exploitdb' in cve.poc_data ) } except Exception as e: logger.error(f"Error getting ExploitDB stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/github-poc-stats") async def get_github_poc_stats(db: Session = Depends(get_db)): """Get GitHub PoC-related statistics""" try: # Get basic statistics github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count() cves_with_github_pocs = db.query(CVE).filter( CVE.poc_data.isnot(None), # Check if poc_data exists func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).count() # Get quality distribution quality_distribution = {} try: quality_results = db.query( func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'), func.count().label('count') ).filter( CVE.poc_data.isnot(None), func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).group_by('tier').all() for tier, count in quality_results: if tier: quality_distribution[tier] = count except Exception as e: logger.warning(f"Error getting quality distribution: {e}") quality_distribution = {} # Calculate average quality score try: from sqlalchemy import Integer avg_quality = db.query( func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer)) ).filter( CVE.poc_data.isnot(None), func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' ).scalar() or 0 except Exception as e: logger.warning(f"Error calculating average quality: {e}") avg_quality = 0 return { 'github_poc_rules': github_poc_rules, 'cves_with_github_pocs': cves_with_github_pocs, 'quality_distribution': quality_distribution, 'average_quality_score': float(avg_quality) if avg_quality else 0, 'source': 'github_poc' } except Exception as e: logger.error(f"Error getting GitHub PoC stats: {e}") return {"error": str(e)} @router.get("/cisa-kev-stats") async def get_cisa_kev_stats(db: Session = Depends(get_db)): """Get CISA KEV-related statistics""" try: from cisa_kev_client import CISAKEVClient client = CISAKEVClient(db) # Get sync status status = await client.get_kev_sync_status() # Get threat level distribution from CISA KEV data threat_level_distribution = {} cves_with_kev = db.query(CVE).filter( text("poc_data::text LIKE '%\"cisa_kev\"%'") ).all() for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) threat_level = vuln_data.get('threat_level', 'unknown') threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1 # Get vulnerability category distribution category_distribution = {} for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) category = vuln_data.get('vulnerability_category', 'unknown') category_distribution[category] = category_distribution.get(category, 0) + 1 # Get ransomware usage statistics ransomware_stats = {'known': 0, 'unknown': 0} for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower() if ransomware_use == 'known': ransomware_stats['known'] += 1 else: ransomware_stats['unknown'] += 1 # Calculate average threat score threat_scores = [] for cve in cves_with_kev: if cve.poc_data and 'cisa_kev' in cve.poc_data: vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) threat_score = vuln_data.get('threat_score', 0) if threat_score: threat_scores.append(threat_score) avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0 return { "cisa_kev_sync_status": status, "threat_level_distribution": threat_level_distribution, "category_distribution": category_distribution, "ransomware_stats": ransomware_stats, "average_threat_score": round(avg_threat_score, 2), "total_kev_cves": len(cves_with_kev), "total_with_threat_scores": len(threat_scores) } except Exception as e: logger.error(f"Error getting CISA KEV stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/cancel-job/{job_id}") async def cancel_job(job_id: str, db: Session = Depends(get_db)): """Cancel a running job""" try: # Find the job in the database job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Job not found") if job.status not in ['pending', 'running']: raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") # Set cancellation flag main.job_cancellation_flags[job_id] = True # Update job status job.status = 'cancelled' job.cancelled_at = datetime.utcnow() db.commit() logger.info(f"Job {job_id} ({job.job_type}) cancelled by user") return { "success": True, "message": f"Job {job_id} cancelled successfully", "job_id": job_id, "job_type": job.job_type } except HTTPException: raise except Exception as e: logger.error(f"Error cancelling job {job_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/running-jobs") async def get_running_jobs(db: Session = Depends(get_db)): """Get currently running jobs""" try: # Get running jobs from database running_jobs_db = db.query(BulkProcessingJob).filter( BulkProcessingJob.status.in_(['pending', 'running']) ).order_by(BulkProcessingJob.created_at.desc()).all() result = [] for job in running_jobs_db: result.append({ 'id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'year': job.year, 'total_items': job.total_items, 'processed_items': job.processed_items, 'failed_items': job.failed_items, 'error_message': job.error_message, 'started_at': job.started_at, 'created_at': job.created_at, 'can_cancel': job.status in ['pending', 'running'] }) return result except Exception as e: logger.error(f"Error getting running jobs: {e}") raise HTTPException(status_code=500, detail=str(e))