Skip to content

LLM Connectivity Reference

Here we handle connections to various LLM services, proprietary and open source.

AnthropicConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class AnthropicConversation(Conversation):
    def __init__(
        self,
        model_name: str,
        prompts: dict,
        correct: bool = False,
        split_correction: bool = False,
    ):
        """Connect to Anthropic's API and set up a conversation with the user.
        Also initialise a second conversational agent to provide corrections to
        the model output, if necessary.

        Args:
        ----
            model_name (str): The name of the model to use.

            prompts (dict): A dictionary of prompts to use for the conversation.

            split_correction (bool): Whether to correct the model output by
                splitting the output into sentences and correcting each
                sentence individually.

        """
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )

        self.ca_model_name = "claude-3-5-sonnet-20240620"
        # TODO make accessible by drop-down

    def set_api_key(self, api_key: str, user: str) -> bool:
        """Set the API key for the Anthropic API. If the key is valid, initialise the
        conversational agent. Set the user for usage statistics.

        Args:
        ----
            api_key (str): The API key for the Anthropic API.

            user (str): The user for usage statistics.

        Returns:
        -------
            bool: True if the API key is valid, False otherwise.

        """
        client = anthropic.Anthropic(
            api_key=api_key,
        )
        self.user = user

        try:
            client.count_tokens("Test connection")
            self.chat = ChatAnthropic(
                model_name=self.model_name,
                temperature=0,
                api_key=api_key,
            )
            self.ca_chat = ChatAnthropic(
                model_name=self.ca_model_name,
                temperature=0,
                api_key=api_key,
            )
            if user == "community":
                self.usage_stats = get_stats(user=user)

            return True

        except anthropic._exceptions.AuthenticationError:
            return False

    def _primary_query(self):
        """Query the Anthropic API with the user's message and return the response
        using the message history (flattery system messages, prior conversation)
        as context. Correct the response if necessary.

        Returns
        -------
            tuple: A tuple containing the response from the Anthropic API and
                the token usage.

        """
        try:
            history = self._create_history()
            response = self.chat.generate([history])
        except (
            anthropic._exceptions.APIError,
            anthropic._exceptions.AnthropicError,
            anthropic._exceptions.ConflictError,
            anthropic._exceptions.NotFoundError,
            anthropic._exceptions.APIStatusError,
            anthropic._exceptions.RateLimitError,
            anthropic._exceptions.APITimeoutError,
            anthropic._exceptions.BadRequestError,
            anthropic._exceptions.APIConnectionError,
            anthropic._exceptions.AuthenticationError,
            anthropic._exceptions.InternalServerError,
            anthropic._exceptions.PermissionDeniedError,
            anthropic._exceptions.UnprocessableEntityError,
            anthropic._exceptions.APIResponseValidationError,
        ) as e:
            return str(e), None

        msg = response.generations[0][0].text
        token_usage = response.llm_output.get("token_usage")

        self.append_ai_message(msg)

        return msg, token_usage

    def _create_history(self):
        history = []
        # extract text components from message contents
        msg_texts = [m.content[0]["text"] if isinstance(m.content, list) else m.content for m in self.messages]

        # check if last message is an image message
        is_image_message = False
        if isinstance(self.messages[-1].content, list):
            is_image_message = self.messages[-1].content[1]["type"] == "image_url"

        # find location of last AI message (if any)
        last_ai_message = None
        for i, m in enumerate(self.messages):
            if isinstance(m, AIMessage):
                last_ai_message = i

        # Aggregate system messages into one message at the beginning
        system_messages = [m.content for m in self.messages if isinstance(m, SystemMessage)]
        if system_messages:
            history.append(
                SystemMessage(content="\n".join(system_messages)),
            )

        # concatenate all messages before the last AI message into one message
        if last_ai_message is not None:
            history.append(
                HumanMessage(
                    content="\n".join([m for m in msg_texts[:last_ai_message]]),
                ),
            )
            # then append the last AI message
            history.append(
                AIMessage(
                    content=msg_texts[last_ai_message],
                ),
            )

            # then concatenate all messages after that
            # into one HumanMessage
            history.append(
                HumanMessage(
                    content="\n".join(
                        [m for m in msg_texts[last_ai_message + 1 :]],
                    ),
                ),
            )

        # else add human message to history (without system messages)
        else:
            last_system_message = None
            for i, m in enumerate(self.messages):
                if isinstance(m, SystemMessage):
                    last_system_message = i
            history.append(
                HumanMessage(
                    content="\n".join(
                        [m for m in msg_texts[last_system_message + 1 :]],
                    ),
                ),
            )

        # if the last message is an image message, add the image to the history
        if is_image_message:
            history[-1]["content"] = [
                {"type": "text", "text": history[-1]["content"]},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": self.messages[-1].content[1]["image_url"]["url"],
                    },
                },
            ]
        return history

    def _correct_response(self, msg: str):
        """Correct the response from the Anthropic API by sending it to a secondary
        language model. Optionally split the response into single sentences and
        correct each sentence individually. Update usage stats.

        Args:
        ----
            msg (str): The response from the Anthropic API.

        Returns:
        -------
            str: The corrected response (or OK if no correction necessary).

        """
        ca_messages = self.ca_messages.copy()
        ca_messages.append(
            HumanMessage(
                content=msg,
            ),
        )
        ca_messages.append(
            SystemMessage(
                content="If there is nothing to correct, please respond with just 'OK', and nothing else!",
            ),
        )

        response = self.ca_chat.generate([ca_messages])

        correction = response.generations[0][0].text
        token_usage = response.llm_output.get("token_usage")

        return correction

__init__(model_name, prompts, correct=False, split_correction=False)

Connect to Anthropic's API and set up a conversation with the user. Also initialise a second conversational agent to provide corrections to the model output, if necessary.


model_name (str): The name of the model to use.

prompts (dict): A dictionary of prompts to use for the conversation.

split_correction (bool): Whether to correct the model output by
    splitting the output into sentences and correcting each
    sentence individually.
Source code in biochatter/llm_connect.py
def __init__(
    self,
    model_name: str,
    prompts: dict,
    correct: bool = False,
    split_correction: bool = False,
):
    """Connect to Anthropic's API and set up a conversation with the user.
    Also initialise a second conversational agent to provide corrections to
    the model output, if necessary.

    Args:
    ----
        model_name (str): The name of the model to use.

        prompts (dict): A dictionary of prompts to use for the conversation.

        split_correction (bool): Whether to correct the model output by
            splitting the output into sentences and correcting each
            sentence individually.

    """
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )

    self.ca_model_name = "claude-3-5-sonnet-20240620"

set_api_key(api_key, user)

Set the API key for the Anthropic API. If the key is valid, initialise the conversational agent. Set the user for usage statistics.


api_key (str): The API key for the Anthropic API.

user (str): The user for usage statistics.

bool: True if the API key is valid, False otherwise.
Source code in biochatter/llm_connect.py
def set_api_key(self, api_key: str, user: str) -> bool:
    """Set the API key for the Anthropic API. If the key is valid, initialise the
    conversational agent. Set the user for usage statistics.

    Args:
    ----
        api_key (str): The API key for the Anthropic API.

        user (str): The user for usage statistics.

    Returns:
    -------
        bool: True if the API key is valid, False otherwise.

    """
    client = anthropic.Anthropic(
        api_key=api_key,
    )
    self.user = user

    try:
        client.count_tokens("Test connection")
        self.chat = ChatAnthropic(
            model_name=self.model_name,
            temperature=0,
            api_key=api_key,
        )
        self.ca_chat = ChatAnthropic(
            model_name=self.ca_model_name,
            temperature=0,
            api_key=api_key,
        )
        if user == "community":
            self.usage_stats = get_stats(user=user)

        return True

    except anthropic._exceptions.AuthenticationError:
        return False

AzureGptConversation

Bases: GptConversation

