auto_sigma_rule_generator/backend/nvd_bulk_processor.py
2025-07-08 17:50:01 -05:00

483 lines
No EOL
18 KiB
Python

"""
NVD JSON Dataset Bulk Processor
Downloads and processes NVD JSON data feeds for comprehensive CVE seeding
"""
import requests
import json
import gzip
import zipfile
import os
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import asyncio
import aiohttp
from pathlib import Path
import hashlib
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NVDBulkProcessor:
"""Handles bulk downloading and processing of NVD JSON data feeds"""
def __init__(self, db_session: Session, data_dir: str = "./nvd_data"):
self.db_session = db_session
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
self.api_key = os.getenv("NVD_API_KEY")
# NVD JSON 2.0 feed URLs
self.base_url = "https://nvd.nist.gov/feeds/json/cve/1.1"
self.feed_urls = {
"modified": f"{self.base_url}/nvdcve-1.1-modified.json.gz",
"recent": f"{self.base_url}/nvdcve-1.1-recent.json.gz"
}
# Rate limiting
self.rate_limit_delay = 0.6 # 600ms between requests
self.last_request_time = 0
def get_year_feed_url(self, year: int) -> str:
"""Get the URL for a specific year's CVE feed"""
return f"{self.base_url}/nvdcve-1.1-{year}.json.gz"
def get_meta_url(self, feed_url: str) -> str:
"""Get the metadata URL for a feed"""
return feed_url.replace(".json.gz", ".meta")
async def download_file(self, session: aiohttp.ClientSession, url: str,
destination: Path, check_meta: bool = True) -> bool:
"""Download a file with metadata checking"""
try:
# Check if we should download based on metadata
if check_meta:
meta_url = self.get_meta_url(url)
should_download = await self._should_download_file(session, meta_url, destination)
if not should_download:
logger.info(f"Skipping {url} - file is up to date")
return True
# Rate limiting
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.rate_limit_delay:
await asyncio.sleep(self.rate_limit_delay - time_since_last)
# Download the file
headers = {}
if self.api_key:
headers["apiKey"] = self.api_key
async with session.get(url, headers=headers, timeout=30) as response:
if response.status == 200:
content = await response.read()
destination.write_bytes(content)
logger.info(f"Downloaded {url} -> {destination}")
self.last_request_time = time.time()
return True
else:
logger.error(f"Failed to download {url}: HTTP {response.status}")
return False
except Exception as e:
logger.error(f"Error downloading {url}: {e}")
return False
async def _should_download_file(self, session: aiohttp.ClientSession,
meta_url: str, destination: Path) -> bool:
"""Check if file should be downloaded based on metadata"""
try:
# Download metadata
async with session.get(meta_url, timeout=10) as response:
if response.status != 200:
return True # Download if we can't get metadata
meta_content = await response.text()
# Parse metadata
meta_data = {}
for line in meta_content.strip().split('\n'):
if ':' in line:
key, value = line.split(':', 1)
meta_data[key.strip()] = value.strip()
# Check if local file exists and matches
if destination.exists():
local_size = destination.stat().st_size
remote_size = int(meta_data.get('size', 0))
remote_sha256 = meta_data.get('sha256', '')
if local_size == remote_size and remote_sha256:
# Verify SHA256 if available
local_sha256 = self._calculate_sha256(destination)
if local_sha256 == remote_sha256:
return False # File is up to date
return True # Download needed
except Exception as e:
logger.warning(f"Error checking metadata for {meta_url}: {e}")
return True # Download if metadata check fails
def _calculate_sha256(self, file_path: Path) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
async def download_all_feeds(self, start_year: int = 2002,
end_year: Optional[int] = None) -> List[Path]:
"""Download all NVD JSON feeds"""
if end_year is None:
end_year = datetime.now().year
downloaded_files = []
async with aiohttp.ClientSession() as session:
# Download year-based feeds
for year in range(start_year, end_year + 1):
url = self.get_year_feed_url(year)
filename = f"nvdcve-1.1-{year}.json.gz"
destination = self.data_dir / filename
if await self.download_file(session, url, destination):
downloaded_files.append(destination)
# Download modified and recent feeds
for feed_name, url in self.feed_urls.items():
filename = f"nvdcve-1.1-{feed_name}.json.gz"
destination = self.data_dir / filename
if await self.download_file(session, url, destination):
downloaded_files.append(destination)
return downloaded_files
def extract_json_file(self, compressed_file: Path) -> Path:
"""Extract JSON from compressed file"""
json_file = compressed_file.with_suffix('.json')
try:
if compressed_file.suffix == '.gz':
with gzip.open(compressed_file, 'rt', encoding='utf-8') as f_in:
with open(json_file, 'w', encoding='utf-8') as f_out:
f_out.write(f_in.read())
elif compressed_file.suffix == '.zip':
with zipfile.ZipFile(compressed_file, 'r') as zip_ref:
zip_ref.extractall(self.data_dir)
else:
# File is already uncompressed
return compressed_file
logger.info(f"Extracted {compressed_file} -> {json_file}")
return json_file
except Exception as e:
logger.error(f"Error extracting {compressed_file}: {e}")
raise
def process_json_file(self, json_file: Path) -> Tuple[int, int]:
"""Process a single JSON file and return (processed, failed) counts"""
from main import CVE, BulkProcessingJob
processed_count = 0
failed_count = 0
try:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
cve_items = data.get('CVE_Items', [])
logger.info(f"Processing {len(cve_items)} CVEs from {json_file}")
for cve_item in cve_items:
try:
cve_data = self._extract_cve_data(cve_item)
if cve_data:
self._store_cve_data(cve_data)
processed_count += 1
else:
failed_count += 1
except Exception as e:
logger.error(f"Error processing CVE item: {e}")
failed_count += 1
# Commit changes
self.db_session.commit()
logger.info(f"Processed {processed_count} CVEs, failed: {failed_count}")
except Exception as e:
logger.error(f"Error processing {json_file}: {e}")
self.db_session.rollback()
raise
return processed_count, failed_count
def _extract_cve_data(self, cve_item: dict) -> Optional[dict]:
"""Extract CVE data from JSON item"""
try:
cve = cve_item.get('cve', {})
impact = cve_item.get('impact', {})
cve_id = cve.get('CVE_data_meta', {}).get('ID', '')
if not cve_id:
return None
# Description
description_data = cve.get('description', {}).get('description_data', [])
description = ''
if description_data:
description = description_data[0].get('value', '')
# CVSS Score
cvss_score = None
severity = None
if 'baseMetricV3' in impact:
cvss_v3 = impact['baseMetricV3'].get('cvssV3', {})
cvss_score = cvss_v3.get('baseScore')
severity = cvss_v3.get('baseSeverity', '').lower()
elif 'baseMetricV2' in impact:
cvss_v2 = impact['baseMetricV2'].get('cvssV2', {})
cvss_score = cvss_v2.get('baseScore')
severity = impact['baseMetricV2'].get('severity', '').lower()
# Dates
published_date = None
modified_date = None
if 'publishedDate' in cve_item:
published_date = datetime.fromisoformat(
cve_item['publishedDate'].replace('Z', '+00:00')
)
if 'lastModifiedDate' in cve_item:
modified_date = datetime.fromisoformat(
cve_item['lastModifiedDate'].replace('Z', '+00:00')
)
# Affected products (from CPE data)
affected_products = []
configurations = cve_item.get('configurations', {})
for node in configurations.get('nodes', []):
for cpe_match in node.get('cpe_match', []):
if cpe_match.get('vulnerable', False):
cpe_uri = cpe_match.get('cpe23Uri', '')
if cpe_uri:
affected_products.append(cpe_uri)
# Reference URLs
reference_urls = []
references = cve.get('references', {}).get('reference_data', [])
for ref in references:
url = ref.get('url', '')
if url:
reference_urls.append(url)
return {
'cve_id': cve_id,
'description': description,
'cvss_score': cvss_score,
'severity': severity,
'published_date': published_date,
'modified_date': modified_date,
'affected_products': affected_products,
'reference_urls': reference_urls,
'data_source': 'nvd_bulk',
'nvd_json_version': '1.1',
'bulk_processed': True
}
except Exception as e:
logger.error(f"Error extracting CVE data: {e}")
return None
def _store_cve_data(self, cve_data: dict):
"""Store CVE data in database"""
from main import CVE
# Check if CVE already exists
existing_cve = self.db_session.query(CVE).filter(
CVE.cve_id == cve_data['cve_id']
).first()
if existing_cve:
# Update existing CVE
for key, value in cve_data.items():
setattr(existing_cve, key, value)
existing_cve.updated_at = datetime.utcnow()
logger.debug(f"Updated CVE {cve_data['cve_id']}")
else:
# Create new CVE
new_cve = CVE(**cve_data)
self.db_session.add(new_cve)
logger.debug(f"Created new CVE {cve_data['cve_id']}")
async def bulk_seed_database(self, start_year: int = 2002,
end_year: Optional[int] = None) -> dict:
"""Perform complete bulk seeding of the database"""
from main import BulkProcessingJob
if end_year is None:
end_year = datetime.now().year
# Create bulk processing job
job = BulkProcessingJob(
job_type='nvd_bulk_seed',
status='running',
started_at=datetime.utcnow(),
job_metadata={
'start_year': start_year,
'end_year': end_year,
'total_years': end_year - start_year + 1
}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_failed = 0
results = []
try:
# Download all feeds
logger.info(f"Starting bulk seed from {start_year} to {end_year}")
downloaded_files = await self.download_all_feeds(start_year, end_year)
job.total_items = len(downloaded_files)
self.db_session.commit()
# Process each file
for file_path in downloaded_files:
try:
# Extract JSON file
json_file = self.extract_json_file(file_path)
# Process the JSON file
processed, failed = self.process_json_file(json_file)
total_processed += processed
total_failed += failed
job.processed_items += 1
results.append({
'file': file_path.name,
'processed': processed,
'failed': failed
})
# Clean up extracted file if it's different from original
if json_file != file_path:
json_file.unlink()
self.db_session.commit()
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
job.failed_items += 1
total_failed += 1
self.db_session.commit()
# Update job status
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_failed': total_failed,
'results': results
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Bulk seed job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_failed': total_failed,
'results': results
}
async def incremental_update(self) -> dict:
"""Perform incremental update using modified and recent feeds"""
from main import BulkProcessingJob
# Create incremental update job
job = BulkProcessingJob(
job_type='incremental_update',
status='running',
started_at=datetime.utcnow(),
job_metadata={'feeds': ['modified', 'recent']}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_failed = 0
results = []
try:
# Download modified and recent feeds
async with aiohttp.ClientSession() as session:
for feed_name, url in self.feed_urls.items():
filename = f"nvdcve-1.1-{feed_name}.json.gz"
destination = self.data_dir / filename
if await self.download_file(session, url, destination):
try:
json_file = self.extract_json_file(destination)
processed, failed = self.process_json_file(json_file)
total_processed += processed
total_failed += failed
results.append({
'feed': feed_name,
'processed': processed,
'failed': failed
})
# Clean up
if json_file != destination:
json_file.unlink()
except Exception as e:
logger.error(f"Error processing {feed_name} feed: {e}")
total_failed += 1
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_failed': total_failed,
'results': results
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Incremental update job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_failed': total_failed,
'results': results
}