Monte Carlo Tree Self-Refine

Unlocking Advanced Mathematical Reasoning with Monte Carlo Tree Self-Refine

Inference
Author

Matthias De Paolis

Published

September 30, 2024

Image

The rapid advancement of artificial intelligence has ushered in a new era of large language models (LLMs) like GPT-4 and LLaMA. These models have demonstrated remarkable abilities in natural language understanding, generation, and even exhibited emergent properties such as reasoning and in-context learning. They have been deployed across various domains, from generating coherent narratives to assisting in programming tasks. However, when it comes to complex mathematical reasoning, especially at the level of mathematical Olympiads, these models often fall short.

Despite their impressive linguistic capabilities, LLMs struggle with the precision and logical rigor required for high-level mathematics. They are prone to generating “hallucinations” - plausible-sounding but incorrect or irrelevant outputs - which can be particularly problematic in mathematical contexts where accuracy is paramount. Traditional methods to mitigate these issues, such as self-refinement techniques, provide some relief but do not fully bridge the gap.

Enter the Monte Carlo Tree Self-Refine (MCTSr) algorithm - a novel integration of LLMs with Monte Carlo Tree Search (MCTS). This innovative approach systematically explores the solution space and employs heuristic self-refinement mechanisms to enhance decision-making within LLMs. By constructing a Monte Carlo search tree through iterative processes of selection, self-refinement, self-evaluation, and backpropagation, MCTSr optimizes the balance between exploration and exploitation. The result is a significant improvement in the LLM’s ability to solve complex mathematical problems, reaching success rates comparable to GPT-4 on Olympiad-level benchmarks.

In this article, we delve into the theoretical underpinnings of the MCTSr technique, explore its implementation through code examples, and demonstrate its effectiveness in enhancing mathematical reasoning in LLMs. Whether you’re an AI researcher, a data scientist, or simply curious about the intersection of machine learning and mathematics, this exploration offers valuable insights into pushing the boundaries of what LLMs can achieve.

Theoretical Foundations of Monte Carlo Tree Self-Refine

To understand the Monte Carlo Tree Self-Refine (MCTSr) technique, it’s essential to grasp the fundamentals of both Monte Carlo Tree Search (MCTS) and the limitations of LLMs in mathematical reasoning. MCTSr marries these two domains to overcome the challenges inherent in each.

Monte Carlo Tree Search (MCTS) Primer

MCTS is a heuristic search algorithm used extensively in decision-making processes, particularly in game playing (like Go and chess) and complex problem-solving scenarios. The algorithm incrementally builds a search tree, guided by the outcomes of random simulations (also known as rollouts), to evaluate the potential of different actions.

MCTS operates through four primary phases:

  1. Selection: Starting from the root node, the algorithm selects child nodes based on a policy (e.g., the Upper Confidence Bound for Trees, or UCT) until it reaches a leaf node.
  2. Expansion: If the leaf node is not a terminal state, the algorithm adds one or more child nodes, representing possible future states.
  3. Simulation (Rollout): From the new node, the algorithm performs a simulation to the end of the game or task, using a default policy to make decisions.
  4. Backpropagation: The results of the simulation are propagated back up the tree, updating the statistics of the nodes involved (e.g., win/loss records).

The key to MCTS’s success is balancing exploration (trying new, untested actions) and exploitation (choosing actions known to yield good results). This balance is often achieved using the Upper Confidence Bound for Trees algorithm (UCT), which selects child nodes that maximize the following equation:

Image
Image

Challenges in Applying MCTS to LLMs

While MCTS is powerful, directly applying it to LLMs isn’t straightforward. LLMs generate outputs in a continuous and infinite space, making the action space vast and unbounded. Traditional MCTS relies on discrete, finite action spaces. Moreover, LLMs can produce inconsistent outputs due to their generative nature, complicating the evaluation of states and actions.

Introducing Monte Carlo Tree Self-Refine (MCTSr)

MCTSr addresses these challenges by integrating self-refinement and self-evaluation mechanisms into the MCTS framework. It adapts the traditional MCTS algorithm to operate effectively within the context of LLMs tackling mathematical problems.

Key Components and Notations