Source code in biochatter/llm_connect.py
class AzureGptConversation(GptConversation):
    def __init__(
        self,
        deployment_name: str,
        model_name: str,
        prompts: dict,
        correct: bool = False,
        split_correction: bool = False,
        version: str | None = None,
        base_url: str | None = None,
    ):
        """Connect to Azure's GPT API and set up a conversation with the user.

        Extends GptConversation.

        Args:
        ----
            deployment_name (str): The name of the Azure deployment to use.

            model_name (str): The name of the model to use. This is distinct
                from the deployment name.

            prompts (dict): A dictionary of prompts to use for the conversation.

            correct (bool): Whether to correct the model output.

            split_correction (bool): Whether to correct the model output by
                splitting the output into sentences and correcting each
                sentence individually.

            version (str): The version of the Azure API to use.

            base_url (str): The base URL of the Azure API to use.

        """
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )

        self.version = version
        self.base_url = base_url
        self.deployment_name = deployment_name

    def set_api_key(self, api_key: str) -> bool:
        """Set the API key for the Azure API.

        If the key is valid, initialise the conversational agent. No user stats
        on Azure.

        Args:
        ----
            api_key (str): The API key for the Azure API.

        Returns:
        -------
            bool: True if the API key is valid, False otherwise.

        """
        try:
            self.chat = AzureChatOpenAI(
                deployment_name=self.deployment_name,
                model_name=self.model_name,
                openai_api_version=self.version,
                azure_endpoint=self.base_url,
                openai_api_key=api_key,
                temperature=0,
            )
            # TODO this is the same model as the primary one; refactor to be
            # able to use any model for correction
            self.ca_chat = AzureChatOpenAI(
                deployment_name=self.deployment_name,
                model_name=self.model_name,
                openai_api_version=self.version,
                azure_endpoint=self.base_url,
                openai_api_key=api_key,
                temperature=0,
            )

            test = self.chat.generate([[HumanMessage(content="Hello")]])

            return True

        except openai._exceptions.AuthenticationError:
            return False

    def _update_usage_stats(self, model: str, token_usage: dict):
        """We do not track usage stats for Azure."""
        return

__init__(deployment_name, model_name, prompts, correct=False, split_correction=False, version=None, base_url=None)

Connect to Azure's GPT API and set up a conversation with the user.

Extends GptConversation.


deployment_name (str): The name of the Azure deployment to use.

model_name (str): The name of the model to use. This is distinct
    from the deployment name.

prompts (dict): A dictionary of prompts to use for the conversation.

correct (bool): Whether to correct the model output.

split_correction (bool): Whether to correct the model output by
    splitting the output into sentences and correcting each
    sentence individually.

version (str): The version of the Azure API to use.

base_url (str): The base URL of the Azure API to use.
Source code in biochatter/llm_connect.py
def __init__(
    self,
    deployment_name: str,
    model_name: str,
    prompts: dict,
    correct: bool = False,
    split_correction: bool = False,
    version: str | None = None,
    base_url: str | None = None,
):
    """Connect to Azure's GPT API and set up a conversation with the user.

    Extends GptConversation.

    Args:
    ----
        deployment_name (str): The name of the Azure deployment to use.

        model_name (str): The name of the model to use. This is distinct
            from the deployment name.

        prompts (dict): A dictionary of prompts to use for the conversation.

        correct (bool): Whether to correct the model output.

        split_correction (bool): Whether to correct the model output by
            splitting the output into sentences and correcting each
            sentence individually.

        version (str): The version of the Azure API to use.

        base_url (str): The base URL of the Azure API to use.

    """
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )

    self.version = version
    self.base_url = base_url
    self.deployment_name = deployment_name

set_api_key(api_key)

Set the API key for the Azure API.

If the key is valid, initialise the conversational agent. No user stats on Azure.


api_key (str): The API key for the Azure API.

bool: True if the API key is valid, False otherwise.
Source code in biochatter/llm_connect.py
def set_api_key(self, api_key: str) -> bool:
    """Set the API key for the Azure API.

    If the key is valid, initialise the conversational agent. No user stats
    on Azure.

    Args:
    ----
        api_key (str): The API key for the Azure API.

    Returns:
    -------
        bool: True if the API key is valid, False otherwise.

    """
    try:
        self.chat = AzureChatOpenAI(
            deployment_name=self.deployment_name,
            model_name=self.model_name,
            openai_api_version=self.version,
            azure_endpoint=self.base_url,
            openai_api_key=api_key,
            temperature=0,
        )
        # TODO this is the same model as the primary one; refactor to be
        # able to use any model for correction
        self.ca_chat = AzureChatOpenAI(
            deployment_name=self.deployment_name,
            model_name=self.model_name,
            openai_api_version=self.version,
            azure_endpoint=self.base_url,
            openai_api_key=api_key,
            temperature=0,
        )

        test = self.chat.generate([[HumanMessage(content="Hello")]])

        return True

    except openai._exceptions.AuthenticationError:
        return False

BloomConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class BloomConversation(Conversation):
    def __init__(
        self,
        model_name: str,
        prompts: dict,
        split_correction: bool,
    ):
        """DEPRECATED: Superceded by XinferenceConversation."""
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            split_correction=split_correction,
        )

        self.messages = []

    def set_api_key(self, api_key: str, user: str | None = None):
        self.chat = HuggingFaceHub(
            repo_id=self.model_name,
            model_kwargs={"temperature": 1.0},  # "regular sampling"
            # as per https://huggingface.co/docs/api-inference/detailed_parameters
            huggingfacehub_api_token=api_key,
        )

        try:
            self.chat.generate(["Hello, I am a biomedical researcher."])
            return True
        except ValueError:
            return False

    def _cast_messages(self, messages):
        """Render the different roles of the chat-based conversation as plain text."""
        cast = ""
        for m in messages:
            if isinstance(m, SystemMessage):
                cast += f"System: {m.content}\n"
            elif isinstance(m, HumanMessage):
                cast += f"Human: {m.content}\n"
            elif isinstance(m, AIMessage):
                cast += f"AI: {m.content}\n"
            else:
                raise ValueError(f"Unknown message type: {type(m)}")

        return cast

    def _primary_query(self):
        response = self.chat.generate([self._cast_messages(self.messages)])

        msg = response.generations[0][0].text
        token_usage = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
        }

        self.append_ai_message(msg)

        return msg, token_usage

    def _correct_response(self, msg: str):
        return "ok"

__init__(model_name, prompts, split_correction)

DEPRECATED: Superceded by XinferenceConversation.

Source code in biochatter/llm_connect.py
def __init__(
    self,
    model_name: str,
    prompts: dict,
    split_correction: bool,
):
    """DEPRECATED: Superceded by XinferenceConversation."""
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        split_correction=split_correction,
    )

    self.messages = []

Conversation

Bases: ABC

Use this class to set up a connection to an LLM API. Can be used to set the user name and API key, append specific messages for system, user, and AI roles (if available), set up the general context as well as manual and tool-based data inputs, and finally to query the API with prompts made by the user.

The conversation class is expected to have a messages attribute to store the conversation, and a history attribute, which is a list of messages in a specific format for logging / printing.

