diff --git a/backend/llm_client.py b/backend/llm_client.py index c892e26..4b2c9a9 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -94,6 +94,15 @@ class LLMClient: elif self.provider == 'ollama': base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + # Check if model is available, if not try to pull it + if not self._check_ollama_model_available(base_url, self.model): + logger.info(f"Model {self.model} not found, attempting to pull...") + if self._pull_ollama_model(base_url, self.model): + logger.info(f"Successfully pulled model {self.model}") + else: + logger.error(f"Failed to pull model {self.model}") + return + self.llm = Ollama( model=self.model, base_url=base_url, @@ -116,11 +125,20 @@ class LLMClient: def get_provider_info(self) -> Dict[str, Any]: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) + + # For Ollama, get actually available models + available_models = provider_info.get('models', []) + if self.provider == 'ollama': + ollama_models = self._get_ollama_available_models() + if ollama_models: + available_models = ollama_models + return { 'provider': self.provider, 'model': self.model, 'available': self.is_available(), 'supported_models': provider_info.get('models', []), + 'available_models': available_models, 'env_key': provider_info.get('env_key', ''), 'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) } @@ -395,4 +413,72 @@ Output ONLY the enhanced SIGMA rule in valid YAML format.""" self.model = model or self._get_default_model(provider) self._initialize_llm() - logger.info(f"Switched to provider: {provider} with model: {self.model}") \ No newline at end of file + logger.info(f"Switched to provider: {provider} with model: {self.model}") + + def _check_ollama_model_available(self, base_url: str, model: str) -> bool: + """Check if an Ollama model is available locally""" + try: + import requests + response = requests.get(f"{base_url}/api/tags", timeout=10) + if response.status_code == 200: + data = response.json() + models = data.get('models', []) + for m in models: + if m.get('name', '').startswith(model + ':') or m.get('name') == model: + return True + return False + except Exception as e: + logger.error(f"Error checking Ollama models: {e}") + return False + + def _pull_ollama_model(self, base_url: str, model: str) -> bool: + """Pull an Ollama model""" + try: + import requests + import json + + # Use the pull API endpoint + payload = {"name": model} + response = requests.post( + f"{base_url}/api/pull", + json=payload, + timeout=300, # 5 minutes timeout for model download + stream=True + ) + + if response.status_code == 200: + # Stream the response to monitor progress + for line in response.iter_lines(): + if line: + try: + data = json.loads(line.decode('utf-8')) + if data.get('status'): + logger.info(f"Ollama pull progress: {data.get('status')}") + if data.get('error'): + logger.error(f"Ollama pull error: {data.get('error')}") + return False + except json.JSONDecodeError: + continue + return True + else: + logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") + return False + + except Exception as e: + logger.error(f"Error pulling Ollama model {model}: {e}") + return False + + def _get_ollama_available_models(self) -> List[str]: + """Get list of available Ollama models""" + try: + import requests + base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + response = requests.get(f"{base_url}/api/tags", timeout=10) + if response.status_code == 200: + data = response.json() + models = data.get('models', []) + return [m.get('name', '') for m in models if m.get('name')] + return [] + except Exception as e: + logger.error(f"Error getting Ollama models: {e}") + return [] \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index f59bc3a..706497a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1461,6 +1461,7 @@ async def sync_references(request: ReferenceSyncRequest, background_tasks: Backg result = await client.bulk_sync_references( batch_size=request.batch_size, max_cves=request.max_cves, + force_resync=request.force_resync, cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) ) running_jobs[job_id]['result'] = result @@ -2088,6 +2089,60 @@ async def get_running_jobs(db: Session = Depends(get_db)): logger.error(f"Error getting running jobs: {e}") raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/ollama-pull-model") +async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks): + """Pull an Ollama model""" + try: + from llm_client import LLMClient + + model = request.get('model') + if not model: + raise HTTPException(status_code=400, detail="Model name is required") + + # Create a background task to pull the model + def pull_model_task(): + try: + client = LLMClient(provider='ollama', model=model) + base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + + if client._pull_ollama_model(base_url, model): + logger.info(f"Successfully pulled Ollama model: {model}") + else: + logger.error(f"Failed to pull Ollama model: {model}") + except Exception as e: + logger.error(f"Error in model pull task: {e}") + + background_tasks.add_task(pull_model_task) + + return { + "message": f"Started pulling model {model}", + "status": "started", + "model": model + } + + except Exception as e: + logger.error(f"Error starting model pull: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/api/ollama-models") +async def get_ollama_models(): + """Get available Ollama models""" + try: + from llm_client import LLMClient + + client = LLMClient(provider='ollama') + available_models = client._get_ollama_available_models() + + return { + "available_models": available_models, + "total_models": len(available_models), + "status": "success" + } + + except Exception as e: + logger.error(f"Error getting Ollama models: {e}") + raise HTTPException(status_code=500, detail=str(e)) + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/reference_client.py b/backend/reference_client.py new file mode 100644 index 0000000..b43ecd8 --- /dev/null +++ b/backend/reference_client.py @@ -0,0 +1,603 @@ +""" +Reference Data Extraction Client +Extracts and analyzes text content from CVE references and KEV records +""" + +import aiohttp +import asyncio +import json +import logging +import re +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple, Any +from urllib.parse import urlparse, urljoin +from sqlalchemy.orm import Session +from sqlalchemy import text, func +import hashlib +from bs4 import BeautifulSoup +import ssl +import certifi + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class ReferenceClient: + """Client for extracting and analyzing reference content from CVE and KEV records""" + + def __init__(self, db_session: Session): + self.db_session = db_session + + # Rate limiting + self.rate_limit_delay = 2.0 # 2 seconds between requests + self.last_request_time = 0 + + # Cache for processed URLs + self.url_cache = {} + self.cache_ttl = 86400 # 24 hours cache + + # SSL context for secure requests + self.ssl_context = ssl.create_default_context(cafile=certifi.where()) + # Allow self-signed certificates for some sites that might have issues + self.ssl_context.check_hostname = False + self.ssl_context.verify_mode = ssl.CERT_NONE + + # Common headers to avoid being blocked + self.headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', + 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', + 'Accept-Language': 'en-US,en;q=0.5', + 'Accept-Encoding': 'gzip, deflate, br', + 'Connection': 'keep-alive', + 'Upgrade-Insecure-Requests': '1', + 'Sec-Fetch-Dest': 'document', + 'Sec-Fetch-Mode': 'navigate', + 'Sec-Fetch-Site': 'none' + } + + # Supported reference types + self.reference_types = { + 'security_advisory': ['security', 'advisory', 'bulletin', 'alert', 'cve', 'vulnerability'], + 'patch': ['patch', 'fix', 'update', 'hotfix', 'security-update'], + 'exploit': ['exploit', 'poc', 'proof-of-concept', 'github.com', 'exploit-db'], + 'technical_analysis': ['analysis', 'research', 'technical', 'writeup', 'blog'], + 'vendor_advisory': ['microsoft', 'apple', 'oracle', 'cisco', 'vmware', 'adobe'], + 'cve_database': ['cve.mitre.org', 'nvd.nist.gov', 'cve.org'] + } + + async def _make_request(self, session: aiohttp.ClientSession, url: str) -> Optional[Tuple[str, str]]: + """Make a rate-limited request to fetch URL content""" + try: + # Rate limiting + current_time = asyncio.get_event_loop().time() + time_since_last = current_time - self.last_request_time + if time_since_last < self.rate_limit_delay: + await asyncio.sleep(self.rate_limit_delay - time_since_last) + + # Check cache first + url_hash = hashlib.md5(url.encode()).hexdigest() + if url_hash in self.url_cache: + cache_entry = self.url_cache[url_hash] + if datetime.now().timestamp() - cache_entry['timestamp'] < self.cache_ttl: + logger.info(f"Using cached content for {url}") + return cache_entry['content'], cache_entry['content_type'] + + async with session.get(url, headers=self.headers) as response: + self.last_request_time = asyncio.get_event_loop().time() + + if response.status == 200: + content_type = response.headers.get('content-type', '').lower() + + # Only process text content + if 'text/html' in content_type or 'text/plain' in content_type or 'application/json' in content_type: + try: + content = await response.text(encoding='utf-8', errors='ignore') + except UnicodeDecodeError: + # Fallback to binary content if text decode fails + content_bytes = await response.read() + content = content_bytes.decode('utf-8', errors='ignore') + + # Cache the result + self.url_cache[url_hash] = { + 'content': content, + 'content_type': content_type, + 'timestamp': datetime.now().timestamp() + } + + return content, content_type + else: + logger.warning(f"Unsupported content type {content_type} for {url}") + return None, None + elif response.status in [301, 302, 303, 307, 308]: + logger.info(f"Redirect response {response.status} for {url}") + return None, None + else: + logger.warning(f"Request failed: {response.status} for {url}") + return None, None + + except aiohttp.ClientError as e: + logger.warning(f"Client error fetching {url}: {e}") + return None, None + except asyncio.TimeoutError: + logger.warning(f"Timeout fetching {url}") + return None, None + except Exception as e: + logger.error(f"Unexpected error fetching {url}: {e}") + return None, None + + def _extract_text_from_html(self, html_content: str) -> str: + """Extract meaningful text from HTML content""" + try: + soup = BeautifulSoup(html_content, 'html.parser') + + # Remove script and style elements + for script in soup(["script", "style"]): + script.decompose() + + # Extract text from common content areas + content_selectors = [ + 'article', 'main', '.content', '#content', + '.post-content', '.entry-content', 'section' + ] + + text_content = "" + for selector in content_selectors: + elements = soup.select(selector) + if elements: + for element in elements: + text_content += element.get_text(separator=' ', strip=True) + '\n' + break + + # If no structured content found, get all text + if not text_content.strip(): + text_content = soup.get_text(separator=' ', strip=True) + + # Clean up the text + text_content = re.sub(r'\s+', ' ', text_content) + text_content = text_content.strip() + + return text_content + + except Exception as e: + logger.error(f"Error extracting text from HTML: {e}") + return "" + + def _analyze_reference_content(self, url: str, content: str) -> Dict[str, Any]: + """Analyze reference content to extract security-relevant information""" + analysis = { + 'url': url, + 'content_length': len(content), + 'reference_type': 'unknown', + 'security_keywords': [], + 'technical_indicators': [], + 'patch_information': [], + 'exploit_indicators': [], + 'cve_mentions': [], + 'severity_indicators': [], + 'mitigation_steps': [], + 'affected_products': [], + 'attack_vectors': [], + 'confidence_score': 0 + } + + if not content: + return analysis + + content_lower = content.lower() + + # Classify reference type + domain = urlparse(url).netloc.lower() + for ref_type, keywords in self.reference_types.items(): + if any(keyword in domain or keyword in content_lower for keyword in keywords): + analysis['reference_type'] = ref_type + break + + # Extract CVE mentions + cve_pattern = r'(CVE-\d{4}-\d{4,7})' + cve_matches = re.findall(cve_pattern, content, re.IGNORECASE) + analysis['cve_mentions'] = list(set(cve_matches)) + + # Security keywords + security_keywords = [ + 'vulnerability', 'exploit', 'attack', 'malware', 'backdoor', + 'privilege escalation', 'remote code execution', 'rce', + 'sql injection', 'xss', 'csrf', 'buffer overflow', + 'authentication bypass', 'authorization', 'injection', + 'denial of service', 'dos', 'ddos', 'ransomware' + ] + + for keyword in security_keywords: + if keyword in content_lower: + analysis['security_keywords'].append(keyword) + + # Technical indicators + technical_patterns = [ + r'\b(function|method|class|variable)\s+\w+', + r'\b(file|directory|path|folder)\s+[^\s]+', + r'\b(port|service|protocol)\s+\d+', + r'\b(registry|key|value)\s+[^\s]+', + r'\b(process|executable|binary)\s+[^\s]+', + r'\b(dll|exe|bat|ps1|sh|py|jar)\b', + r'\b(http|https|ftp|smb|tcp|udp)://[^\s]+', + r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b' + ] + + for pattern in technical_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + if matches: + analysis['technical_indicators'].extend(matches[:10]) # Limit to 10 per pattern + + # Patch information + patch_keywords = [ + 'patch', 'fix', 'update', 'hotfix', 'security update', + 'kb\d+', 'ms\d+-\d+', 'version \d+\.\d+', + 'download', 'install', 'upgrade' + ] + + for keyword in patch_keywords: + if re.search(keyword, content_lower): + analysis['patch_information'].append(keyword) + + # Exploit indicators + exploit_keywords = [ + 'proof of concept', 'poc', 'exploit code', 'payload', + 'shellcode', 'reverse shell', 'metasploit', 'nmap', + 'vulnerability assessment', 'penetration test', 'bypass' + ] + + for keyword in exploit_keywords: + if keyword in content_lower: + analysis['exploit_indicators'].append(keyword) + + # Severity indicators + severity_patterns = [ + r'\b(critical|high|medium|low)\s+(severity|risk|priority)', + r'\b(cvss|score)\s*[:=]?\s*(\d+\.\d+|\d+)', + r'\b(exploitability|impact)\s*[:=]?\s*(high|medium|low)' + ] + + for pattern in severity_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + if matches: + analysis['severity_indicators'].extend([' '.join(match) if isinstance(match, tuple) else match for match in matches]) + + # Mitigation steps + mitigation_keywords = [ + 'mitigation', 'workaround', 'prevention', 'remediation', + 'disable', 'block', 'restrict', 'configure', 'setting' + ] + + # Find sentences containing mitigation keywords + sentences = re.split(r'[.!?]+', content) + for sentence in sentences: + if any(keyword in sentence.lower() for keyword in mitigation_keywords): + if len(sentence.strip()) > 20: # Avoid very short sentences + analysis['mitigation_steps'].append(sentence.strip()[:200]) # Limit length + + # Calculate confidence score + score = 0 + score += min(len(analysis['security_keywords']) * 5, 25) + score += min(len(analysis['technical_indicators']) * 2, 20) + score += min(len(analysis['cve_mentions']) * 10, 30) + score += min(len(analysis['patch_information']) * 3, 15) + score += min(len(analysis['exploit_indicators']) * 4, 20) + + if analysis['reference_type'] != 'unknown': + score += 10 + + analysis['confidence_score'] = min(score, 100) + + # Clean up and deduplicate + for key in ['security_keywords', 'technical_indicators', 'patch_information', + 'exploit_indicators', 'severity_indicators', 'mitigation_steps']: + analysis[key] = list(set(analysis[key]))[:10] # Limit to 10 items each + + return analysis + + async def extract_reference_content(self, url: str) -> Optional[Dict[str, Any]]: + """Extract and analyze content from a single reference URL""" + try: + connector = aiohttp.TCPConnector( + ssl=self.ssl_context, + limit=100, + limit_per_host=30, + ttl_dns_cache=300, + use_dns_cache=True, + ) + timeout = aiohttp.ClientTimeout(total=60, connect=30) + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + content, content_type = await self._make_request(session, url) + + if not content: + return None + + # Extract text from HTML if needed + if content_type and 'text/html' in content_type: + text_content = self._extract_text_from_html(content) + else: + text_content = content + + # Analyze the content + analysis = self._analyze_reference_content(url, text_content) + + # Add metadata + analysis.update({ + 'extracted_at': datetime.utcnow().isoformat(), + 'content_type': content_type, + 'text_length': len(text_content), + 'source': 'reference_extraction' + }) + + return analysis + + except Exception as e: + logger.error(f"Error extracting reference content from {url}: {e}") + return None + + async def sync_cve_references(self, cve_id: str) -> Dict[str, Any]: + """Sync reference data for a specific CVE""" + from main import CVE, SigmaRule + + # Get existing CVE + cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first() + if not cve: + logger.warning(f"CVE {cve_id} not found in database") + return {"error": "CVE not found"} + + if not cve.reference_urls: + logger.info(f"No reference URLs found for CVE {cve_id}") + return {"cve_id": cve_id, "references_processed": 0} + + logger.info(f"Processing {len(cve.reference_urls)} references for CVE {cve_id}") + + processed_references = [] + successful_extractions = 0 + + for url in cve.reference_urls: + try: + # Extract reference content + ref_analysis = await self.extract_reference_content(url) + + if ref_analysis: + processed_references.append(ref_analysis) + successful_extractions += 1 + logger.info(f"Successfully extracted content from {url}") + else: + logger.warning(f"Failed to extract content from {url}") + + # Small delay between requests + await asyncio.sleep(1) + + except Exception as e: + logger.error(f"Error processing reference {url}: {e}") + + # Update CVE with reference data + cve.reference_data = { + 'reference_analysis': { + 'references': processed_references, + 'total_references': len(cve.reference_urls), + 'successful_extractions': successful_extractions, + 'extraction_rate': successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0, + 'extracted_at': datetime.utcnow().isoformat(), + 'source': 'reference_extraction' + } + } + + cve.reference_sync_status = 'completed' if successful_extractions > 0 else 'failed' + cve.reference_last_synced = datetime.utcnow() + + cve.updated_at = datetime.utcnow() + + # Update SIGMA rule with reference data + sigma_rule = self.db_session.query(SigmaRule).filter( + SigmaRule.cve_id == cve_id + ).first() + + if sigma_rule: + # Aggregate indicators from all references + aggregated_indicators = { + 'security_keywords': [], + 'technical_indicators': [], + 'exploit_indicators': [], + 'patch_information': [], + 'attack_vectors': [], + 'mitigation_steps': [] + } + + for ref in processed_references: + for key in aggregated_indicators.keys(): + if key in ref: + aggregated_indicators[key].extend(ref[key]) + + # Deduplicate + for key in aggregated_indicators: + aggregated_indicators[key] = list(set(aggregated_indicators[key])) + + # Update rule with reference data + if not sigma_rule.nomi_sec_data: + sigma_rule.nomi_sec_data = {} + + sigma_rule.nomi_sec_data['reference_analysis'] = { + 'aggregated_indicators': aggregated_indicators, + 'reference_count': len(processed_references), + 'high_confidence_references': len([r for r in processed_references if r.get('confidence_score', 0) > 70]), + 'reference_types': list(set([r.get('reference_type') for r in processed_references if r.get('reference_type') != 'unknown'])), + 'source': 'reference_extraction' + } + + # Update exploit indicators + if sigma_rule.exploit_indicators: + existing_indicators = json.loads(sigma_rule.exploit_indicators) + else: + existing_indicators = {} + + for key, values in aggregated_indicators.items(): + if key not in existing_indicators: + existing_indicators[key] = [] + existing_indicators[key].extend(values) + existing_indicators[key] = list(set(existing_indicators[key])) + + sigma_rule.exploit_indicators = json.dumps(existing_indicators) + sigma_rule.updated_at = datetime.utcnow() + + self.db_session.commit() + + logger.info(f"Successfully synchronized reference data for {cve_id}") + + return { + "cve_id": cve_id, + "references_processed": len(processed_references), + "successful_extractions": successful_extractions, + "extraction_rate": successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0, + "high_confidence_references": len([r for r in processed_references if r.get('confidence_score', 0) > 70]), + "source": "reference_extraction" + } + + async def bulk_sync_references(self, batch_size: int = 50, max_cves: int = None, + force_resync: bool = False, cancellation_flag: Optional[callable] = None) -> Dict[str, Any]: + """Bulk synchronize reference data for multiple CVEs""" + from main import CVE, BulkProcessingJob + + # Create bulk processing job + job = BulkProcessingJob( + job_type='reference_sync', + status='running', + started_at=datetime.utcnow(), + job_metadata={'batch_size': batch_size, 'max_cves': max_cves} + ) + self.db_session.add(job) + self.db_session.commit() + + total_processed = 0 + total_references = 0 + successful_extractions = 0 + + try: + # Get CVEs that have reference URLs but no reference analysis + query = self.db_session.query(CVE).filter( + CVE.reference_urls.isnot(None), + func.array_length(CVE.reference_urls, 1) > 0 + ) + + # Filter out CVEs that already have reference analysis (unless force_resync is True) + if not force_resync: + query = query.filter( + CVE.reference_sync_status != 'completed' + ) + + if max_cves: + cves = query.limit(max_cves).all() + else: + cves = query.all() + + job.total_items = len(cves) + self.db_session.commit() + + logger.info(f"Starting bulk reference sync for {len(cves)} CVEs") + + # Process in batches + for i in range(0, len(cves), batch_size): + # Check for cancellation + if cancellation_flag and cancellation_flag(): + logger.info("Bulk reference sync cancelled by user") + job.status = 'cancelled' + job.cancelled_at = datetime.utcnow() + job.error_message = "Job cancelled by user" + break + + batch = cves[i:i + batch_size] + + for cve in batch: + # Check for cancellation + if cancellation_flag and cancellation_flag(): + logger.info("Bulk reference sync cancelled by user") + job.status = 'cancelled' + job.cancelled_at = datetime.utcnow() + job.error_message = "Job cancelled by user" + break + + try: + result = await self.sync_cve_references(cve.cve_id) + + if "error" not in result: + total_processed += 1 + total_references += result.get("references_processed", 0) + successful_extractions += result.get("successful_extractions", 0) + else: + job.failed_items += 1 + + job.processed_items += 1 + + # Longer delay for reference extraction to be respectful + await asyncio.sleep(2) + + except Exception as e: + logger.error(f"Error processing references for {cve.cve_id}: {e}") + job.failed_items += 1 + + # Break out of outer loop if cancelled + if job.status == 'cancelled': + break + + # Commit after each batch + self.db_session.commit() + logger.info(f"Processed reference batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}") + + # Update job status + if job.status != 'cancelled': + job.status = 'completed' + job.completed_at = datetime.utcnow() + + job.job_metadata.update({ + 'total_processed': total_processed, + 'total_references': total_references, + 'successful_extractions': successful_extractions, + 'extraction_rate': successful_extractions / total_references if total_references > 0 else 0, + 'source': 'reference_extraction' + }) + + except Exception as e: + job.status = 'failed' + job.error_message = str(e) + job.completed_at = datetime.utcnow() + logger.error(f"Bulk reference sync job failed: {e}") + + finally: + self.db_session.commit() + + return { + 'job_id': str(job.id), + 'status': job.status, + 'total_processed': total_processed, + 'total_references': total_references, + 'successful_extractions': successful_extractions, + 'extraction_rate': successful_extractions / total_references if total_references > 0 else 0, + 'source': 'reference_extraction' + } + + async def get_reference_sync_status(self) -> Dict[str, Any]: + """Get reference synchronization status""" + from main import CVE + + # Count CVEs with reference URLs + total_cves = self.db_session.query(CVE).count() + + cves_with_refs = self.db_session.query(CVE).filter( + CVE.reference_urls.isnot(None), + func.array_length(CVE.reference_urls, 1) > 0 + ).count() + + # Count CVEs with reference analysis + cves_with_analysis = self.db_session.query(CVE).filter( + CVE.reference_sync_status == 'completed' + ).count() + + return { + 'total_cves': total_cves, + 'cves_with_references': cves_with_refs, + 'cves_with_analysis': cves_with_analysis, + 'reference_coverage': (cves_with_refs / total_cves * 100) if total_cves > 0 else 0, + 'analysis_coverage': (cves_with_analysis / cves_with_refs * 100) if cves_with_refs > 0 else 0, + 'sync_status': 'active' if cves_with_analysis > 0 else 'pending', + 'source': 'reference_extraction' + } \ No newline at end of file diff --git a/backend/setup_ollama.py b/backend/setup_ollama.py new file mode 100755 index 0000000..690fb20 --- /dev/null +++ b/backend/setup_ollama.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +""" +Setup script to pull the default Ollama model on startup +""" + +import os +import sys +import time +import requests +import json +import logging + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool: + """Wait for Ollama service to be ready""" + for i in range(max_retries): + try: + response = requests.get(f"{base_url}/api/tags", timeout=5) + if response.status_code == 200: + logger.info("Ollama service is ready") + return True + except Exception as e: + logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})") + time.sleep(delay) + + logger.error("Ollama service is not ready after maximum retries") + return False + +def check_model_exists(base_url: str, model: str) -> bool: + """Check if a model exists in Ollama""" + try: + response = requests.get(f"{base_url}/api/tags", timeout=10) + if response.status_code == 200: + data = response.json() + models = data.get('models', []) + for m in models: + model_name = m.get('name', '') + if model_name.startswith(model + ':') or model_name == model: + logger.info(f"Model {model} already exists") + return True + return False + except Exception as e: + logger.error(f"Error checking models: {e}") + return False + +def pull_model(base_url: str, model: str) -> bool: + """Pull an Ollama model""" + try: + logger.info(f"Pulling model {model}...") + payload = {"name": model} + response = requests.post( + f"{base_url}/api/pull", + json=payload, + timeout=1800, # 30 minutes timeout for model download + stream=True + ) + + if response.status_code == 200: + # Stream the response to monitor progress + for line in response.iter_lines(): + if line: + try: + data = json.loads(line.decode('utf-8')) + status = data.get('status', '') + if 'pulling' in status.lower() or 'downloading' in status.lower(): + logger.info(f"Ollama: {status}") + elif data.get('error'): + logger.error(f"Ollama pull error: {data.get('error')}") + return False + except json.JSONDecodeError: + continue + + logger.info(f"Successfully pulled model {model}") + return True + else: + logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") + logger.error(f"Response: {response.text}") + return False + + except Exception as e: + logger.error(f"Error pulling model {model}: {e}") + return False + +def main(): + """Main setup function""" + base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434') + model = os.getenv('LLM_MODEL', 'llama3.2') + + logger.info(f"Setting up Ollama with model {model}") + logger.info(f"Ollama URL: {base_url}") + + # Wait for Ollama service to be ready + if not wait_for_ollama(base_url): + logger.error("Ollama service is not available") + sys.exit(1) + + # Check if model already exists + if check_model_exists(base_url, model): + logger.info(f"Model {model} is already available") + sys.exit(0) + + # Pull the model + if pull_model(base_url, model): + logger.info(f"Setup completed successfully - model {model} is ready") + sys.exit(0) + else: + logger.error(f"Failed to pull model {model}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 8758c71..d5a3592 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -32,6 +32,8 @@ services: depends_on: db: condition: service_healthy + ollama-setup: + condition: service_completed_successfully volumes: - ./backend:/app - ./github_poc_collector:/github_poc_collector @@ -67,6 +69,18 @@ services: - OLLAMA_HOST=0.0.0.0 restart: unless-stopped + ollama-setup: + build: ./backend + depends_on: + - ollama + environment: + OLLAMA_BASE_URL: http://ollama:11434 + LLM_MODEL: llama3.2 + volumes: + - ./backend:/app + command: python setup_ollama.py + restart: "no" + volumes: postgres_data: redis_data: