from .util import extract_answer, integrate_answer, get_prefix, sample_prefix
from .wrapper import *
import time

class PathConsistency:
    """
    A class to perform path-consistency inference using a model. It recursively samples and integrates answers
    based on confidence thresholds, ultimately returning a final answer.

    Args:
        model (CompletionModel): The model used for generating completions.
        max_branch (int, optional): Maximum number of branches to explore in the inference tree. Defaults to 20.
        max_level (int, optional): Maximum depth level of the tree. Defaults to 3.
        confidence_threshold (float, optional): Threshold for determining the confidence level of answers. Defaults to 0.8.
        ans_type (str, optional): Type of the answer to be extracted from the model's generation. Defaults to 'float'.

    Attributes:
        model (CompletionModel): The completion model instance.
        max_branch (int): Maximum number of branches to explore.
        max_level (int): Maximum depth level of exploration.
        confidence_thres (float): Confidence threshold for answer validation.
        ans_type (str): Expected type of answer (e.g., float, string).

    Methods:
        inference(prompt: str, **kwargs):
            Perform inference based on the provided prompt. It iteratively generates answers, integrates them, and
            returns the final result along with additional information such as latency and reasoning paths.
            
            Args:
                prompt (str): The initial text prompt to generate completions from.
                **kwargs: Additional keyword arguments to pass to the model's completion function.

            Returns:
                dict: A dictionary containing the final answer, all generated answers, their respective latencies,
                      and the reasoning paths.
    """
    def __init__(self, 
                 model, 
                 max_branch : int = 20, 
                 max_level : int = 3, 
                 confidence_threshold : float = 0.8, 
                 ans_type : str = 'float', ) -> None:
        self.model : CompletionModel = model
        self.max_branch = max_branch
        self.max_level = max_level
        self.confidence_thres = confidence_threshold
        self.ans_type = ans_type

    def inference(self, prompt: str, **kwargs):
        """
        Perform inference using a model by exploring different paths based on the given prompt.

        Args:
            prompt (str): The text prompt to initiate the model's completion process.
            **kwargs: Additional arguments to be passed to the model's completion function.

        Returns:
            dict: A dictionary containing:
                  - 'answer': The final integrated answer.
                  - 'answers': A list of all generated answers.
                  - 'latency': A list of times taken for each generation.
                  - 'generations': The raw outputs generated by the model.
        """
        answers = []
        prefix_level = 0
        prefix_list = []
        reasoning = []
        times = []
        generations = []
        for branch_id in range(self.max_branch): 
            # Sample a prefix from the existing prefix list
            prefix = sample_prefix(prefix_list)
            prompt_plus_prefix = prompt + prefix

            # Generate a completion from the model
            start_time = time.time()
            generation = self.model.completion_function(prompt_plus_prefix,
                                            **kwargs
                                            )
            end_time = time.time()

            # Record the generation time
            times.append(end_time - start_time)
            # Store the reasoning path
            reasoning.append(prefix + generation)

            # Extract and store the answer from the generation
            answer = extract_answer(generation, self.ans_type)
            answers.append(answer)
            generations.append(generation)

            # Update prefix list and level after certain branches
            if (branch_id + 1) % (self.max_branch / (self.max_level + 1)) == 0:
                prefix_list, prefix_level = get_prefix(answers, self.confidence_thres, reasoning, prefix_list, prefix_level)   
        
        # Integrate all the answers to form a final answer
        final_answer = integrate_answer(answers)
        info = {'answer' : final_answer, 'answers' : answers, 'latency' : times, 'generations' : generations}
        return info