Source code in biochatter/llm_connect.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class Conversation(ABC):
    """Use this class to set up a connection to an LLM API. Can be used to set the
    user name and API key, append specific messages for system, user, and AI
    roles (if available), set up the general context as well as manual and
    tool-based data inputs, and finally to query the API with prompts made by
    the user.

    The conversation class is expected to have a `messages` attribute to store
    the conversation, and a `history` attribute, which is a list of messages in
    a specific format for logging / printing.

    """

    def __init__(
        self,
        model_name: str,
        prompts: dict,
        correct: bool = False,
        split_correction: bool = False,
        use_ragagent_selector: bool = False,
    ):
        super().__init__()
        self.model_name = model_name
        self.prompts = prompts
        self.correct = correct
        self.split_correction = split_correction
        self.rag_agents: list[RagAgent] = []
        self.history = []
        self.messages = []
        self.ca_messages = []
        self.current_statements = []
        self._use_ragagent_selector = use_ragagent_selector

    @property
    def use_ragagent_selector(self):
        return self._use_ragagent_selector

    @use_ragagent_selector.setter
    def use_ragagent_selector(self, val: bool):
        self._use_ragagent_selector = val

    def set_user_name(self, user_name: str):
        self.user_name = user_name

    def set_rag_agent(self, agent: RagAgent):
        """Update or insert rag_agent: if the rag_agent with the same mode already
        exists, it will be updated. Otherwise, the new rag_agent will be inserted.
        """
        i, _ = self.find_rag_agent(agent.mode)
        if i < 0:
            # insert
            self.rag_agents.append(agent)
        else:
            # update
            self.rag_agents[i] = agent

    def find_rag_agent(self, mode: str) -> tuple[int, RagAgent]:
        for i, val in enumerate(self.rag_agents):
            if val.mode == mode:
                return i, val
        return -1, None

    @abstractmethod
    def set_api_key(self, api_key: str, user: str | None = None):
        pass

    def get_prompts(self):
        return self.prompts

    def set_prompts(self, prompts: dict):
        self.prompts = prompts

    def append_ai_message(self, message: str) -> None:
        """Add a message from the AI to the conversation.

        Args:
        ----
            message (str): The message from the AI.

        """
        self.messages.append(
            AIMessage(
                content=message,
            ),
        )

    def append_system_message(self, message: str) -> None:
        """Add a system message to the conversation.

        Args:
        ----
            message (str): The system message.

        """
        self.messages.append(
            SystemMessage(
                content=message,
            ),
        )

    def append_ca_message(self, message: str) -> None:
        """Add a message to the correcting agent conversation.

        Args:
        ----
            message (str): The message to the correcting agent.

        """
        self.ca_messages.append(
            SystemMessage(
                content=message,
            ),
        )

    def append_user_message(self, message: str) -> None:
        """Add a message from the user to the conversation.

        Args:
        ----
            message (str): The message from the user.

        """
        self.messages.append(
            HumanMessage(
                content=message,
            ),
        )

    def append_image_message(
        self,
        message: str,
        image_url: str,
        local: bool = False,
    ) -> None:
        """Add a user message with an image to the conversation. Also checks, in
        addition to the `local` flag, if the image URL is a local file path.
        If it is local, the image will be encoded as a base64 string to be
        passed to the LLM.

        Args:
        ----
            message (str): The message from the user.
            image_url (str): The URL of the image.
            local (bool): Whether the image is local or not. If local, it will
                be encoded as a base64 string to be passed to the LLM.

        """
        parsed_url = urllib.parse.urlparse(image_url)
        if local or not parsed_url.netloc:
            image_url = f"data:image/jpeg;base64,{encode_image(image_url)}"
        else:
            image_url = f"data:image/jpeg;base64,{encode_image_from_url(image_url)}"

        self.messages.append(
            HumanMessage(
                content=[
                    {"type": "text", "text": message},
                    {"type": "image_url", "image_url": {"url": image_url}},
                ],
            ),
        )

    def setup(self, context: str):
        """Set up the conversation with general prompts and a context."""
        for msg in self.prompts["primary_model_prompts"]:
            if msg:
                self.append_system_message(msg)

        for msg in self.prompts["correcting_agent_prompts"]:
            if msg:
                self.append_ca_message(msg)

        self.context = context
        msg = f"The topic of the research is {context}."
        self.append_system_message(msg)

    def setup_data_input_manual(self, data_input: str):
        self.data_input = data_input
        msg = f"The user has given information on the data input: {data_input}."
        self.append_system_message(msg)

    def setup_data_input_tool(self, df, input_file_name: str):
        self.data_input_tool = df

        for tool_name in self.prompts["tool_prompts"]:
            if tool_name in input_file_name:
                msg = self.prompts["tool_prompts"][tool_name].format(df=df)
                self.append_system_message(msg)

    def query(self, text: str, image_url: str = None) -> tuple[str, dict, str]:
        """The main workflow for querying the LLM API. Appends the most recent
        query to the conversation, optionally injects context from the RAG
        agent, and runs the primary query method of the child class.

        Args:
        ----
            text (str): The user query.

            image_url (str): The URL of an image to include in the conversation.
                Optional and only supported for models with vision capabilities.

        Returns:
        -------
            tuple: A tuple containing the response from the API, the token usage
                information, and the correction if necessary/desired.

        """
        if not image_url:
            self.append_user_message(text)
        else:
            self.append_image_message(text, image_url)

        self._inject_context(text)

        msg, token_usage = self._primary_query()

        if not token_usage:
            # indicates error
            return (msg, token_usage, None)

        if not self.correct:
            return (msg, token_usage, None)

        cor_msg = "Correcting (using single sentences) ..." if self.split_correction else "Correcting ..."

        if st:
            with st.spinner(cor_msg):
                corrections = self._correct_query(text)
        else:
            corrections = self._correct_query(text)

        if not corrections:
            return (msg, token_usage, None)

        correction = "\n".join(corrections)
        return (msg, token_usage, correction)

    def _correct_query(self, msg: str):
        corrections = []
        if self.split_correction:
            nltk.download("punkt")
            tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
            sentences = tokenizer.tokenize(msg)
            for sentence in sentences:
                correction = self._correct_response(sentence)

                if str(correction).lower() not in ["ok", "ok."]:
                    corrections.append(correction)
        else:
            correction = self._correct_response(msg)

            if str(correction).lower() not in ["ok", "ok."]:
                corrections.append(correction)

        return corrections

    @abstractmethod
    def _primary_query(self, text: str):
        pass

    @abstractmethod
    def _correct_response(self, msg: str):
        pass

    def _inject_context_by_ragagent_selector(self, text: str):
        """Inject the context generated by RagAgentSelector, which will choose appropriate
         rag agent to generate context according to user's question
        Args:
            text (str): The user query to be used for choosing rag agent
        """
        rag_agents: list[RagAgent] = [agent for agent in self.rag_agents if agent.use_prompt]
        decider_agent = RagAgentSelector(
            rag_agents=rag_agents,
            conversation_factory=lambda: self,
        )
        result = decider_agent.execute(text)
        if result.tool_result is not None and len(result.tool_result) > 0:
            return result.tool_result
        # find rag agent selected
        rag_agent = next(
            [agent for agent in rag_agents if agent.mode == result.answer],
            None,
        )
        if rag_agent is None:
            return None
        return rag_agent.generate_responses(text)

    def _inject_context(self, text: str):
        """Inject the context received from the RAG agent into the prompt. The RAG
        agent will find the most similar n text fragments and add them to the
        message history object for usage in the next prompt. Uses the document
        summarisation prompt set to inject the context. The ultimate prompt
        should include the placeholder for the statements, `{statements}` (used
        for formatting the string).

        Args:
        ----
            text (str): The user query to be used for similarity search.

        """
        sim_msg = "Performing similarity search to inject fragments ..."

        if st:
            with st.spinner(sim_msg):
                statements = []
                if self.use_ragagent_selector:
                    statements = self._inject_context_by_ragagent_selector(text)
                else:
                    for agent in self.rag_agents:
                        try:
                            docs = agent.generate_responses(text)
                            statements = statements + [doc[0] for doc in docs]
                        except ValueError as e:
                            logger.warning(e)

        else:
            statements = []
            if self.use_ragagent_selector:
                statements = self._inject_context_by_ragagent_selector(text)
            else:
                for agent in self.rag_agents:
                    try:
                        docs = agent.generate_responses(text)
                        statements = statements + [doc[0] for doc in docs]
                    except ValueError as e:
                        logger.warning(e)

        if statements and len(statements) > 0:
            prompts = self.prompts["rag_agent_prompts"]
            self.current_statements = statements
            for i, prompt in enumerate(prompts):
                # if last prompt, format the statements into the prompt
                if i == len(prompts) - 1:
                    self.append_system_message(
                        prompt.format(statements=statements),
                    )
                else:
                    self.append_system_message(prompt)

    def get_last_injected_context(self) -> list[dict]:
        """Get a formatted list of the last context injected into the
        conversation. Contains one dictionary for each RAG mode.

        Returns
        -------
            List[dict]: A list of dictionaries containing the mode and context
            for each RAG agent.

        """
        last_context = []
        for agent in self.rag_agents:
            last_context.append(
                {"mode": agent.mode, "context": agent.last_response},
            )
        return last_context

    def get_msg_json(self) -> str:
        """Return a JSON representation (of a list of dicts) of the messages in
        the conversation. The keys of the dicts are the roles, the values are
        the messages.

        Returns
        -------
            str: A JSON representation of the messages in the conversation.

        """
        d = []
        for msg in self.messages:
            if isinstance(msg, SystemMessage):
                role = "system"
            elif isinstance(msg, HumanMessage):
                role = "user"
            elif isinstance(msg, AIMessage):
                role = "ai"
            else:
                raise ValueError(f"Unknown message type: {type(msg)}")

            d.append({role: msg.content})

        return json.dumps(d)

    def reset(self):
        """Resets the conversation to the initial state."""
        self.history = []
        self.messages = []
        self.ca_messages = []
        self.current_statements = []

