- Add Celery, Flower, and related dependencies to requirements.txt - Update bulk_seeder.py with progress callback support for Celery integration - Clean up finetuned model dependencies (now served through Ollama) - Update setup_ollama.py for enhanced configuration 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
159 lines
No EOL
5.3 KiB
Python
Executable file
159 lines
No EOL
5.3 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Setup script to pull the default Ollama model on startup
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import requests
|
|
import json
|
|
import logging
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
|
|
"""Wait for Ollama service to be ready"""
|
|
for i in range(max_retries):
|
|
try:
|
|
response = requests.get(f"{base_url}/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
logger.info("Ollama service is ready")
|
|
return True
|
|
except Exception as e:
|
|
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
|
|
time.sleep(delay)
|
|
|
|
logger.error("Ollama service is not ready after maximum retries")
|
|
return False
|
|
|
|
def check_model_exists(base_url: str, model: str) -> bool:
|
|
"""Check if a model exists in Ollama"""
|
|
try:
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
for m in models:
|
|
model_name = m.get('name', '')
|
|
if model_name.startswith(model + ':') or model_name == model:
|
|
logger.info(f"Model {model} already exists")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error checking models: {e}")
|
|
return False
|
|
|
|
def pull_model(base_url: str, model: str) -> bool:
|
|
"""Pull an Ollama model"""
|
|
try:
|
|
logger.info(f"Pulling model {model}...")
|
|
payload = {"name": model}
|
|
response = requests.post(
|
|
f"{base_url}/api/pull",
|
|
json=payload,
|
|
timeout=1800, # 30 minutes timeout for model download
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Stream the response to monitor progress
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line.decode('utf-8'))
|
|
status = data.get('status', '')
|
|
if 'pulling' in status.lower() or 'downloading' in status.lower():
|
|
logger.info(f"Ollama: {status}")
|
|
elif data.get('error'):
|
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
logger.info(f"Successfully pulled model {model}")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
|
logger.error(f"Response: {response.text}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error pulling model {model}: {e}")
|
|
return False
|
|
|
|
def setup_custom_sigma_model(base_url: str):
|
|
"""Setup custom SIGMA model if fine-tuned model exists."""
|
|
model_path = "/app/models/sigma_llama_finetuned"
|
|
|
|
if not os.path.exists(model_path):
|
|
logger.info("Fine-tuned model not found, skipping custom model setup")
|
|
return False
|
|
|
|
logger.info("Fine-tuned model found, setting up custom Ollama model...")
|
|
|
|
# Create a simple Modelfile for the custom model
|
|
modelfile_content = f"""FROM llama3.2
|
|
|
|
TEMPLATE \"\"\"### Instruction:
|
|
{{{{ .Prompt }}}}
|
|
|
|
### Response:
|
|
\"\"\"
|
|
|
|
PARAMETER temperature 0.1
|
|
PARAMETER top_p 0.9
|
|
PARAMETER stop "### Instruction:"
|
|
PARAMETER stop "### Response:"
|
|
|
|
SYSTEM \"\"\"You are a cybersecurity expert specializing in SIGMA rule creation. Generate valid SIGMA rules in YAML format based on the provided CVE and exploit information.\"\"\"
|
|
"""
|
|
|
|
try:
|
|
# Write Modelfile
|
|
with open("/tmp/Modelfile.sigma", "w") as f:
|
|
f.write(modelfile_content)
|
|
|
|
# Create custom model (this would need ollama CLI in the container)
|
|
# For now, just log that we would create it
|
|
logger.info("Custom SIGMA model configuration prepared")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error setting up custom model: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main setup function"""
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
|
|
model = os.getenv('LLM_MODEL', 'llama3.2')
|
|
|
|
logger.info(f"Setting up Ollama with model {model}")
|
|
logger.info(f"Ollama URL: {base_url}")
|
|
|
|
# Wait for Ollama service to be ready
|
|
if not wait_for_ollama(base_url):
|
|
logger.error("Ollama service is not available")
|
|
sys.exit(1)
|
|
|
|
# Check if model already exists
|
|
if check_model_exists(base_url, model):
|
|
logger.info(f"Model {model} is already available")
|
|
else:
|
|
# Pull the base model
|
|
if pull_model(base_url, model):
|
|
logger.info(f"Successfully pulled model {model}")
|
|
else:
|
|
logger.error(f"Failed to pull model {model}")
|
|
sys.exit(1)
|
|
|
|
# Setup custom SIGMA model if fine-tuned model exists
|
|
setup_custom_sigma_model(base_url)
|
|
|
|
logger.info("Setup completed successfully")
|
|
sys.exit(0)
|
|
|
|
if __name__ == "__main__":
|
|
main() |