Skip to content

Reflexion Agent Reference

Here we handle the implementation of a LangGraph-based multi-agent system for reflexion on a user-defined task.

The base module

ReflexionAgent

Bases: ABC

LLM agent reflexion framework:

start -> draft -> execute tool -> revise -> evaluation -> end /|\ | ---------------------------

Source code in biochatter/langgraph_agent_base.py
class ReflexionAgent(ABC):
    """
    LLM agent reflexion framework:

    start -> draft -> execute tool -> revise -> evaluation -> end
                        /|\                        |
                         ---------------------------
    """

    RECURSION_LIMIT = 30

    def __init__(
        self,
        conversation_factory: Callable,
        max_steps: Optional[int] = 20,
        agent_logger: Optional[ReflexionAgentLogger] = ReflexionAgentLogger(),
    ):
        """
        Args:
          conversation_factory Callable: the callback to create Conversation
          max_steps int: max steps for reflextion loop
        """
        if max_steps <= 0:
            max_steps = ReflexionAgent.RECURSION_LIMIT
        recursion_limit = ReflexionAgent.RECURSION_LIMIT
        if recursion_limit < max_steps:
            recursion_limit = max_steps
        self.initial_responder = None
        self.revise_responder = None
        self.max_steps = max_steps
        self.recursion_limit = recursion_limit
        self.conversation = conversation_factory()
        self.agent_logger = agent_logger

    def _should_continue(self, state: list[BaseMessage]):
        """
        Determine if we need to continue reflexion
        Args:
          state List[BaseMessage]: message history
        """
        num_steps = ReflexionAgent._get_num_iterations(state)
        if num_steps > self.max_steps:
            return END
        return EXECUTE_TOOL_NODE

    @abstractmethod
    def _tool_function(self, state: list[BaseMessage]) -> ToolMessage:
        """
        tool function, execute tool based on initial draft or revised answer
        Args:
          state List[BaseMessage]: message history
        Returns:
          ToolMessage
        """
        pass

    @abstractmethod
    def _create_initial_responder(
        self, prompt: Optional[str] = None
    ) -> ResponderWithRetries:
        """
        draft responder, draft initial answer
        Args:
          prompt str: prompt for LLM to draft initial answer
        """
        pass

    @abstractmethod
    def _create_revise_responder(
        self, prompt: Optional[str] = None
    ) -> ResponderWithRetries:
        """
        revise responder, revise answer according to tool function result
        Args:
          prompt str: prompt for LLM to draft initial answer
        """
        pass

    @abstractmethod
    def _parse_final_result(
        self, messages: list[BaseMessage]
    ) -> ReflexionAgentResult:
        """
        parse the result of the last step
        Args:
          output BaseMessage: last step message
        Returns:
          ReflexionAgentResult: the parsed reuslt of the last step
        """
        pass

    def get_logs(self):
        return self.agent_logger.logs

    @staticmethod
    def _get_num_iterations(state: list[BaseMessage]):
        """
        Calculate iteration number
        Args:
          state List[BaseMessage]: message history

        Returns:
          int: the iterations number
        """
        i = 0
        for m in state[::-1]:
            if not isinstance(m, (ToolMessage, AIMessage)):
                break
            i += 1
        return i

    @staticmethod
    def _get_user_question(state: list[BaseMessage]):
        """
        get user's question from messages array
        """
        for m in state:
            if not isinstance(m, HumanMessage):
                continue
            return m.content
        return None

    @staticmethod
    def _get_last_tool_result(messages: list[BaseMessage]):
        """
        get result of the last tool node
        """
        for m in messages[::-1]:
            if not isinstance(m, ToolMessage):
                continue
            content = json.loads(m.content)
            return content["result"]
        return None

    def _build_graph(self, prompt: Optional[str] = None):
        """
        Build Langgraph graph for execution of chained LLM processes.

        Args:
          prompt str: prompt for LLM

        Returns:
          CompiledGraph | None: a Langgraph graph or None in case of errors
        """
        try:
            self.initial_responder = self._create_initial_responder(prompt)
            self.revise_responder = self._create_revise_responder(prompt)
            builder = MessageGraph()
            builder.add_node(DRAFT_NODE, self.initial_responder.respond)
            builder.add_node(EXECUTE_TOOL_NODE, self._tool_function)
            builder.add_node(REVISE_NODE, self.revise_responder.respond)
            builder.add_edge(DRAFT_NODE, EXECUTE_TOOL_NODE)
            builder.add_edge(EXECUTE_TOOL_NODE, REVISE_NODE)

            builder.add_conditional_edges(REVISE_NODE, self._should_continue)
            builder.set_entry_point(DRAFT_NODE)
            graph = builder.compile()
            return graph
        except Exception as e:
            logger.error(e)
            return None

    def _execute_graph(
        self,
        graph: Optional[CompiledGraph] = None,
        question: Optional[str] = "",
    ) -> ReflexionAgentResult:
        """
        execute Langgraph graph
        Args:
          graph CompiledGraph: Langgraph graph
          question str: user question

        Returns:
          answer str | None: string answer parsed from Langgraph graph execution
        """
        if graph is None:
            return None
        if len(question) == 0:
            return None

        events = graph.stream(
            [HumanMessage(content=question)],
            {
                "recursion_limit": self.recursion_limit,
            },
        )
        messages = [HumanMessage(content=question)]
        for i, step in enumerate(events):
            if isinstance(step, list):
                node, output = (f"{i}", step[i])
            else:
                node, output = next(iter(step.items()))
            self.agent_logger.log_step_message(i + 1, node, output)
            messages.append(output)

        final_result = self._parse_final_result(messages)
        self.agent_logger.log_final_result(final_result)
        return final_result

    def execute(
        self, question: str, prompt: Optional[str] = None
    ) -> ReflexionAgentResult:
        """
        Execute ReflexionAgent. Wrapper for building a graph and executing it,
        returning the final answer.

        Args:
          question str: user question
          prompt str: user prompt

        Returns:
          answer str | None: If it executes successfully, an answer to the
            question will be returned, otherwise, it returns None
        """
        if len(question) == 0:
            return None
        graph = self._build_graph(prompt)
        return self._execute_graph(graph, question)