Before diving into the algorithm, let’s define the key components and notations:

  • Problem Instance (P): The mathematical problem to be solved.
  • Answer Nodes (A): Each node represents a potential solution to P.
  • Actions (M): Possible self-refinement modifications to an answer.
  • Reward Function (R): Samples self-rewards based on the quality of the modifications.
  • Reward Set (Ra): Stores all reward samples for node a.
  • Quality Function (Q(a)): Estimates the value of an answer node a, derived from accumulated rewards.
  • Upper Confidence Bound (U(a)): Balances exploration and exploitation in node selection.
  • Parent Node (Father(a)): The node from which a was derived.
  • Child Nodes (Children(a)): Nodes derived from a through actions m ∈ M.
  • Visit Count (N(a)): The number of times node a has been visited.

Algorithm Workflow

  1. Initialization: Start with a root node representing an initial answer to P. This could be a naive solution or even a placeholder response like “I don’t know.”

  2. Selection:

  • Use the quality function Q to evaluate all expandable nodes.
  • Select the node with the highest U(a) value for further exploration.
  1. Self-Refinement:
  • Apply a self-refinement action m to the selected node a.
  • The LLM generates feedback or criticism about a and produces an improved answer a’ guided by this feedback.
  1. Self-Evaluation:
  • The LLM evaluates a’ to assign a reward, sampling multiple times to reduce variance.
  • Apply constraints to ensure meaningful rewards:
    • Prompt Constraint: Instruct the model to adhere to strict evaluation standards.
    • Full Score Suppression: Discourage perfect scores to maintain discernment.
    • Repeated Sampling: Collect multiple reward samples to improve reliability. 
  1. Backpropagation:
  • Update the Q values of a and its ancestors based on the rewards obtained.
  • The updated Q value of a node a considers both its own reward and the maximum Q value among its children.
  1. UCT Update:
  • Recalculate U(a) for all candidate nodes.
  1. Termination:
  • The process repeats until a termination condition is met (e.g., maximum iterations, satisfactory solution quality).
  • Upon termination, select the best answer based on the highest Q(a) value.

Detailed Explanation of Components

Self-Refinement

Self-refinement leverages the LLM’s capability to critique and improve upon its own outputs. The model generates feedback m on the current answer a and then refines a into a’ based on this feedback. This process simulates a human iteratively improving a solution by self-critique.

###Self-Evaluation In self-evaluation, the LLM assesses the quality of the refined answer a’ and assigns a reward. Since LLMs tend to produce overly generous evaluations, constraints are necessary: - Prompt Constraint: The model is instructed to be strict and critical in its evaluation. - Full Score Suppression: Perfect scores are penalized to encourage meaningful differentiation. - Repeated Sampling: Multiple evaluations are averaged to reduce bias.

The quality function Q(a) is computed as:

Image

This formula balances the worst-case scenario (min Ra) with the average reward, providing a more robust estimate of the answer’s quality.

Backpropagation

After updating the Q(a) of a, the algorithm backpropagates this information up the tree. The Q value of each parent node is updated based on its own Q(a) and the maximum Q value among its children:

Image

This ensures that the parent nodes are aware of the best potential outcomes from their descendants.

UCT Update and Selection

The updated UCT value for each candidate node guides the selection of the next node to explore. By balancing the estimated quality of the node (Q(a)) and the need to explore less-visited nodes, the algorithm efficiently navigates the search space.

Termination Conditions

The algorithm can terminate based on various criteria: - Early Stopping: If improvements become negligible or solutions converge. - Resource Constraints: Maximum number of iterations or time limits. - Solution Quality: If a solution meets a predefined quality threshold.

Advantages of MCTSr

  • Systematic Exploration: By structuring the search as a tree, the algorithm systematically explores possible refinements.
  • Balanced Decision-Making: The UCT formula ensures a balance between exploiting known good solutions and exploring new possibilities.
  • Enhanced Accuracy: Self-evaluation and refinement lead to higher-quality solutions, reducing errors common in LLM outputs.
  • Scalability: The framework can adapt to various problem complexities by adjusting parameters like the exploration constant c and termination conditions.

Practical Implementation of MCTSr

To bring the theoretical concepts of MCTSr to life, let’s delve into a practical implementation using Python code. This implementation demonstrates how the algorithm can be applied to improve the performance of an LLM in solving mathematical problems.

