from typing import Optional
import fire
import json
import os
from .llama import Llama
from .path_consistency import LlamaModel, PathConsistency
from .util import TqdmPrintWrapper, is_equal, get_math_prompt, get_current_time_string, read_file

def main(
    # Model control parameters
    dataset: str,
    ckpt_dir: str = "/path/to/the/model",
    tokenizer_path: str = "/path/to/the/tokenizer",
    # Generation control parameters
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 1024,
    max_batch_size: int = 1,
    max_gen_len: Optional[int] = None,
    # SC control parameters
    max_branch: int = 20,
    ans_type: str = "float",
    # Prefix control parameters
    confidence_thres: float = 0.8,
    max_level: int = 3
):
    """
    Main function to perform inference using the PathConsistency method on a given dataset.

    Args:
        dataset (str): The name or path of the dataset to process.
        ckpt_dir (str, optional): Directory path to the model checkpoint. Defaults to "/newdata/MODEL/Meta-Llama-3-8B/original/".
        tokenizer_path (str, optional): Path to the tokenizer file. Defaults to "/newdata/MODEL/Meta-Llama-3-8B/original/tokenizer.model".
        temperature (float, optional): Sampling temperature for generation. Defaults to 0.6.
        top_p (float, optional): Nucleus sampling probability threshold. Defaults to 0.9.
        max_seq_len (int, optional): Maximum sequence length for the input. Defaults to 1024.
        max_batch_size (int, optional): Maximum batch size for inference. Defaults to 1.
        max_gen_len (Optional[int], optional): Maximum length for generation. Defaults to None.
        max_branch (int, optional): Maximum number of branches for path consistency. Defaults to 20.
        ans_type (str, optional): Type of the answer to be extracted ('float' or 'str'). Defaults to "float".
        confidence_thres (float, optional): Confidence threshold for prefix selection. Defaults to 0.8.
        max_level (int, optional): Maximum level of prefix branching. Defaults to 3.

    The function builds a model using the Llama class, loads a dataset, and performs inference
    by generating answers using the PathConsistency method. The results are saved to a JSONL file,
    and the accuracy of the generated answers is calculated and logged.

    Returns:
        None
    """
    # Build the generator using the Llama model
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
    model = LlamaModel(generator)

    # Define the output path for saving results
    OUTPUT_PATH = f'outputs/{dataset}/{dataset}@{max_branch}_prefix_L{max_level}_c{confidence_thres}_t{temperature}_p{top_p}_{get_current_time_string()}.jsonl'
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    f = open(OUTPUT_PATH, 'w')

    # Load examples from the dataset
    examples = read_file(dataset)
    prompt = get_math_prompt(dataset)
    scores = []
    i = 0
    # Initialize a progress bar for processing
    pbar = TqdmPrintWrapper(examples[0:1], total=len(examples))
    for e in pbar:
        question = e['input']

        # Initialize PathConsistency object for inference
        PC = PathConsistency(model,
                             max_branch=max_branch,
                             max_level=max_level,
                             confidence_threshold=confidence_thres,
                             ans_type=ans_type
                             )
        # Perform inference and generate answers
        info = PC.inference(prompt.format(question=question),
                            max_gen_len=max_gen_len,
                            temperature=temperature,
                            top_p=top_p,
                            )
        
        # Compare the generated answer with the target
        target = e['target']
        if info['answer'] != None and is_equal(info['answer'], target, ans_type):
            info = {'result': 1, **info}
            scores.append(1)
        else:
            info = {'result': 0, **info}
            scores.append(0)
        info = {"id": i, **info}
        i+=1

        # Write the result to the output file
        f.write(json.dumps(info) + "\n")

    # Calculate and log accuracy
    accuracy = sum(scores) / len(scores)
    f.write(f"Accuracy:{accuracy}\n")        
    f.close()

if __name__ == "__main__":
    fire.Fire(main)