__init__(conversation_factory, max_steps=20, agent_logger=ReflexionAgentLogger())

Parameters:

Name Type Description Default
conversation_factory Callable

the callback to create Conversation

required
max_steps int

max steps for reflextion loop

20
Source code in biochatter/langgraph_agent_base.py
def __init__(
    self,
    conversation_factory: Callable,
    max_steps: Optional[int] = 20,
    agent_logger: Optional[ReflexionAgentLogger] = ReflexionAgentLogger(),
):
    """
    Args:
      conversation_factory Callable: the callback to create Conversation
      max_steps int: max steps for reflextion loop
    """
    if max_steps <= 0:
        max_steps = ReflexionAgent.RECURSION_LIMIT
    recursion_limit = ReflexionAgent.RECURSION_LIMIT
    if recursion_limit < max_steps:
        recursion_limit = max_steps
    self.initial_responder = None
    self.revise_responder = None
    self.max_steps = max_steps
    self.recursion_limit = recursion_limit
    self.conversation = conversation_factory()
    self.agent_logger = agent_logger

execute(question, prompt=None)

Execute ReflexionAgent. Wrapper for building a graph and executing it, returning the final answer.

Parameters:

Name Type Description Default
question str

user question

required
prompt str

user prompt

None

Returns:

Type Description
ReflexionAgentResult

answer str | None: If it executes successfully, an answer to the question will be returned, otherwise, it returns None

Source code in biochatter/langgraph_agent_base.py
def execute(
    self, question: str, prompt: Optional[str] = None
) -> ReflexionAgentResult:
    """
    Execute ReflexionAgent. Wrapper for building a graph and executing it,
    returning the final answer.

    Args:
      question str: user question
      prompt str: user prompt

    Returns:
      answer str | None: If it executes successfully, an answer to the
        question will be returned, otherwise, it returns None
    """
    if len(question) == 0:
        return None
    graph = self._build_graph(prompt)
    return self._execute_graph(graph, question)

ReflexionAgentLogger

Source code in biochatter/langgraph_agent_base.py
class ReflexionAgentLogger:
    def __init__(self) -> None:
        self._logs: str = ""

    def log_step_message(
        self,
        step: int,
        node_name: str,
        output: BaseMessage,
    ):
        """
        log step message
        Args:
          step int: step index
          output BaseMessage: step message
        """
        pass

    def log_final_result(self, final_result: dict[str, Any]) -> None:
        """
        log final result
        Args:
          output BaseMessage: last step message
        """
        pass

    def _log_message(
        self,
        msg: str = "",
        level: Optional[Literal["info", "error", "warn"]] = "info",
    ):
        """
        Save log message

        Args:
            msg: the message to be logged

            level: the log level to write
        """
        logger_func = (
            logger.info
            if level == "info"
            else (logger.error if level == "error" else logger.warning)
        )
        logger_func(msg)
        self._logs = (
            self._logs
            + f"[{level}]"
            + f"{datetime.now().isoformat()} - {msg}\n"
        )

    @property
    def logs(self):
        return self._logs