Setting Up Seed Answers

# Seed answers to initiate the MCTS
seed_answers = [
    "I don't know the answer",
    "I'm not sure",
    "I can't say",
]

Critiquing an Answer

The critique_answer function prompts the LLM to analyze a given answer and provide a detailed critique. This critique will guide the refinement process.

# Get Critique
def critique_answer(question, answer):
    prompt = (
        f"Question: {question}\n"
        f"Answer Attempt: {answer}\n"
        "Please analyze the answer above. "
        "Identify any inaccuracies or areas lacking detail. "
        "Provide a thorough critique, highlighting specific flaws and suggesting improvements. "
        "Your critique should be detailed and step-by-step. "
        "Do not provide a revised answer."
    )
    # Request critique from the language model
    return chat_completion_request(prompt)

Explanation: This function constructs a prompt that includes the question and the current answer. It instructs the LLM to provide a detailed critique without offering a revised answer.

Refining the Answer

Using the critique, the refine_answer function prompts the LLM to generate an improved answer.

# Improve the answer
def refine_answer(question, answer, critique):
    prompt = (
        f"Question: {question}\n"
        f"Current Answer: {answer}\n"
        f"Feedback: {critique}\n\n"
        "Based on the feedback, refine the answer to address all the points raised. "
        "Ensure the new answer is accurate, detailed, and well-structured. "
        "Present your reasoning process and verification steps before providing the final answer."
    )
    # Request refined answer from the language model
    return chat_completion_request(prompt)

Explanation: This function constructs a prompt that includes the question, the current answer, and the critique. It instructs the LLM to refine the answer based on the feedback.

Evaluating the Answer

The evaluate_answer function asks the LLM to assess the refined answer, provide a critique, and assign a numerical rating.

def evaluate_answer(question, answer):
    prompt = (
        f"Question: {question}\n"
        f"Answer: {answer}\n"
        "As an expert, assess the answer above for correctness and completeness. "
        "Provide a detailed critique, pointing out any issues. "
        "Then, assign a rating between 0 and 100, where 100 represents a perfect answer. "
        "Format:\n"
        "Critique: <Your detailed critique>\n"
        "Rating: <Numerical rating>"
    )
    # Request evaluation from the language model
    evaluation = chat_completion_request(prompt)
    
    # Extract the rating from the evaluation
    try:
        match = re.search(r'Rating:\s*(\d+\.?\d*)', evaluation)
        if match:
            rating = float(match.group(1))
            rating = min(rating, 95)  # Cap the rating at 95
            rating /= 100.0
        else:
            raise ValueError("Rating not found in the evaluation.")
    except Exception as e:
        print(f"Error extracting rating: {e}")
        print(f"Evaluation response: {evaluation}")
        rating = 0.0
    
    print(f"\nEvaluation Response:\n{evaluation}")
    return rating

Explanation: The function prompts the LLM to critique the answer and assign a rating. It then parses the LLM’s response to extract the numerical rating, capping it at 95 to prevent overconfidence and normalizing it to a value between 0 and 1.

Defining the Tree Node Structure

We define a TreeNode class to represent nodes in the MCTS tree. Each node contains an answer and references to its parent and children.

import math
import random
import numpy as np
import re

# Define the maximum number of children per node
MAX_CHILDREN = 3

class TreeNode:
    def __init__(self, question, answer, parent=None):
        self.question = question
        self.answer = answer
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.Ra = []  # List to store all reward samples
        self.Q = 0.0  # Quality value Q(a)

    def fully_expanded(self):
        return len(self.children) >= MAX_CHILDREN

    def select_promising_child(self, exploration_param=1.41):
        best_score = float('-inf')
        best_child = None
        for child in self.children:
            if child.visits == 0:
                ucb_score = float('inf')
            else:
                # Compute Q(a) using the exact formula
                min_ra = min(child.Ra)
                avg_ra = child.value / child.visits
                child.Q = 0.5 * (min_ra + avg_ra)
                
                exploration = exploration_param * math.sqrt(2 * math.log(self.visits) / child.visits)
                ucb_score = child.Q + exploration
            if ucb_score > best_score:
                best_score = ucb_score
                best_child = child
        return best_child

    def most_visited_child(self):
        return max(self.children, key=lambda c: c.visits, default=None)

    def add_child(self, child_node):
        self.children.append(child_node)

