""" SIGMA rule generation tasks for Celery """ import asyncio import logging from typing import Dict, Any, List, Optional from celery import current_task from celery_config import celery_app, get_db_session from enhanced_sigma_generator import EnhancedSigmaGenerator from llm_client import LLMClient logger = logging.getLogger(__name__) @celery_app.task(bind=True, name='sigma_tasks.generate_enhanced_rules') def generate_enhanced_rules_task(self, cve_ids: Optional[List[str]] = None) -> Dict[str, Any]: """ Celery task for enhanced SIGMA rule generation Args: cve_ids: Optional list of specific CVE IDs to process Returns: Dictionary containing generation results """ db_session = get_db_session() try: # Import here to avoid circular imports import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from main import CVE # Update task progress self.update_state( state='PROGRESS', meta={ 'stage': 'generating_rules', 'progress': 0, 'message': 'Starting enhanced SIGMA rule generation' } ) logger.info(f"Starting enhanced rule generation task for CVEs: {cve_ids}") # Create generator instance generator = EnhancedSigmaGenerator(db_session) # Get CVEs to process if cve_ids: cves = db_session.query(CVE).filter(CVE.cve_id.in_(cve_ids)).all() else: cves = db_session.query(CVE).filter(CVE.poc_count > 0).all() total_cves = len(cves) processed_cves = 0 successful_rules = 0 failed_rules = 0 results = [] # Process each CVE for i, cve in enumerate(cves): try: # Update progress progress = int((i / total_cves) * 100) self.update_state( state='PROGRESS', meta={ 'stage': 'generating_rules', 'progress': progress, 'message': f'Processing CVE {cve.cve_id}', 'current_cve': cve.cve_id, 'processed': processed_cves, 'total': total_cves } ) # Generate rule using asyncio loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: result = loop.run_until_complete( generator.generate_enhanced_rule(cve) ) if result.get('success', False): successful_rules += 1 else: failed_rules += 1 results.append({ 'cve_id': cve.cve_id, 'success': result.get('success', False), 'message': result.get('message', 'No message'), 'rule_id': result.get('rule_id') }) finally: loop.close() processed_cves += 1 except Exception as e: logger.error(f"Error processing CVE {cve.cve_id}: {e}") failed_rules += 1 results.append({ 'cve_id': cve.cve_id, 'success': False, 'message': f'Error: {str(e)}', 'rule_id': None }) # Final results final_result = { 'total_processed': processed_cves, 'successful_rules': successful_rules, 'failed_rules': failed_rules, 'results': results } # Update final progress self.update_state( state='SUCCESS', meta={ 'stage': 'completed', 'progress': 100, 'message': f'Generated {successful_rules} rules from {processed_cves} CVEs', 'results': final_result } ) logger.info(f"Enhanced rule generation task completed: {final_result}") return final_result except Exception as e: logger.error(f"Enhanced rule generation task failed: {e}") self.update_state( state='FAILURE', meta={ 'stage': 'error', 'progress': 0, 'message': f'Task failed: {str(e)}', 'error': str(e) } ) raise finally: db_session.close() @celery_app.task(bind=True, name='sigma_tasks.llm_enhanced_generation') def llm_enhanced_generation_task(self, cve_id: str, provider: str = 'ollama', model: Optional[str] = None) -> Dict[str, Any]: """ Celery task for LLM-enhanced rule generation Args: cve_id: CVE identifier provider: LLM provider (openai, anthropic, ollama, finetuned) model: Specific model to use Returns: Dictionary containing generation result """ db_session = get_db_session() try: # Import here to avoid circular imports import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from main import CVE # Update task progress self.update_state( state='PROGRESS', meta={ 'stage': 'llm_generation', 'progress': 10, 'message': f'Starting LLM rule generation for {cve_id}', 'cve_id': cve_id, 'provider': provider, 'model': model } ) logger.info(f"Starting LLM rule generation for {cve_id} using {provider}") # Get CVE from database cve = db_session.query(CVE).filter(CVE.cve_id == cve_id).first() if not cve: raise ValueError(f"CVE {cve_id} not found in database") # Update progress self.update_state( state='PROGRESS', meta={ 'stage': 'llm_generation', 'progress': 25, 'message': f'Initializing LLM client ({provider})', 'cve_id': cve_id } ) # Create LLM client llm_client = LLMClient(provider=provider, model=model) if not llm_client.is_available(): raise ValueError(f"LLM client not available for provider: {provider}") # Update progress self.update_state( state='PROGRESS', meta={ 'stage': 'llm_generation', 'progress': 50, 'message': f'Generating rule with LLM for {cve_id}', 'cve_id': cve_id } ) # Generate rule using asyncio loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: rule_content = loop.run_until_complete( llm_client.generate_sigma_rule( cve_id=cve.cve_id, poc_content=cve.poc_data or '', cve_description=cve.description or '' ) ) finally: loop.close() # Update progress self.update_state( state='PROGRESS', meta={ 'stage': 'llm_generation', 'progress': 75, 'message': f'Validating generated rule for {cve_id}', 'cve_id': cve_id } ) # Validate the generated rule is_valid = False if rule_content: is_valid = llm_client.validate_sigma_rule(rule_content, cve_id) # Prepare result result = { 'cve_id': cve_id, 'rule_content': rule_content, 'is_valid': is_valid, 'provider': provider, 'model': model or llm_client.model, 'success': bool(rule_content and is_valid) } # Update final progress self.update_state( state='SUCCESS', meta={ 'stage': 'completed', 'progress': 100, 'message': f'LLM rule generation completed for {cve_id}', 'cve_id': cve_id, 'success': result['success'], 'result': result } ) logger.info(f"LLM rule generation task completed for {cve_id}: {result['success']}") return result except Exception as e: logger.error(f"LLM rule generation task failed for {cve_id}: {e}") self.update_state( state='FAILURE', meta={ 'stage': 'error', 'progress': 0, 'message': f'Task failed for {cve_id}: {str(e)}', 'cve_id': cve_id, 'error': str(e) } ) raise finally: db_session.close() @celery_app.task(bind=True, name='sigma_tasks.batch_llm_generation') def batch_llm_generation_task(self, cve_ids: List[str], provider: str = 'ollama', model: Optional[str] = None) -> Dict[str, Any]: """ Celery task for batch LLM rule generation Args: cve_ids: List of CVE identifiers provider: LLM provider (openai, anthropic, ollama, finetuned) model: Specific model to use Returns: Dictionary containing batch generation results """ db_session = get_db_session() try: # Update task progress self.update_state( state='PROGRESS', meta={ 'stage': 'batch_llm_generation', 'progress': 0, 'message': f'Starting batch LLM generation for {len(cve_ids)} CVEs', 'total_cves': len(cve_ids), 'provider': provider, 'model': model } ) logger.info(f"Starting batch LLM generation for {len(cve_ids)} CVEs using {provider}") # Initialize results results = [] successful_rules = 0 failed_rules = 0 # Process each CVE for i, cve_id in enumerate(cve_ids): try: # Update progress progress = int((i / len(cve_ids)) * 100) self.update_state( state='PROGRESS', meta={ 'stage': 'batch_llm_generation', 'progress': progress, 'message': f'Processing CVE {cve_id} ({i+1}/{len(cve_ids)})', 'current_cve': cve_id, 'processed': i, 'total': len(cve_ids) } ) # Generate rule for this CVE result = llm_enhanced_generation_task.apply( args=[cve_id, provider, model] ).get() if result.get('success', False): successful_rules += 1 else: failed_rules += 1 results.append(result) except Exception as e: logger.error(f"Error processing CVE {cve_id} in batch: {e}") failed_rules += 1 results.append({ 'cve_id': cve_id, 'success': False, 'error': str(e), 'provider': provider, 'model': model }) # Final results final_result = { 'total_processed': len(cve_ids), 'successful_rules': successful_rules, 'failed_rules': failed_rules, 'provider': provider, 'model': model, 'results': results } # Update final progress self.update_state( state='SUCCESS', meta={ 'stage': 'completed', 'progress': 100, 'message': f'Batch generation completed: {successful_rules} successful, {failed_rules} failed', 'results': final_result } ) logger.info(f"Batch LLM generation task completed: {final_result}") return final_result except Exception as e: logger.error(f"Batch LLM generation task failed: {e}") self.update_state( state='FAILURE', meta={ 'stage': 'error', 'progress': 0, 'message': f'Batch task failed: {str(e)}', 'error': str(e) } ) raise finally: db_session.close()