log_final_result(final_result)

log final result Args: output BaseMessage: last step message

Source code in biochatter/langgraph_agent_base.py
def log_final_result(self, final_result: dict[str, Any]) -> None:
    """
    log final result
    Args:
      output BaseMessage: last step message
    """
    pass

log_step_message(step, node_name, output)

log step message Args: step int: step index output BaseMessage: step message

Source code in biochatter/langgraph_agent_base.py
def log_step_message(
    self,
    step: int,
    node_name: str,
    output: BaseMessage,
):
    """
    log step message
    Args:
      step int: step index
      output BaseMessage: step message
    """
    pass

ResponderWithRetries

Raise request to LLM with 3 retries

Source code in biochatter/langgraph_agent_base.py
class ResponderWithRetries:
    """
    Raise request to LLM with 3 retries
    """

    def __init__(self, runnable, validator):
        """
        Args:
        runnable: LLM agent
        validator: used to validate response
        """
        self.runnable = runnable
        self.validator = validator

    @traceable
    def respond(self, state: list[BaseMessage]):
        """
        Invoke LLM agent, this function will be called by LangGraph
        Args:
        state List[BaseMessage]: message history
        """
        response = []
        for attempt in range(3):
            try:
                response = self.runnable.invoke({"messages": state})
                self.validator.invoke(response)
                return response
            except ValidationError as e:
                state = state + [HumanMessage(content=repr(e))]
        return response

__init__(runnable, validator)

Args: runnable: LLM agent validator: used to validate response

Source code in biochatter/langgraph_agent_base.py
def __init__(self, runnable, validator):
    """
    Args:
    runnable: LLM agent
    validator: used to validate response
    """
    self.runnable = runnable
    self.validator = validator

respond(state)

Invoke LLM agent, this function will be called by LangGraph Args: state List[BaseMessage]: message history

Source code in biochatter/langgraph_agent_base.py
@traceable
def respond(self, state: list[BaseMessage]):
    """
    Invoke LLM agent, this function will be called by LangGraph
    Args:
    state List[BaseMessage]: message history
    """
    response = []
    for attempt in range(3):
        try:
            response = self.runnable.invoke({"messages": state})
            self.validator.invoke(response)
            return response
        except ValidationError as e:
            state = state + [HumanMessage(content=repr(e))]
    return response

The KG-based reflexion agent

GenerateQuery

Bases: BaseModel

Generate the query.

Source code in biochatter/kg_langgraph_agent.py
class GenerateQuery(BaseModel):
    """Generate the query."""

    answer: str = Field(
        description="Cypher query for graph database according to user's question."
    )
    reflection: str = Field(
        description="Your reflection on the initial answer, critique of what to improve"
    )
    search_queries: list[str] = Field(description=SEARCH_QUERIES_DESCRIPTION)

KGQueryReflexionAgent

Bases: ReflexionAgent