Explanation: - fully_expanded: Checks if the node has reached the maximum number of child nodes. - select_promising_child: Implements the UCT formula to select the most promising child node. - most_visited_child: Retrieves the child node with the highest visit count. - add_child: Adds a new child node to the current node.

Running the MCTS

Finally, we initialize the MCTS with a question and perform the search to find the best answer.

mcts = MonteCarloTreeSearch(question, seed_answers, iterations=10)
best_answer = mcts.perform_search()

Explanation: We create an instance of MonteCarloTreeSearch with the question and seed answers. We specify the number of iterations (rollouts) for the search.

Applying MCTSr to a Mathematical Problem

Suppose we have the following mathematical question: Question: “Calculate the sum of the interior angles of a 12-sided polygon.” We can use the MCTSr implementation to find an accurate answer.

question = "Calculate the sum of the interior angles of a 12-sided polygon."
mcts = MonteCarloTreeSearch(question, seed_answers, iterations=10)
best_answer = mcts.perform_search()
print(f"\nBest Answer:\n{best_answer}")

Output:

(Snippet showing only the final anwer)

Best Answer: Refined Answer: Calculating the sum of interior angles of a 12-sided polygon

Step-by-step calculation:

A 12-sided polygon is called a dodecagon. Using the formula for the sum of interior angles of a polygon, we can calculate the sum:

Sum of interior angles = (n-2) × 180

Where n is the number of sides.

In this case, n = 12, so: Sum of interior angles = (12 - 2) × 180 = 10 × 180 = 1800

Reasoning and Verification:

The formula for the sum of interior angles of a polygon is derived from the fact that each interior angle is supplementary to its adjacent exterior angle. By applying this formula to a dodecagon, we can calculate the sum of its interior angles.

To verify the answer, we can also use a different approach. The sum of interior angles of a polygon is also equal to (n-2) × 180, where n is the number of sides. Since we have already shown that n = 12, we can substitute this value into the formula:

Sum of interior angles = (12 - 2) × 180 = 10 × 180 = 1800

Final Answer: The sum of the interior angles of a 12-sided polygon is 1800 degrees.

This refined answer demonstrates a clear understanding of the problem and provides a step-by-step calculation using relevant mathematical concepts. It also includes a verification step to ensure the answer is accurate.

Explanation: The MCTS starts with a seed answer and iteratively refines it. After ten iterations, it arrives at a correct and well-explained solution.

Advantages of the Implementation

  • Iterative Improvement: The algorithm systematically improves the answer through self-critique and refinement.
  • Balanced Exploration: Using UCT, the search balances exploring new refinements and exploiting known good answers.
  • Automatic Evaluation: The self-evaluation step allows the model to assess the quality of answers without external input.

Limitations

  • Computational Resources: Each iteration involves multiple calls to the LLM, which can be time-consuming and resource-intensive.
  • Model Dependence: The quality of the final answer heavily relies on the LLM’s capability to critique and refine effectively.
  • Parameter Tuning: Parameters like MAX_CHILDREN and the exploration constant need to be tuned for optimal performance.

Conclusion

The Monte Carlo Tree Self-Refine technique represents a significant step forward in enhancing LLMs’ ability to tackle complex reasoning tasks. By integrating MCTS with self-refinement strategies, MCTSr offers a robust framework for improving decision-making and solution quality in AI applications. The practical implementation demonstrates how theoretical concepts can be applied to achieve tangible improvements in mathematical problem-solving.

As research continues, we can expect further refinements and broader applications of this innovative approach, potentially extending beyond mathematics to other domains requiring complex reasoning and decision-making.

Note: The code examples provided use placeholder functions like chat_completion_request, which should be implemented using the appropriate API calls to the language model you’re interfacing with. The full code used in this article is available on Google Colab and in the LLM Tutorial.

References

Zhang, D., Huang, X., Zhou, D., Li, Y., & Ouyang, W. (2024). Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B. arXiv (Cornell University). https://doi.org/10.48550/arxiv.2406.07394