append_ai_message(message)

Add a message from the AI to the conversation.


message (str): The message from the AI.
Source code in biochatter/llm_connect.py
def append_ai_message(self, message: str) -> None:
    """Add a message from the AI to the conversation.

    Args:
    ----
        message (str): The message from the AI.

    """
    self.messages.append(
        AIMessage(
            content=message,
        ),
    )

append_ca_message(message)

Add a message to the correcting agent conversation.


message (str): The message to the correcting agent.
Source code in biochatter/llm_connect.py
def append_ca_message(self, message: str) -> None:
    """Add a message to the correcting agent conversation.

    Args:
    ----
        message (str): The message to the correcting agent.

    """
    self.ca_messages.append(
        SystemMessage(
            content=message,
        ),
    )

append_image_message(message, image_url, local=False)

Add a user message with an image to the conversation. Also checks, in addition to the local flag, if the image URL is a local file path. If it is local, the image will be encoded as a base64 string to be passed to the LLM.


message (str): The message from the user.
image_url (str): The URL of the image.
local (bool): Whether the image is local or not. If local, it will
    be encoded as a base64 string to be passed to the LLM.
Source code in biochatter/llm_connect.py
def append_image_message(
    self,
    message: str,
    image_url: str,
    local: bool = False,
) -> None:
    """Add a user message with an image to the conversation. Also checks, in
    addition to the `local` flag, if the image URL is a local file path.
    If it is local, the image will be encoded as a base64 string to be
    passed to the LLM.

    Args:
    ----
        message (str): The message from the user.
        image_url (str): The URL of the image.
        local (bool): Whether the image is local or not. If local, it will
            be encoded as a base64 string to be passed to the LLM.

    """
    parsed_url = urllib.parse.urlparse(image_url)
    if local or not parsed_url.netloc:
        image_url = f"data:image/jpeg;base64,{encode_image(image_url)}"
    else:
        image_url = f"data:image/jpeg;base64,{encode_image_from_url(image_url)}"

    self.messages.append(
        HumanMessage(
            content=[
                {"type": "text", "text": message},
                {"type": "image_url", "image_url": {"url": image_url}},
            ],
        ),
    )

append_system_message(message)

Add a system message to the conversation.


message (str): The system message.
Source code in biochatter/llm_connect.py
def append_system_message(self, message: str) -> None:
    """Add a system message to the conversation.

    Args:
    ----
        message (str): The system message.

    """
    self.messages.append(
        SystemMessage(
            content=message,
        ),
    )

append_user_message(message)

Add a message from the user to the conversation.


message (str): The message from the user.
Source code in biochatter/llm_connect.py
def append_user_message(self, message: str) -> None:
    """Add a message from the user to the conversation.

    Args:
    ----
        message (str): The message from the user.

    """
    self.messages.append(
        HumanMessage(
            content=message,
        ),
    )

get_last_injected_context()

Get a formatted list of the last context injected into the conversation. Contains one dictionary for each RAG mode.

Returns
List[dict]: A list of dictionaries containing the mode and context
for each RAG agent.
Source code in biochatter/llm_connect.py
def get_last_injected_context(self) -> list[dict]:
    """Get a formatted list of the last context injected into the
    conversation. Contains one dictionary for each RAG mode.

    Returns
    -------
        List[dict]: A list of dictionaries containing the mode and context
        for each RAG agent.

    """
    last_context = []
    for agent in self.rag_agents:
        last_context.append(
            {"mode": agent.mode, "context": agent.last_response},
        )
    return last_context

get_msg_json()

Return a JSON representation (of a list of dicts) of the messages in the conversation. The keys of the dicts are the roles, the values are the messages.

Returns
str: A JSON representation of the messages in the conversation.
Source code in biochatter/llm_connect.py
def get_msg_json(self) -> str:
    """Return a JSON representation (of a list of dicts) of the messages in
    the conversation. The keys of the dicts are the roles, the values are
    the messages.

    Returns
    -------
        str: A JSON representation of the messages in the conversation.

    """
    d = []
    for msg in self.messages:
        if isinstance(msg, SystemMessage):
            role = "system"
        elif isinstance(msg, HumanMessage):
            role = "user"
        elif isinstance(msg, AIMessage):
            role = "ai"
        else:
            raise ValueError(f"Unknown message type: {type(msg)}")

        d.append({role: msg.content})

    return json.dumps(d)

query(text, image_url=None)

The main workflow for querying the LLM API. Appends the most recent query to the conversation, optionally injects context from the RAG agent, and runs the primary query method of the child class.


text (str): The user query.

image_url (str): The URL of an image to include in the conversation.
    Optional and only supported for models with vision capabilities.

tuple: A tuple containing the response from the API, the token usage
    information, and the correction if necessary/desired.
Source code in biochatter/llm_connect.py
def query(self, text: str, image_url: str = None) -> tuple[str, dict, str]:
    """The main workflow for querying the LLM API. Appends the most recent
    query to the conversation, optionally injects context from the RAG
    agent, and runs the primary query method of the child class.

    Args:
    ----
        text (str): The user query.

        image_url (str): The URL of an image to include in the conversation.
            Optional and only supported for models with vision capabilities.

    Returns:
    -------
        tuple: A tuple containing the response from the API, the token usage
            information, and the correction if necessary/desired.

    """
    if not image_url:
        self.append_user_message(text)
    else:
        self.append_image_message(text, image_url)

    self._inject_context(text)

    msg, token_usage = self._primary_query()

    if not token_usage:
        # indicates error
        return (msg, token_usage, None)

    if not self.correct:
        return (msg, token_usage, None)

    cor_msg = "Correcting (using single sentences) ..." if self.split_correction else "Correcting ..."

    if st:
        with st.spinner(cor_msg):
            corrections = self._correct_query(text)
    else:
        corrections = self._correct_query(text)

    if not corrections:
        return (msg, token_usage, None)

    correction = "\n".join(corrections)
    return (msg, token_usage, correction)

reset()

Resets the conversation to the initial state.

Source code in biochatter/llm_connect.py
def reset(self):
    """Resets the conversation to the initial state."""
    self.history = []
    self.messages = []
    self.ca_messages = []
    self.current_statements = []

set_rag_agent(agent)

Update or insert rag_agent: if the rag_agent with the same mode already exists, it will be updated. Otherwise, the new rag_agent will be inserted.

Source code in biochatter/llm_connect.py
def set_rag_agent(self, agent: RagAgent):
    """Update or insert rag_agent: if the rag_agent with the same mode already
    exists, it will be updated. Otherwise, the new rag_agent will be inserted.
    """
    i, _ = self.find_rag_agent(agent.mode)
    if i < 0:
        # insert
        self.rag_agents.append(agent)
    else:
        # update
        self.rag_agents[i] = agent

setup(context)

Set up the conversation with general prompts and a context.

Source code in biochatter/llm_connect.py
def setup(self, context: str):
    """Set up the conversation with general prompts and a context."""
    for msg in self.prompts["primary_model_prompts"]:
        if msg:
            self.append_system_message(msg)

    for msg in self.prompts["correcting_agent_prompts"]:
        if msg:
            self.append_ca_message(msg)

    self.context = context
    msg = f"The topic of the research is {context}."
    self.append_system_message(msg)

GptConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class GptConversation(Conversation):
    def __init__(
        self,
        model_name: str,
        prompts: dict,
        correct: bool = False,
        split_correction: bool = False,
        base_url: str = None,
    ):
        """Connect to OpenAI's GPT API and set up a conversation with the user.
        Also initialise a second conversational agent to provide corrections to
        the model output, if necessary.

        Args:
        ----
            model_name (str): The name of the model to use.

            prompts (dict): A dictionary of prompts to use for the conversation.

            split_correction (bool): Whether to correct the model output by
                splitting the output into sentences and correcting each
                sentence individually.

            base_url (str): Optional OpenAI base_url value to use custom
                endpoint URL instead of default

        """
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )
        self.base_url = base_url
        self.ca_model_name = "gpt-3.5-turbo"
        # TODO make accessible by drop-down

    def set_api_key(self, api_key: str, user: str) -> bool:
        """Set the API key for the OpenAI API. If the key is valid, initialise the
        conversational agent. Set the user for usage statistics.

        Args:
        ----
            api_key (str): The API key for the OpenAI API.

            user (str): The user for usage statistics.

        Returns:
        -------
            bool: True if the API key is valid, False otherwise.

        """
        client = openai.OpenAI(
            api_key=api_key,
            base_url=self.base_url,
        )
        self.user = user

        try:
            client.models.list()
            self.chat = ChatOpenAI(
                model_name=self.model_name,
                temperature=0,
                openai_api_key=api_key,
                base_url=self.base_url,
            )
            self.ca_chat = ChatOpenAI(
                model_name=self.ca_model_name,
                temperature=0,
                openai_api_key=api_key,
                base_url=self.base_url,
            )
            if user == "community":
                self.usage_stats = get_stats(user=user)

            return True

        except openai._exceptions.AuthenticationError:
            return False

    def _primary_query(self):
        """Query the OpenAI API with the user's message and return the response
        using the message history (flattery system messages, prior conversation)
        as context. Correct the response if necessary.

        Returns
        -------
            tuple: A tuple containing the response from the OpenAI API and the
                token usage.

        """
        try:
            response = self.chat.generate([self.messages])
        except (
            openai._exceptions.APIError,
            openai._exceptions.OpenAIError,
            openai._exceptions.ConflictError,
            openai._exceptions.NotFoundError,
            openai._exceptions.APIStatusError,
            openai._exceptions.RateLimitError,
            openai._exceptions.APITimeoutError,
            openai._exceptions.BadRequestError,
            openai._exceptions.APIConnectionError,
            openai._exceptions.AuthenticationError,
            openai._exceptions.InternalServerError,
            openai._exceptions.PermissionDeniedError,
            openai._exceptions.UnprocessableEntityError,
            openai._exceptions.APIResponseValidationError,
        ) as e:
            return str(e), None

        msg = response.generations[0][0].text
        token_usage = response.llm_output.get("token_usage")

        self._update_usage_stats(self.model_name, token_usage)

        self.append_ai_message(msg)

        return msg, token_usage

    def _correct_response(self, msg: str):
        """Correct the response from the OpenAI API by sending it to a secondary
        language model. Optionally split the response into single sentences and
        correct each sentence individually. Update usage stats.

        Args:
        ----
            msg (str): The response from the OpenAI API.

        Returns:
        -------
            str: The corrected response (or OK if no correction necessary).

        """
        ca_messages = self.ca_messages.copy()
        ca_messages.append(
            HumanMessage(
                content=msg,
            ),
        )
        ca_messages.append(
            SystemMessage(
                content="If there is nothing to correct, please respond with just 'OK', and nothing else!",
            ),
        )

        response = self.ca_chat.generate([ca_messages])

        correction = response.generations[0][0].text
        token_usage = response.llm_output.get("token_usage")

        self._update_usage_stats(self.ca_model_name, token_usage)

        return correction

    def _update_usage_stats(self, model: str, token_usage: dict):
        """Update redis database with token usage statistics using the usage_stats
        object with the increment method.

        Args:
        ----
            model (str): The model name.

            token_usage (dict): The token usage statistics.

        """
        if self.user == "community":
            self.usage_stats.increment(
                "usage:[date]:[user]",
                {f"{k}:{model}": v for k, v in token_usage.items()},
            )

__init__(model_name, prompts, correct=False, split_correction=False, base_url=None)

Connect to OpenAI's GPT API and set up a conversation with the user. Also initialise a second conversational agent to provide corrections to the model output, if necessary.


model_name (str): The name of the model to use.

prompts (dict): A dictionary of prompts to use for the conversation.

split_correction (bool): Whether to correct the model output by
    splitting the output into sentences and correcting each
    sentence individually.

base_url (str): Optional OpenAI base_url value to use custom
    endpoint URL instead of default
Source code in biochatter/llm_connect.py
def __init__(
    self,
    model_name: str,
    prompts: dict,
    correct: bool = False,
    split_correction: bool = False,
    base_url: str = None,
):
    """Connect to OpenAI's GPT API and set up a conversation with the user.
    Also initialise a second conversational agent to provide corrections to
    the model output, if necessary.

    Args:
    ----
        model_name (str): The name of the model to use.

        prompts (dict): A dictionary of prompts to use for the conversation.

        split_correction (bool): Whether to correct the model output by
            splitting the output into sentences and correcting each
            sentence individually.

        base_url (str): Optional OpenAI base_url value to use custom
            endpoint URL instead of default

    """
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )
    self.base_url = base_url
    self.ca_model_name = "gpt-3.5-turbo"

set_api_key(api_key, user)

Set the API key for the OpenAI API. If the key is valid, initialise the conversational agent. Set the user for usage statistics.


api_key (str): The API key for the OpenAI API.

user (str): The user for usage statistics.

bool: True if the API key is valid, False otherwise.
Source code in biochatter/llm_connect.py
def set_api_key(self, api_key: str, user: str) -> bool:
    """Set the API key for the OpenAI API. If the key is valid, initialise the
    conversational agent. Set the user for usage statistics.

    Args:
    ----
        api_key (str): The API key for the OpenAI API.

        user (str): The user for usage statistics.

    Returns:
    -------
        bool: True if the API key is valid, False otherwise.

    """
    client = openai.OpenAI(
        api_key=api_key,
        base_url=self.base_url,
    )
    self.user = user

    try:
        client.models.list()
        self.chat = ChatOpenAI(
            model_name=self.model_name,
            temperature=0,
            openai_api_key=api_key,
            base_url=self.base_url,
        )
        self.ca_chat = ChatOpenAI(
            model_name=self.ca_model_name,
            temperature=0,
            openai_api_key=api_key,
            base_url=self.base_url,
        )
        if user == "community":
            self.usage_stats = get_stats(user=user)

        return True

    except openai._exceptions.AuthenticationError:
        return False

OllamaConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class OllamaConversation(Conversation):
    def set_api_key(self, api_key: str, user: str | None = None):
        pass

    def __init__(
        self,
        base_url: str,
        prompts: dict,
        model_name: str = "llama3",
        correct: bool = False,
        split_correction: bool = False,
    ):
        """Connect to an Ollama LLM via the Ollama/Langchain library and set
        up a conversation with the user. Also initialise a second
        conversational agent to provide corrections to the model output, if
        necessary.

        Args:
        ----
            base_url (str): The base URL of the Ollama instance.

            prompts (dict): A dictionary of prompts to use for the conversation.

            model_name (str): The name of the model to use. Can be any model name available in your Ollama instance.

            correct (bool): Whether to correct the model output.

            split_correction (bool): Whether to correct the model output by
            splitting the output into sentences and correcting each sentence
            individually.

        """
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )
        self.model_name = model_name
        self.model = ChatOllama(
            base_url=base_url,
            model=self.model_name,
            temperature=0.0,
        )

        self.ca_model_name = "mixtral:latest"

        self.ca_model = ChatOllama(
            base_url=base_url,
            model_name=self.ca_model_name,
            temperature=0.0,
        )

    def append_system_message(self, message: str):
        """We override the system message addition because Ollama does not
        accept multiple system messages. We concatenate them if there are
        multiple.

        Args:
        ----
            message (str): The message to append.

        """
        # if there is not already a system message in self.messages
        if not any(isinstance(m, SystemMessage) for m in self.messages):
            self.messages.append(
                SystemMessage(
                    content=message,
                ),
            )
        else:
            # if there is a system message, append to the last one
            for i, msg in enumerate(self.messages):
                if isinstance(msg, SystemMessage):
                    self.messages[i].content += f"\n{message}"
                    break

    def append_ca_message(self, message: str):
        """We also override the system message addition for the correcting agent,
        likewise because Ollama does not accept multiple system messages. We
        concatenate them if there are multiple.

        TODO this currently assumes that the correcting agent is the same model
        as the primary one.

        Args:
        ----
            message (str): The message to append.

        """
        # if there is not already a system message in self.messages
        if not any(isinstance(m, SystemMessage) for m in self.ca_messages):
            self.ca_messages.append(
                SystemMessage(
                    content=message,
                ),
            )
        else:
            # if there is a system message, append to the last one
            for i, msg in enumerate(self.ca_messages):
                if isinstance(msg, SystemMessage):
                    self.ca_messages[i].content += f"\n{message}"
                    break

    def _primary_query(self):
        """Query the Ollama client API with the user's message and return the
        response using the message history (flattery system messages, prior
        conversation) as context. Correct the response if necessary.

        Returns
        -------
            tuple: A tuple containing the response from the Ollama API
            (formatted similarly to responses from the OpenAI API) and the token
            usage.

        """
        try:
            messages = self._create_history(self.messages)
            response = self.model.invoke(
                messages,
                # ,generate_config={"max_tokens": 2048, "temperature": 0},
            )
        except (
            openai._exceptions.APIError,
            openai._exceptions.OpenAIError,
            openai._exceptions.ConflictError,
            openai._exceptions.NotFoundError,
            openai._exceptions.APIStatusError,
            openai._exceptions.RateLimitError,
            openai._exceptions.APITimeoutError,
            openai._exceptions.BadRequestError,
            openai._exceptions.APIConnectionError,
            openai._exceptions.AuthenticationError,
            openai._exceptions.InternalServerError,
            openai._exceptions.PermissionDeniedError,
            openai._exceptions.UnprocessableEntityError,
            openai._exceptions.APIResponseValidationError,
        ) as e:
            return str(e), None
        response_dict = response.dict()
        msg = response_dict["content"]
        token_usage = response_dict["response_metadata"]["eval_count"]

        self._update_usage_stats(self.model_name, token_usage)

        self.append_ai_message(msg)

        return msg, token_usage

    def _create_history(self, messages):
        history = []
        for i, m in enumerate(messages):
            if isinstance(m, AIMessage):
                history.append(AIMessage(content=m.content))
            elif isinstance(m, HumanMessage):
                history.append(HumanMessage(content=m.content))
            elif isinstance(m, SystemMessage):
                history.append(SystemMessage(content=m.content))

        return history

    def _correct_response(self, msg: str):
        """Correct the response from the Ollama API by sending it to a
        secondary language model. Optionally split the response into single
        sentences and correct each sentence individually. Update usage stats.

        Args:
        ----
            msg (str): The response from the model.

        Returns:
        -------
            str: The corrected response (or OK if no correction necessary).

        """
        ca_messages = self.ca_messages.copy()
        ca_messages.append(
            HumanMessage(
                content=msg,
            ),
        )
        ca_messages.append(
            SystemMessage(
                content="If there is nothing to correct, please respond with just 'OK', and nothing else!",
            ),
        )
        response = self.ca_model.invoke(
            chat_history=self._create_history(self.messages),
        ).dict()
        correction = response["content"]
        token_usage = response["eval_count"]

        self._update_usage_stats(self.ca_model_name, token_usage)

        return correction

    def _update_usage_stats(self, model: str, token_usage: dict):
        """Update redis database with token usage statistics using the usage_stats
        object with the increment method.

        Args:
        ----
            model (str): The model name.

            token_usage (dict): The token usage statistics.

        """

__init__(base_url, prompts, model_name='llama3', correct=False, split_correction=False)

Connect to an Ollama LLM via the Ollama/Langchain library and set up a conversation with the user. Also initialise a second conversational agent to provide corrections to the model output, if necessary.


base_url (str): The base URL of the Ollama instance.

prompts (dict): A dictionary of prompts to use for the conversation.

model_name (str): The name of the model to use. Can be any model name available in your Ollama instance.

correct (bool): Whether to correct the model output.

split_correction (bool): Whether to correct the model output by
splitting the output into sentences and correcting each sentence
individually.
Source code in biochatter/llm_connect.py
def __init__(
    self,
    base_url: str,
    prompts: dict,
    model_name: str = "llama3",
    correct: bool = False,
    split_correction: bool = False,
):
    """Connect to an Ollama LLM via the Ollama/Langchain library and set
    up a conversation with the user. Also initialise a second
    conversational agent to provide corrections to the model output, if
    necessary.

    Args:
    ----
        base_url (str): The base URL of the Ollama instance.

        prompts (dict): A dictionary of prompts to use for the conversation.

        model_name (str): The name of the model to use. Can be any model name available in your Ollama instance.

        correct (bool): Whether to correct the model output.

        split_correction (bool): Whether to correct the model output by
        splitting the output into sentences and correcting each sentence
        individually.

    """
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )
    self.model_name = model_name
    self.model = ChatOllama(
        base_url=base_url,
        model=self.model_name,
        temperature=0.0,
    )

    self.ca_model_name = "mixtral:latest"

    self.ca_model = ChatOllama(
        base_url=base_url,
        model_name=self.ca_model_name,
        temperature=0.0,
    )

append_ca_message(message)

We also override the system message addition for the correcting agent, likewise because Ollama does not accept multiple system messages. We concatenate them if there are multiple.

TODO this currently assumes that the correcting agent is the same model as the primary one.


message (str): The message to append.
Source code in biochatter/llm_connect.py
def append_ca_message(self, message: str):
    """We also override the system message addition for the correcting agent,
    likewise because Ollama does not accept multiple system messages. We
    concatenate them if there are multiple.

    TODO this currently assumes that the correcting agent is the same model
    as the primary one.

    Args:
    ----
        message (str): The message to append.

    """
    # if there is not already a system message in self.messages
    if not any(isinstance(m, SystemMessage) for m in self.ca_messages):
        self.ca_messages.append(
            SystemMessage(
                content=message,
            ),
        )
    else:
        # if there is a system message, append to the last one
        for i, msg in enumerate(self.ca_messages):
            if isinstance(msg, SystemMessage):
                self.ca_messages[i].content += f"\n{message}"
                break

append_system_message(message)

We override the system message addition because Ollama does not accept multiple system messages. We concatenate them if there are multiple.


message (str): The message to append.
Source code in biochatter/llm_connect.py
def append_system_message(self, message: str):
    """We override the system message addition because Ollama does not
    accept multiple system messages. We concatenate them if there are
    multiple.

    Args:
    ----
        message (str): The message to append.

    """
    # if there is not already a system message in self.messages
    if not any(isinstance(m, SystemMessage) for m in self.messages):
        self.messages.append(
            SystemMessage(
                content=message,
            ),
        )
    else:
        # if there is a system message, append to the last one
        for i, msg in enumerate(self.messages):
            if isinstance(msg, SystemMessage):
                self.messages[i].content += f"\n{message}"
                break

WasmConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class WasmConversation(Conversation):
    def __init__(
        self,
        model_name: str,
        prompts: dict,
        correct: bool = False,
        split_correction: bool = False,
    ):
        """This class is used to return the complete query as a string to be used
        in the frontend running the wasm model. It does not call the API itself,
        but updates the message history similarly to the other conversation
        classes. It overrides the `query` method from the `Conversation` class
        to return a plain string that contains the entire message for the model
        as the first element of the tuple. The second and third elements are
        `None` as there is no token usage or correction for the wasm model.

        """
        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )

    def query(self, text: str) -> tuple:
        """Return the entire message history as a single string. This is the
        message that is sent to the wasm model.

        Args:
        ----
            text (str): The user query.

        Returns:
        -------
            tuple: A tuple containing the message history as a single string,
                and `None` for the second and third elements of the tuple.

        """
        self.append_user_message(text)

        self._inject_context(text)

        return (self._primary_query(), None, None)

    def _primary_query(self):
        """Concatenate all messages in the conversation into a single string and
        return it. Currently discards information about roles (system, user).
        """
        return "\n".join([m.content for m in self.messages])

    def _correct_response(self, msg: str):
        """This method is not used for the wasm model."""
        return "ok"

    def set_api_key(self, api_key: str, user: str | None = None):
        """This method is not used for the wasm model."""
        return True

