114 lines
No EOL
3.8 KiB
Python
Executable file
114 lines
No EOL
3.8 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
Setup script to pull the default Ollama model on startup
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import requests
|
|
import json
|
|
import logging
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
|
|
"""Wait for Ollama service to be ready"""
|
|
for i in range(max_retries):
|
|
try:
|
|
response = requests.get(f"{base_url}/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
logger.info("Ollama service is ready")
|
|
return True
|
|
except Exception as e:
|
|
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
|
|
time.sleep(delay)
|
|
|
|
logger.error("Ollama service is not ready after maximum retries")
|
|
return False
|
|
|
|
def check_model_exists(base_url: str, model: str) -> bool:
|
|
"""Check if a model exists in Ollama"""
|
|
try:
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
for m in models:
|
|
model_name = m.get('name', '')
|
|
if model_name.startswith(model + ':') or model_name == model:
|
|
logger.info(f"Model {model} already exists")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error checking models: {e}")
|
|
return False
|
|
|
|
def pull_model(base_url: str, model: str) -> bool:
|
|
"""Pull an Ollama model"""
|
|
try:
|
|
logger.info(f"Pulling model {model}...")
|
|
payload = {"name": model}
|
|
response = requests.post(
|
|
f"{base_url}/api/pull",
|
|
json=payload,
|
|
timeout=1800, # 30 minutes timeout for model download
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Stream the response to monitor progress
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line.decode('utf-8'))
|
|
status = data.get('status', '')
|
|
if 'pulling' in status.lower() or 'downloading' in status.lower():
|
|
logger.info(f"Ollama: {status}")
|
|
elif data.get('error'):
|
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
logger.info(f"Successfully pulled model {model}")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
|
logger.error(f"Response: {response.text}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error pulling model {model}: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main setup function"""
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
|
|
model = os.getenv('LLM_MODEL', 'llama3.2')
|
|
|
|
logger.info(f"Setting up Ollama with model {model}")
|
|
logger.info(f"Ollama URL: {base_url}")
|
|
|
|
# Wait for Ollama service to be ready
|
|
if not wait_for_ollama(base_url):
|
|
logger.error("Ollama service is not available")
|
|
sys.exit(1)
|
|
|
|
# Check if model already exists
|
|
if check_model_exists(base_url, model):
|
|
logger.info(f"Model {model} is already available")
|
|
sys.exit(0)
|
|
|
|
# Pull the model
|
|
if pull_model(base_url, model):
|
|
logger.info(f"Setup completed successfully - model {model} is ready")
|
|
sys.exit(0)
|
|
else:
|
|
logger.error(f"Failed to pull model {model}")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main() |