Source code for avise.connectors.languagemodel.ollama

"""Connector for Ollama API communication using the ollama library."""

import logging
from typing import List

import ollama

from .base import BaseLMConnector, Message
from ...registry import connector_registry
from ...utils import ansi_colors

logger = logging.getLogger(__name__)


[docs] @connector_registry.register("ollama-lm") class OllamaLMConnector(BaseLMConnector): """Connector for communicating with the Ollama API. Used by Security Evaluation Tests for sending prompts to target Ollama models and collecting their responses. """ name = "ollama-lm" def __init__( self, config: dict, evaluation: bool = False, ): """Initialize the Ollama connector. Args: config: Dictionary containing data from Connector configuration JSON. evaluation: Boolean flag indicating whether initializing a connector for the target model or the evaluation model """ self.api_key = None if evaluation: self.model = config["eval_model"]["name"] self.base_url = config["eval_model"]["api_url"] if ( "max_tokens" in config["eval_model"] and config["eval_model"]["max_tokens"] is not None ): self.max_tokens = config["eval_model"]["max_tokens"] else: self.max_tokens = 512 if ( "api_key" in config["eval_model"] and config["target_model"]["api_key"] is not None ): self.api_key = config["eval_model"]["api_key"] self.client = ollama.Client( host=self.base_url, headers={"Authorization": f"Bearer {self.api_key}"}, ) else: self.client = ollama.Client(host=self.base_url) else: self.model = config["target_model"]["name"] self.base_url = config["target_model"]["api_url"] if ( "max_tokens" in config["target_model"] and config["target_model"]["max_tokens"] is not None ): self.max_tokens = config["target_model"]["max_tokens"] else: self.max_tokens = 512 if ( "api_key" in config["target_model"] and config["target_model"]["api_key"] is not None ): self.api_key = config["target_model"]["api_key"] # Configure client with optional authentication headers self.client = ollama.Client( host=self.base_url, headers={"Authorization": f"Bearer {self.api_key}"}, ) else: self.client = ollama.Client(host=self.base_url) logger.info(f" Ollama Connector Initialized") logger.info(f" Base URL: {self.base_url}") logger.info(f" Model: {self.model}") if self.api_key: logger.info( f" API Key: {'*' * 8}...{self.api_key[-4:] if len(self.api_key) > 4 else '****'}" )
[docs] def generate(self, data: dict, multi_turn: bool = False) -> dict: """Generate a response from the target model via the Ollama API. Arguments: data: Dictionary containing data required for the generation API request. Valid Keys: - prompt : str Prompt for single turn generation. Required for single turn conversation. - messages: list[Message] List of Message objects representing the conversation history.\ Message objects contain 'role' and 'content' attributes.\ Required for multi-turn conversation. - system_prompt : str Optional system prompt - temperature : float [0, 1] Optional temperature setting for the target model. Defaults to 0.5 if not set. - max_tokens : int Optional setting for maximum generated tokens. Defaults to 512 if not set. multi_turn: Boolean flag to indicate if engaging in a multi turn conversation\ with the target model. Default False. Returns: API response. Raises: KeyError: If a required key is missing from data. ValueError: If a value in data is of a wrong type. RuntimeError: If the API call fails. """ if "temperature" not in data: data["temperature"] = 0.5 data["max_tokens"] = self.max_tokens if "system_prompt" in data: if not isinstance(data["system_prompt"], str): raise ValueError( 'If using "system_prompt" in data, it needs to be a string.' ) if multi_turn: if "messages" not in data: raise KeyError( 'Multi-turn conversation requires a "messages" key in \ data variable, which contains a List of Message objects \ representing the conversation history.' ) if not isinstance(data["messages"], list): raise ValueError( 'Multi-turn conversation requires a "messages" key in \ data variable, which contains a List of Message objects \ representing the conversation history.' ) for message in data["messages"]: if not isinstance(message, Message): raise ValueError( 'Multi-turn conversation requires a "messages" key in \ data variable, which contains a List of Message objects \ representing the conversation history.' ) return self._multi_turn(data=data) else: if "prompt" not in data: raise KeyError( 'Single-turn conversation requires a "prompt" key in \ data variable, which contains a prompt as a string.' ) if not isinstance(data["prompt"], str): raise ValueError( 'Single-turn conversation requires a "prompt" key in \ data variable, which contains a prompt as a string.' ) return self._single_turn(data=data)
def _multi_turn(self, data: dict) -> dict: """Make a multi-turn generation. Arguments: data: Dictionary with required data for API request. Returns: {"response": str} """ # Convert Message objects to Ollama's expected format ollama_messages = [ {"role": msg.role, "content": msg.content} for msg in data["messages"] ] if "system_prompt" in data: # If system prompt is given in the data dict, insert it into ollama_messages ollama_messages.insert( 0, {"role": "system", "content": data["system_prompt"]} ) try: response = self.client.chat( model=self.model, messages=ollama_messages, options={ "temperature": data["temperature"], "num_predict": data["max_tokens"], }, ) return {"response": response["message"]["content"]} except Exception as e: logger.error( f"{ansi_colors['red']}ERROR during chat with model: {e}{ansi_colors['reset']}" ) raise RuntimeError(f"Failed to chat with model.") from e def _single_turn(self, data: dict) -> dict: """Make a single-turn generation. Arguments: data: Dictionary with required data for API request. Returns: {"response": str} """ if "system_prompt" in data: # Generate single-turn response with system prompt. try: response = self.client.generate( model=self.model, system=data["system_prompt"], prompt=data["prompt"], options={ "temperature": data["temperature"], "num_predict": data["max_tokens"], }, ) except Exception as e: logger.error( f"{ansi_colors['red']}ERROR while generating response from model: {e}{ansi_colors['reset']}" ) raise RuntimeError( "Failed to generate a response from model due to an error." ) from e return {"response": response.response} try: response = self.client.generate( model=self.model, prompt=data["prompt"], options={ "temperature": data["temperature"], "num_predict": data["max_tokens"], }, ) except Exception as e: logger.error( f"{ansi_colors['red']}ERROR while generating response from model: {e}{ansi_colors['reset']}" ) raise RuntimeError( "Failed to generate a response from model due to an error." ) from e return {"response": response.response} def _match_model(self, model_name: str, available_models: List[str]) -> bool: """Check if a model name exists in the list of available models. Arguents: model_name: Name of the target model. available_models: List of available models to scan for target model. Returns: True if model_name found in available models, False if model_name not found in available models. """ for model in available_models: if model_name == model: return True return False
[docs] def status_check(self) -> bool: """Check if the connector can reach the Ollama API and the target model is available. Returns: True if API is reachable and the target model exists. Raises: ConnectionError: If the API is not reachable. ValueError: If the model is not found. """ # Step 1: Check backend connectivity and get available models try: model_names = self._list_models() except Exception as e: raise ConnectionError( f"Cannot connect to Ollama backend at {self.base_url}: {e}" ) # Step 2: Check if model exists logger.info(f"Available models found: {model_names}") if self._match_model(self.model, model_names): logger.info(f"Model '{self.model}' found.") return True raise ValueError( f"Model '{self.model}' not found in Ollama backend. " f"Available models: {model_names}" )
def _list_models(self) -> List[str]: """Helper method, used by status_check() to verify model availability. Returns: List of model names. Raises: Exception: If the API is not reachable. """ response = self.client.list() models_list = response.get("models", []) model_names = [] for model in models_list: name = model.get("model") if name: model_names.append(name) return model_names