__init__(model_name, prompts, correct=False, split_correction=False)

This class is used to return the complete query as a string to be used in the frontend running the wasm model. It does not call the API itself, but updates the message history similarly to the other conversation classes. It overrides the query method from the Conversation class to return a plain string that contains the entire message for the model as the first element of the tuple. The second and third elements are None as there is no token usage or correction for the wasm model.

Source code in biochatter/llm_connect.py
def __init__(
    self,
    model_name: str,
    prompts: dict,
    correct: bool = False,
    split_correction: bool = False,
):
    """This class is used to return the complete query as a string to be used
    in the frontend running the wasm model. It does not call the API itself,
    but updates the message history similarly to the other conversation
    classes. It overrides the `query` method from the `Conversation` class
    to return a plain string that contains the entire message for the model
    as the first element of the tuple. The second and third elements are
    `None` as there is no token usage or correction for the wasm model.

    """
    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )

query(text)

Return the entire message history as a single string. This is the message that is sent to the wasm model.


text (str): The user query.

tuple: A tuple containing the message history as a single string,
    and `None` for the second and third elements of the tuple.
Source code in biochatter/llm_connect.py
def query(self, text: str) -> tuple:
    """Return the entire message history as a single string. This is the
    message that is sent to the wasm model.

    Args:
    ----
        text (str): The user query.

    Returns:
    -------
        tuple: A tuple containing the message history as a single string,
            and `None` for the second and third elements of the tuple.

    """
    self.append_user_message(text)

    self._inject_context(text)

    return (self._primary_query(), None, None)

set_api_key(api_key, user=None)

This method is not used for the wasm model.

Source code in biochatter/llm_connect.py
def set_api_key(self, api_key: str, user: str | None = None):
    """This method is not used for the wasm model."""
    return True

XinferenceConversation

Bases: Conversation

Source code in biochatter/llm_connect.py
class XinferenceConversation(Conversation):
    def __init__(
        self,
        base_url: str,
        prompts: dict,
        model_name: str = "auto",
        correct: bool = False,
        split_correction: bool = False,
    ):
        """Connect to an open-source LLM via the Xinference client library and set
        up a conversation with the user.  Also initialise a second
        conversational agent to provide corrections to the model output, if
        necessary.

        Args:
        ----
            base_url (str): The base URL of the Xinference instance (should not
            include the /v1 part).

            prompts (dict): A dictionary of prompts to use for the conversation.

            model_name (str): The name of the model to use. Will be mapped to
            the according uid from the list of available models. Can be set to
            "auto" to use the first available model.

            correct (bool): Whether to correct the model output.

            split_correction (bool): Whether to correct the model output by
            splitting the output into sentences and correcting each sentence
            individually.

        """
        # Shaohong: Please keep this xinference importing code here, so that,
        # we don't need to depend on xinference if we dont need it (xinference
        # is expensive to install)
        from xinference.client import Client

        super().__init__(
            model_name=model_name,
            prompts=prompts,
            correct=correct,
            split_correction=split_correction,
        )
        self.client = Client(base_url=base_url)

        self.models = {}
        self.load_models()

        self.ca_model_name = model_name

        self.set_api_key()

        # TODO make accessible by drop-down

    def load_models(self):
        for id, model in self.client.list_models().items():
            model["id"] = id
            self.models[model["model_name"]] = model

    # def list_models_by_type(self, type: str):
    #     names = []
    #     if type == 'embed' or type == 'embedding':
    #         for name, model in self.models.items():
    #             if "model_ability" in model:
    #                 if "embed" in model["model_ability"]:
    #                     names.append(name)
    #             elif model["model_type"] == "embedding":
    #                 names.append(name)
    #         return names
    #     for name, model in self.models.items():
    #         if "model_ability" in model:
    #             if type in model["model_ability"]:
    #                 names.append(name)
    #         elif model["model_type"] == type:
    #             names.append(name)
    #     return names

    def append_system_message(self, message: str):
        """We override the system message addition because Xinference does not
        accept multiple system messages. We concatenate them if there are
        multiple.

        Args:
        ----
            message (str): The message to append.

        """
        # if there is not already a system message in self.messages
        if not any(isinstance(m, SystemMessage) for m in self.messages):
            self.messages.append(
                SystemMessage(
                    content=message,
                ),
            )
        else:
            # if there is a system message, append to the last one
            for i, msg in enumerate(self.messages):
                if isinstance(msg, SystemMessage):
                    self.messages[i].content += f"\n{message}"
                    break

    def append_ca_message(self, message: str):
        """We also override the system message addition for the correcting agent,
        likewise because Xinference does not accept multiple system messages. We
        concatenate them if there are multiple.

        TODO this currently assumes that the correcting agent is the same model
        as the primary one.

        Args:
        ----
            message (str): The message to append.

        """
        # if there is not already a system message in self.messages
        if not any(isinstance(m, SystemMessage) for m in self.ca_messages):
            self.ca_messages.append(
                SystemMessage(
                    content=message,
                ),
            )
        else:
            # if there is a system message, append to the last one
            for i, msg in enumerate(self.ca_messages):
                if isinstance(msg, SystemMessage):
                    self.ca_messages[i].content += f"\n{message}"
                    break

    def _primary_query(self):
        """Query the Xinference client API with the user's message and return the
        response using the message history (flattery system messages, prior
        conversation) as context. Correct the response if necessary.

        LLaMA2 architecture does not accept separate system messages, so we
        concatenate the system message with the user message to form the prompt.
        'LLaMA enforces a strict rule that chats should alternate
        user/assistant/user/assistant, and the system message, if present,
        should be embedded into the first user message.' (from
        https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3)

        Returns
        -------
            tuple: A tuple containing the response from the Xinference API
            (formatted similarly to responses from the OpenAI API) and the token
            usage.

        """
        try:
            history = self._create_history()
            # TODO this is for LLaMA2 arch, may be different for newer models
            prompt = history.pop()
            response = self.model.chat(
                prompt=prompt["content"],
                chat_history=history,
                generate_config={"max_tokens": 2048, "temperature": 0},
            )
        except (
            openai._exceptions.APIError,
            openai._exceptions.OpenAIError,
            openai._exceptions.ConflictError,
            openai._exceptions.NotFoundError,
            openai._exceptions.APIStatusError,
            openai._exceptions.RateLimitError,
            openai._exceptions.APITimeoutError,
            openai._exceptions.BadRequestError,
            openai._exceptions.APIConnectionError,
            openai._exceptions.AuthenticationError,
            openai._exceptions.InternalServerError,
            openai._exceptions.PermissionDeniedError,
            openai._exceptions.UnprocessableEntityError,
            openai._exceptions.APIResponseValidationError,
        ) as e:
            return str(e), None

        msg = response["choices"][0]["message"]["content"]
        token_usage = response["usage"]

        self._update_usage_stats(self.model_name, token_usage)

        self.append_ai_message(msg)

        return msg, token_usage

    def _create_history(self):
        history = []
        # extract text components from message contents
        msg_texts = [m.content[0]["text"] if isinstance(m.content, list) else m.content for m in self.messages]

        # check if last message is an image message
        is_image_message = False
        if isinstance(self.messages[-1].content, list):
            is_image_message = self.messages[-1].content[1]["type"] == "image_url"

        # find location of last AI message (if any)
        last_ai_message = None
        for i, m in enumerate(self.messages):
            if isinstance(m, AIMessage):
                last_ai_message = i

        # concatenate all messages before the last AI message into one message
        if last_ai_message:
            history.append(
                {
                    "role": "user",
                    "content": "\n".join(
                        [m for m in msg_texts[:last_ai_message]],
                    ),
                },
            )
            # then append the last AI message
            history.append(
                {
                    "role": "assistant",
                    "content": msg_texts[last_ai_message],
                },
            )

            # then concatenate all messages after that
            # into one HumanMessage
            history.append(
                {
                    "role": "user",
                    "content": "\n".join(
                        [m for m in msg_texts[last_ai_message + 1 :]],
                    ),
                },
            )

        # if there is no AI message, concatenate all messages into one user
        # message
        else:
            history.append(
                {
                    "role": "user",
                    "content": "\n".join([m for m in msg_texts[:]]),
                },
            )

        # if the last message is an image message, add the image to the history
        if is_image_message:
            history[-1]["content"] = [
                {"type": "text", "text": history[-1]["content"]},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": self.messages[-1].content[1]["image_url"]["url"],
                    },
                },
            ]
        return history

    def _correct_response(self, msg: str):
        """Correct the response from the Xinference API by sending it to a
        secondary language model. Optionally split the response into single
        sentences and correct each sentence individually. Update usage stats.

        Args:
        ----
            msg (str): The response from the model.

        Returns:
        -------
            str: The corrected response (or OK if no correction necessary).

        """
        ca_messages = self.ca_messages.copy()
        ca_messages.append(
            HumanMessage(
                content=msg,
            ),
        )
        ca_messages.append(
            SystemMessage(
                content="If there is nothing to correct, please respond with just 'OK', and nothing else!",
            ),
        )
        history = []
        for m in self.messages:
            if isinstance(m, SystemMessage):
                history.append({"role": "system", "content": m.content})
            elif isinstance(m, HumanMessage):
                history.append({"role": "user", "content": m.content})
            elif isinstance(m, AIMessage):
                history.append({"role": "assistant", "content": m.content})
        prompt = history.pop()
        response = self.ca_model.chat(
            prompt=prompt["content"],
            chat_history=history,
            generate_config={"max_tokens": 2048, "temperature": 0},
        )

        correction = response["choices"][0]["message"]["content"]
        token_usage = response["usage"]

        self._update_usage_stats(self.ca_model_name, token_usage)

        return correction

    def _update_usage_stats(self, model: str, token_usage: dict):
        """Update redis database with token usage statistics using the usage_stats
        object with the increment method.

        Args:
        ----
            model (str): The model name.

            token_usage (dict): The token usage statistics.

        """

    def set_api_key(self) -> bool:
        """Try to get the Xinference model from the client API. If the model is
        found, initialise the conversational agent. If the model is not found,
        `get_model` will raise a RuntimeError.

        Returns
        -------
            bool: True if the model is found, False otherwise.

        """
        try:
            if self.model_name is None or self.model_name == "auto":
                self.model_name = self.list_models_by_type("chat")[0]
            self.model = self.client.get_model(
                self.models[self.model_name]["id"],
            )

            if self.ca_model_name is None or self.ca_model_name == "auto":
                self.ca_model_name = self.list_models_by_type("chat")[0]
            self.ca_model = self.client.get_model(
                self.models[self.ca_model_name]["id"],
            )
            return True

        except RuntimeError:
            # TODO handle error, log?
            return False

    def list_models_by_type(self, type: str):
        names = []
        if type == "embed" or type == "embedding":
            for name, model in self.models.items():
                if "model_ability" in model:
                    if "embed" in model["model_ability"]:
                        names.append(name)
                elif model["model_type"] == "embedding":
                    names.append(name)
            return names
        for name, model in self.models.items():
            if "model_ability" in model:
                if type in model["model_ability"]:
                    names.append(name)
            elif model["model_type"] == type:
                names.append(name)
        return names