Source code in biochatter/kg_langgraph_agent.py
class KGQueryReflexionAgent(ReflexionAgent):
    def __init__(
        self,
        conversation_factory: Callable,
        connection_args: dict[str, str],
        query_lang: Optional[str] = "Cypher",
        max_steps: Optional[int] = 20,
    ):
        """
        LLM agent reflexion framework:

        start -> draft -> execute tool -> revise -> evaluation -> end
                            /|\                        |
                             ---------------------------

        Adapts base class to build and refine a knowledge graph query, default
        language Cypher. Currently hardcoded to connect to Neo4j for the KG
        query implementation.

        Args:
            conversation_factory: function to return the Conversation to use for
                the LLM connection

            connection_args: connection arguments for connecting to the database

            query_lang: graph query language to use

            max_steps: the maximum number of steps to execute in the graph

        """
        super().__init__(
            conversation_factory,
            max_steps,
            agent_logger=KGQueryReflexionAgentLogger(),
        )
        self.actor_prompt_template = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    (
                        "As a senior biomedical researcher and graph database expert, "
                        f"your task is to generate '{query_lang}' queries to extract data from our graph database based on the user's question. "
                        """Current time {time}. {instruction}"""
                    ),
                ),
                MessagesPlaceholder(variable_name="messages"),
                (
                    "system",
                    (
                        "Note: 1. Only generate query according to the user's question above.\n"
                        "2. Please limit the results to a maximum of 30 items"
                    ),
                ),
            ]
        ).partial(time=lambda: datetime.now().isoformat())
        self.parser = JsonOutputToolsParser(return_id=True)
        self.connection_args = connection_args
        self.neodriver = None

    def _connect_db(self):
        if self.neodriver is not None:
            return
        try:
            db_uri = (
                "bolt://"
                + self.connection_args.get("host")
                + ":"
                + self.connection_args.get("port")
            )
            self.neodriver = nu.Driver(
                db_name=self.connection_args.get("db_name") or "neo4j",
                db_uri=db_uri,
            )
        except Exception as e:
            logger.error(e)

    def _query_graph_database(self, query: str):
        """
        Try to execute the query in Neo4j and return the result.

        Args:
            query: the query string

        """
        self._connect_db()
        try:
            return self.neodriver.query(query)
        except Exception as e:
            logger.error(str(e))
            return []  # empty result

    def _create_initial_responder(
        self, prompt: Optional[str] = None
    ) -> ResponderWithRetries:
        llm: ChatOpenAI = self.conversation.chat
        initial_chain = self.actor_prompt_template.partial(
            instruction=prompt if prompt is not None else ""
        ) | llm.bind_tools(
            tools=[GenerateQuery],
            tool_choice="GenerateQuery",
        )
        validator = PydanticToolsParser(tools=[GenerateQuery])
        return ResponderWithRetries(runnable=initial_chain, validator=validator)

    def _create_revise_responder(
        self, prompt: str | None = None
    ) -> ResponderWithRetries:
        revision_instruction = """
        Revise your previous query using the query result and follow the guidelines:
        1. If you consistently obtain empty results, please consider removing constraints such as relationship constraints to try to obtain a result.
        2. You should use previous critique to improve your query.
        3. Only generate a query without returning any other text.
        """
        llm: ChatOpenAI = self.conversation.chat
        revision_chain = self.actor_prompt_template.partial(
            instruction=revision_instruction
        ) | llm.bind_tools(
            tools=[ReviseQuery],
            tool_choice="ReviseQuery",
        )
        validator = PydanticToolsParser(tools=[ReviseQuery])
        return ResponderWithRetries(
            runnable=revision_chain, validator=validator
        )

    def _tool_function(self, state: list[BaseMessage]):
        tool_message: AIMessage = state[-1]
        parsed_tool_messages = self.parser.invoke(tool_message)
        results = []
        for parsed_message in parsed_tool_messages:
            try:
                parsed_args = parsed_message["args"]
                query = (
                    parsed_args[REVISED_QUERY]
                    if REVISED_QUERY in parsed_args
                    else (
                        parsed_args[REVISED_QUERY_DESCRIPTION]
                        if REVISED_QUERY_DESCRIPTION in parsed_args
                        else None
                    )
                )
                if query is not None:
                    result = self._query_graph_database(query)
                    results.append({"query": query, "result": result[0]})
                    continue
                queries = (
                    parsed_args[SEARCH_QUERIES]
                    if SEARCH_QUERIES in parsed_args
                    else parsed_args[SEARCH_QUERIES_DESCRIPTION]
                )
                queries = queries if len(queries) > 0 else [parsed_args[ANSWER]]
                for query in queries:
                    result = self._query_graph_database(query)
                    results.append(
                        {
                            "query": query,
                            "result": result[0] if len(result) > 0 else [],
                        }
                    )
            except Exception as e:
                logger.error(f"Error occurred: {str(e)}")

        content = None
        if len(results) > 1:
            # If there are multiple results, we only return
            # the first non-empty result
            for res in results:
                if res["result"] and len(res["result"]) > 0:
                    content = json.dumps(res)
        if content is None:
            content = json.dumps(results[0]) if len(results) > 0 else ""
        return ToolMessage(
            content=content,
            tool_call_id=parsed_message["id"],
        )

    @staticmethod
    def _get_last_tool_results_num(state: list[BaseMessage]):
        i = 0
        for m in state[::-1]:
            if not isinstance(m, ToolMessage):
                continue
            message: ToolMessage = m
            logger.info(f"query result: {message.content}")
            results = (
                json.loads(message.content)
                if message.content is not None and len(message.content) > 0
                else {"result": []}
            )
            empty = True
            if len(results["result"]) > 0:
                # check if it is really not empty, remove the case: {"result": [{"c.name": None}]}
                for res in results["result"]:
                    for k in res.keys():
                        if res[k] is None:
                            continue
                        if isinstance(res[k], str) and (
                            res[k] == "None" or res[k] == "null"
                        ):
                            continue
                        empty = False
                        break
                    if not empty:
                        break
            return len(results["result"]) if not empty else 0

        return 0

    def _get_last_score(self, state: list[BaseMessage]) -> int | None:
        for m in state[::-1]:
            if not isinstance(m, AIMessage):
                continue
            message: AIMessage = m
            parsed_msg = self.parser.invoke(message)
            try:
                score = parsed_msg[0]["args"]["score"]
                return int(score)
            except Exception:
                return None
        return None

    def _should_continue(self, state: list[BaseMessage]):
        res = super()._should_continue(state)
        if res == END:
            return res
        score = self._get_last_score(state)
        if not score is None and score >= 7:
            return END
        query_results_num = KGQueryReflexionAgent._get_last_tool_results_num(
            state
        )
        return END if query_results_num > 0 else EXECUTE_TOOL_NODE

    def _parse_final_result(
        self, messages: list[BaseMessage]
    ) -> ReflexionAgentResult:
        output = messages[-1]
        result = self.parser.invoke(output)[0]["args"]
        tool_result = ReflexionAgent._get_last_tool_result(messages)
        return ReflexionAgentResult(
            answer=result["answer"] if "answer" in result else None,
            tool_result=tool_result,
        )