__init__(base_url, prompts, model_name='auto', correct=False, split_correction=False)

Connect to an open-source LLM via the Xinference client library and set up a conversation with the user. Also initialise a second conversational agent to provide corrections to the model output, if necessary.


base_url (str): The base URL of the Xinference instance (should not
include the /v1 part).

prompts (dict): A dictionary of prompts to use for the conversation.

model_name (str): The name of the model to use. Will be mapped to
the according uid from the list of available models. Can be set to
"auto" to use the first available model.

correct (bool): Whether to correct the model output.

split_correction (bool): Whether to correct the model output by
splitting the output into sentences and correcting each sentence
individually.
Source code in biochatter/llm_connect.py
def __init__(
    self,
    base_url: str,
    prompts: dict,
    model_name: str = "auto",
    correct: bool = False,
    split_correction: bool = False,
):
    """Connect to an open-source LLM via the Xinference client library and set
    up a conversation with the user.  Also initialise a second
    conversational agent to provide corrections to the model output, if
    necessary.

    Args:
    ----
        base_url (str): The base URL of the Xinference instance (should not
        include the /v1 part).

        prompts (dict): A dictionary of prompts to use for the conversation.

        model_name (str): The name of the model to use. Will be mapped to
        the according uid from the list of available models. Can be set to
        "auto" to use the first available model.

        correct (bool): Whether to correct the model output.

        split_correction (bool): Whether to correct the model output by
        splitting the output into sentences and correcting each sentence
        individually.

    """
    # Shaohong: Please keep this xinference importing code here, so that,
    # we don't need to depend on xinference if we dont need it (xinference
    # is expensive to install)
    from xinference.client import Client

    super().__init__(
        model_name=model_name,
        prompts=prompts,
        correct=correct,
        split_correction=split_correction,
    )
    self.client = Client(base_url=base_url)

    self.models = {}
    self.load_models()

    self.ca_model_name = model_name

    self.set_api_key()

append_ca_message(message)

We also override the system message addition for the correcting agent, likewise because Xinference does not accept multiple system messages. We concatenate them if there are multiple.

TODO this currently assumes that the correcting agent is the same model as the primary one.


message (str): The message to append.
Source code in biochatter/llm_connect.py
def append_ca_message(self, message: str):
    """We also override the system message addition for the correcting agent,
    likewise because Xinference does not accept multiple system messages. We
    concatenate them if there are multiple.

    TODO this currently assumes that the correcting agent is the same model
    as the primary one.

    Args:
    ----
        message (str): The message to append.

    """
    # if there is not already a system message in self.messages
    if not any(isinstance(m, SystemMessage) for m in self.ca_messages):
        self.ca_messages.append(
            SystemMessage(
                content=message,
            ),
        )
    else:
        # if there is a system message, append to the last one
        for i, msg in enumerate(self.ca_messages):
            if isinstance(msg, SystemMessage):
                self.ca_messages[i].content += f"\n{message}"
                break

append_system_message(message)

We override the system message addition because Xinference does not accept multiple system messages. We concatenate them if there are multiple.


message (str): The message to append.
Source code in biochatter/llm_connect.py
def append_system_message(self, message: str):
    """We override the system message addition because Xinference does not
    accept multiple system messages. We concatenate them if there are
    multiple.

    Args:
    ----
        message (str): The message to append.

    """
    # if there is not already a system message in self.messages
    if not any(isinstance(m, SystemMessage) for m in self.messages):
        self.messages.append(
            SystemMessage(
                content=message,
            ),
        )
    else:
        # if there is a system message, append to the last one
        for i, msg in enumerate(self.messages):
            if isinstance(msg, SystemMessage):
                self.messages[i].content += f"\n{message}"
                break

set_api_key()

Try to get the Xinference model from the client API. If the model is found, initialise the conversational agent. If the model is not found, get_model will raise a RuntimeError.

Returns
bool: True if the model is found, False otherwise.
Source code in biochatter/llm_connect.py
def set_api_key(self) -> bool:
    """Try to get the Xinference model from the client API. If the model is
    found, initialise the conversational agent. If the model is not found,
    `get_model` will raise a RuntimeError.

    Returns
    -------
        bool: True if the model is found, False otherwise.

    """
    try:
        if self.model_name is None or self.model_name == "auto":
            self.model_name = self.list_models_by_type("chat")[0]
        self.model = self.client.get_model(
            self.models[self.model_name]["id"],
        )

        if self.ca_model_name is None or self.ca_model_name == "auto":
            self.ca_model_name = self.list_models_by_type("chat")[0]
        self.ca_model = self.client.get_model(
            self.models[self.ca_model_name]["id"],
        )
        return True

    except RuntimeError:
        # TODO handle error, log?
        return False