__init__(conversation_factory, connection_args, query_lang='Cypher', max_steps=20)

LLM agent reflexion framework:

start -> draft -> execute tool -> revise -> evaluation -> end /|\ | ---------------------------

Adapts base class to build and refine a knowledge graph query, default language Cypher. Currently hardcoded to connect to Neo4j for the KG query implementation.

Parameters:

Name Type Description Default
conversation_factory Callable

function to return the Conversation to use for the LLM connection

required
connection_args dict[str, str]

connection arguments for connecting to the database

required
query_lang Optional[str]

graph query language to use

'Cypher'
max_steps Optional[int]

the maximum number of steps to execute in the graph

20
Source code in biochatter/kg_langgraph_agent.py
def __init__(
    self,
    conversation_factory: Callable,
    connection_args: dict[str, str],
    query_lang: Optional[str] = "Cypher",
    max_steps: Optional[int] = 20,
):
    """
    LLM agent reflexion framework:

    start -> draft -> execute tool -> revise -> evaluation -> end
                        /|\                        |
                         ---------------------------

    Adapts base class to build and refine a knowledge graph query, default
    language Cypher. Currently hardcoded to connect to Neo4j for the KG
    query implementation.

    Args:
        conversation_factory: function to return the Conversation to use for
            the LLM connection

        connection_args: connection arguments for connecting to the database

        query_lang: graph query language to use

        max_steps: the maximum number of steps to execute in the graph

    """
    super().__init__(
        conversation_factory,
        max_steps,
        agent_logger=KGQueryReflexionAgentLogger(),
    )
    self.actor_prompt_template = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                (
                    "As a senior biomedical researcher and graph database expert, "
                    f"your task is to generate '{query_lang}' queries to extract data from our graph database based on the user's question. "
                    """Current time {time}. {instruction}"""
                ),
            ),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                (
                    "Note: 1. Only generate query according to the user's question above.\n"
                    "2. Please limit the results to a maximum of 30 items"
                ),
            ),
        ]
    ).partial(time=lambda: datetime.now().isoformat())
    self.parser = JsonOutputToolsParser(return_id=True)
    self.connection_args = connection_args
    self.neodriver = None

ReviseQuery

Bases: GenerateQuery

Revise your previous query according to your question.

Source code in biochatter/kg_langgraph_agent.py
class ReviseQuery(GenerateQuery):
    """Revise your previous query according to your question."""

    revised_query: str = Field(description=REVISED_QUERY_DESCRIPTION)
    score: str = Field(description=SCORE_DESCRIPTION)