import os
import fire
import torch
import time

from tqdm import tqdm
from copy import deepcopy
from datasets import load_dataset, load_from_disk
from typing import Literal
from transformers import AutoTokenizer
from math import comb
import numpy as np
from typing import List, Dict
import re, json, os, math
import warnings

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from utils.openmathinst_utils import process_results, extract_answer
from utils.chat_template import CHAT_TEMPLATE
from utils.data_utils import write_jsonl, most_common_element

CRITIQUE_PROMPT = """Below you are presented with a question and a tentative response. Your task is to evaluate and assign a rating to the response based on the following clear criteria:

Rating Criteria:

1. Missing final answer enclosed in \\boxed{} at the end: assign \\boxed{-1}.
2. Correct response with the final answer enclosed in \\boxed{} at the end: assign \\boxed{1}.
3. Incorrect response with the final answer enclosed in \\boxed{} at the end: assign \\boxed{-0.5}.

### Question Begin ###
__special_original_question__
### Question End ###

### Response Begin ###
__special_original_response__
### Response End ###

Briefly summarize your analysis, then clearly state your final rating value enclosed in \\boxed{} at the end.
"""

def read_json(file):
    with open(file, "r") as f:
        return json.load(f)
    
def read_jsonl(file):
    with open(file, "r") as f:
        return [json.loads(l) for l in f]

def truncate_repetition(text):
    return text.split("assistant")[0].split("user")[0].strip()

def extract_critique_score(config_file=None, generation_file=None, remove_repetition=False):
    if generation_file is not None:
        gen_fp = generation_file
    else:
        config = read_json(config_file)
        gen_fp = config["generation_file"]
    generations = read_jsonl(gen_fp)
    for g in generations:
        scores = []
        if remove_repetition:
            g["critique_outputs"] = [[truncate_repetition(critique) for critique in output] for output in g["critique_outputs"]]
        for output in g["critique_outputs"]:
            if type(output) == list:
                scores.append(assign_critique_score(output))
            elif type(output) == str:
                scores.append(assign_critique_score([output]))
        g["critique_scores"] = scores
    write_jsonl(gen_fp, generations)

# Turn float into [float]
def format_critique_score(config_file=None, generation_file=None):
    if generation_file is not None:
        gen_fp = generation_file
    else:
        config = read_json(config_file)
        gen_fp = config["generation_file"]
    generations = read_jsonl(gen_fp)
    for g in generations:
        scores = []
        for score in g["critique_scores"]:
            if type(score) == float or type(score) == int:
                scores.append([score])
            elif type(score) == list:
                scores.append(score)
            else:
                raise ValueError(f"Invalid score type: {type(score)}")
        g["critique_scores"] = scores
    write_jsonl(gen_fp, generations)

def assign_critique_score(critique_list):
    score_list = []
    for critique in critique_list:
        # Extract score using regex: last boxed string
        answer = extract_answer(string=critique, extract_from_boxed=True)
        if answer:
            try:
                score = float(answer)
            except ValueError:
                score = None
        else:
            score = None
        score_list.append(score)
    return score_list

def compute_acc(gens):
    correct = sum([g["correct"] for g in gens])
    return correct / len(gens)

def avg_on_custom_key(gens, key):
    key_sum = sum([g[key] for g in gens])
    return key_sum / len(gens)

def compute_pass_ratio(generations):
    for g in generations:
        correctness_list = g["correct_list"]
        n = len(correctness_list)
        c = sum(correctness_list)
        g["pass_ratio"] = c / n
    return generations

# used for greedy computation only
def compute_greedy(generations):
    for g in generations:
        g["greedy"] = (g["correct_list"][0] == 1)
    return generations

def _pass_at_k(n: int, c: int, k: int) -> float:
    """Calculates 1 - comb(n - c, k) / comb(n, k)."""
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def compute_pass_at_k(generations, k_list):
    for k in k_list:
        for g in generations:
            correctness_list = g["correct_list"]
            n = len(correctness_list)
            c = sum(correctness_list)
            g[f"pass@{k}"] = _pass_at_k(n, c, int(k))
    return generations

def compute_sc(generations, sc_list):
    # when maj_vote > 1, we use the logic for maj_vote to examine the responses
    for sc in sc_list:
        for g in generations:
            all_answers = g["answer_list"][:sc]
            all_correctness = g["correct_list"][:sc]
            assert len(all_answers) == len(all_correctness), "Number of answers and correctness should be one-to-one mapping"
            most_common_answer = most_common_element(all_answers)
            g[f"sc@{sc}"] = all_correctness[all_answers.index(most_common_answer)]
    return generations

def compute_bon(generations, bon_list, mode='score', normalize=True):
    # when maj_vote > 1, we use the logic for maj_vote to examine the responses
    for bon in bon_list:
        print(f"Computing bon@{bon}...")
        for g in generations:
            critique_scores = g["critique_scores"][:bon]
            correct_list = g["correct_list"][:bon]
            assert len(critique_scores) == len(correct_list), "Number of answers and correctness should be one-to-one mapping"
            critique_scores_candidates = []
            for i in range(len(critique_scores)):
                score_per_prob_normalized = []
                score_per_prob = critique_scores[i]
                if normalize:
                    for score in score_per_prob:
                        if score is not None:
                            if score > 0:
                                score_per_prob_normalized.append(1)
                            else:
                                score_per_prob_normalized.append(0)
                    if len(score_per_prob_normalized) == 0:
                        critique_scores_candidates.append(0)
                        continue
                else:
                    score_per_prob_normalized = score_per_prob

                if mode == 'score':
                    critique_scores_candidates.append(sum(score_per_prob_normalized) / len(score_per_prob_normalized))
                elif mode == 'maj':
                    critique_scores_candidates.append(most_common_element(score_per_prob_normalized))
            highest_score_index = critique_scores_candidates.index(max(critique_scores_candidates))
            g[f"bon@{bon}"] = correct_list[highest_score_index]
    return generations


def compute_critique_acc(generations, threshold=0):
    for g in generations:
        critique_scores = g["critique_scores"]
        correct_list = g["correct_list"]
        target_acc_keyname = "critique_acc_ms" if threshold == 0.5 else "critique_acc"
        assert len(critique_scores) == len(correct_list), "Answer list and correctness list length mismatch"
        # normalize the scores to 0/1
        normalized_scores = []
        ground_truth_scores = []
        for i in range(len(critique_scores)):
            score_per_prob = critique_scores[i]
            ground_truth_score = correct_list[i]
            for score in score_per_prob:
                if score is not None:
                    # we need to change the threshold to 0.5 for math-shepherd
                    if score > threshold:
                        normalized_scores.append(1)
                    else:
                        normalized_scores.append(0)
                    ground_truth_scores.append(ground_truth_score)
                    # A more strict logic without any normalization
                    # if score == 1:
                    #     normalized_scores.append(1)
                    #     ground_truth_scores.append(ground_truth_score)
                    # else:
                    #     normalized_scores.append(0)
                    # elif score == -0.5 or score == -1:
                    #     normalized_scores.append(0)
                    #     ground_truth_scores.append(ground_truth_score)
                    # else:
                    #     normalized_scores.append(score)
                    #     ground_truth_scores.append(ground_truth_score)    
        if len(normalized_scores) != 0:
            correct_count = sum(p == c for p, c in zip(normalized_scores, ground_truth_scores))
            g[target_acc_keyname] = correct_count / len(normalized_scores)
        else:
            g[target_acc_keyname] = 0.0
            print("[Warning]: No valid critique scores found for this generation.")
    return generations


def compute_weighted_sc(generations, sc_list, mode="score", a=2, b=4):
    if mode not in ("score", "maj"):
        raise ValueError("Invalid mode. Must be 'score' or 'maj'.")
    for sc in sc_list:
        for g in generations:
            model_answers = g["answer_list"][:sc]
            critique_scores = g["critique_scores"][:sc]
            answer_correctness = g["correct_list"][:sc]
            assert len(model_answers) == len(critique_scores) == len(answer_correctness), "Mismatched lengths"
            
            normalized_scores = []
            for i in range(len(critique_scores)):
                score_per_prob = critique_scores[i]
                score_per_prob_normalized = []
                for score in score_per_prob:
                    if score is not None:
                        if score > 0:
                            score_per_prob_normalized.append(1)
                        else:
                            score_per_prob_normalized.append(0)
                normalized_scores.append(score_per_prob_normalized)
            # normalized_scores = critique_scores
            
            score_dict = {}
            first_occurrence = {}
            for idx, (answer, score_per_prob) in enumerate(zip(model_answers, normalized_scores)):
                if answer not in score_dict:
                    score_dict[answer] = 0.0
                    first_occurrence[answer] = idx
                if not score_per_prob:
                    continue
                if mode == "score": # use the mean score for each answer
                    score_dict[answer] += (sum(score_per_prob) + a) / (len(score_per_prob) + b) # Laplace smoothing
                    # score_dict[answer] += (sum(score_per_prob) ) / (len(score_per_prob) )
                elif mode == "maj": # use the majority score for each answer
                    score_dict[answer] += most_common_element(score_per_prob)

            max_score = max(score_dict.values())
            candidates = [ans for ans, s in score_dict.items() if s == max_score]
            
            min_idx = float('inf')
            best_answer = None
            for ans in candidates:
                idx = first_occurrence[ans]
                if idx < min_idx:
                    min_idx = idx
                    best_answer = ans
            
            g[f"weighted_sc@{sc}"] = answer_correctness[first_occurrence[best_answer]]
    return generations

def compute_response_len(generations, tokenizer):
    for g in generations:
        encodings = tokenizer(g["response"], add_special_tokens=False, return_attention_mask=False, padding=False, truncation=False).input_ids
        g["response_len_mean"] = sum([len(x) for x in encodings]) / len(encodings)
        g["response_len_std"] = np.std([len(x) for x in encodings])
    return generations


def self_verify(
    # required
    base_model: str = None,
    chat_template_name: str = None,
    output_file: str = None,

    # model
    bf16: bool = True,
    fp16: bool = False,
    tensor_parallel_size: int = 8,

    # data
    question_key: str = "problem",
    response_key: str = "response",
    
    # gen
    critique_prompt: str = CRITIQUE_PROMPT,
    enforce_eager: bool = False,
    num_scheduler_steps: int = 1,
    system_prompt: str = 'Please reason step by step, and put your final answer within \\boxed{}.',
    gen_n_critiques: int = 1,
    max_model_len: int = 4096,
    max_generation_tokens: int = 3000,
    gen_per_question: int = 8,
    temperature: float = 0.0,
    top_p: float = 1.0,
    max_prompt_len: int = None,
    
    # reproducibility
    seed: int = 42):
    # Self-critique step to select the best response
    generations = read_jsonl(output_file)

    tensor_parallel_size=torch.cuda.device_count()
    if "qwen" in base_model.lower():
        if '32b' in base_model.lower() or '72b' in base_model.lower():
            pass
        else:
            tensor_parallel_size = min(tensor_parallel_size, 4)
    print('Tensor parallel size: ', tensor_parallel_size)
    
    critique_llm = LLM(
        model=base_model,
        tensor_parallel_size=tensor_parallel_size,
        dtype=torch.bfloat16 if bf16 else (torch.float16 if fp16 else torch.float32),
        seed=seed,
        max_model_len=max_model_len,
        enforce_eager=enforce_eager,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    
    # Generate critique prompts for each response
    for g in generations:
        question = g[question_key]
        critique_prompts = []
        for resp in g[response_key]:
            prompt_template = deepcopy(critique_prompt)
            cp = prompt_template.replace("__special_original_question__", question).replace("__special_original_response__", resp)
            critique_prompts.append(cp)
        g["critique_prompts"] = critique_prompts
    
    if max_prompt_len is not None:
        print(f"[Warning] Using max prompt length {max_prompt_len}")
        prompts = [
            tokenizer.decode(
                tokenizer.encode(
                    tokenizer.apply_chat_template(
                        conversation=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": p}
                        ],
                        tokenize=False,
                        add_generation_prompt=True,
                    ),
                    max_length=1024,
                    truncation=True
                ),
                skip_special_tokens=False  # Keep special tokens if needed
            )
            for gen in generations for p in gen["critique_prompts"]
        ]
    else:
        prompts = [
            tokenizer.apply_chat_template(
                conversation=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": p}
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
            for gen in generations for p in gen["critique_prompts"]
        ]
        
    
    # Generate critique responses
    critique_outputs = critique_llm.generate(
        prompts,
        sampling_params=SamplingParams(
            n=gen_n_critiques,
            temperature=temperature,
            max_tokens=max_generation_tokens,
            top_p=top_p,
            stop_token_ids=[tokenizer.eos_token_id]
        )
    )
    
    # Process critique scores and select responses
    index = 0
    for g in generations:
        outputs = []
        scores = []
        for i in range(gen_per_question):
            one_prob_critiques = [r.text for r in critique_outputs[index + i].outputs]
            one_prob_scores = assign_critique_score(one_prob_critiques)
            outputs.append(one_prob_critiques)
            scores.append(one_prob_scores)

        g["critique_outputs"] = outputs
        g["critique_scores"] = scores

        index += gen_per_question
    
    write_jsonl(output_file, generations)
    print("="*50)
    print(f"Self-critique completed and updated to {output_file}")
    print("="*50)


def _judge(llm_judge, generations, generation_file, gen_per_question, bf16=False, fp16=False, seed=42, enforce_eager=False, num_scheduler_steps=1, extract_from_solution=True, ques_key="problem", ans_key="answer"):
    if not llm_judge:
        # This is to allow two output formats
        # MetaMathQA: `The answer is: 123`
        # OpenMathInstruct: `\boxed{123}`
        for g in tqdm(generations, desc="Computing correctness (rule-based)", total=len(generations)):
            correctness_list = []
            extracted_answers = [extract_answer(string=resp, extract_from_boxed=True) for resp in g["response"]]
            for res in g["response"]:
                correctness = process_results(
                    response=res,
                    extracted_answer=g[ans_key],
                    response_extract_from_boxed=True,
                    answer_extract_from_boxed=False,
                ) or process_results(
                    response=res,
                    extracted_answer=g[ans_key],
                    response_extract_from_boxed=False,
                    response_extract_regex=r"answer is(.*?)\.",
                    answer_extract_from_boxed=False,
                )
                correctness_list.append(correctness)
            g["answer_list"] = extracted_answers
            g["correct_list"] = [1 if c else 0 for c in correctness_list]
            assert len(g["correct_list"]) == gen_per_question, f"Len of correctness_list ({len(g['correct_list'])}) should be equal to gen_per_question ({gen_per_question})"

    # LLM as a judge
    else:
        judge_llm = LLM(
            model="KbsdJames/Omni-Judge",
            tensor_parallel_size=torch.cuda.device_count(),
            dtype=torch.bfloat16 if bf16 else (torch.float16 if fp16 else torch.float32),
            seed=seed,
            max_model_len=8 * 1024, # should be enough
            enforce_eager=enforce_eager,
            num_scheduler_steps=num_scheduler_steps,
            trust_remote_code=True,
        )
        tokenizer = AutoTokenizer.from_pretrained("KbsdJames/Omni-Judge", trust_remote_code=True)

        # flatten all generations
        if extract_from_solution: # MATH dataset
            judge_prompts = [tokenizer.get_context(g["problem"], extract_answer(g["solution"]), r) for g in generations for r in g["response"]]
        else:
            judge_prompts = [tokenizer.get_context(g[ques_key], g[ans_key], r) for g in generations for r in g["response"]]
        
        if ('judge_res' in generations[0].keys()) or ('correct_list' in generations[0].keys()):
            warnings.warn("The generations may already have been judged, will overwrite if not aborted ...")
            time.sleep(5)
        
        judge_outputs = judge_llm.generate(
            judge_prompts,
            sampling_params=SamplingParams(
                temperature=0,
                max_tokens=300,
                stop_token_ids=[
                    tokenizer.eos_token_id,
                    tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]
            ),
        )
        # store the judge results back to the generations
        offset = 0
        for g in tqdm(generations, desc="Computing correctness (llm-based)", total=len(generations)):
            num_responses = len(g["response"])
            gen_judge_outputs = judge_outputs[offset : offset + num_responses]
            offset += num_responses

            g["judge_res"] = []
            correctness_list = []

            # Judgement for the same question but different responses
            for jo in gen_judge_outputs:
                try:
                    judge_res = tokenizer.parse_response(jo.outputs[0].text)
                except Exception:
                    judge_res = "Failed to parse"
                if judge_res != "Failed to parse":
                    is_correct = (judge_res["judgement"] == "TRUE")
                else:
                    is_correct = False
                g["judge_res"].append(judge_res)
                correctness_list.append(is_correct)

            g["correct_list"] = [1 if c else 0 for c in correctness_list]
            assert len(g["correct_list"]) == gen_per_question, f"Len of correctness_list ({len(g['correct_list'])}) should be equal to gen_per_question ({gen_per_question})"
        
    print("="*50)
    print("Judge completed")
    print("="*50)
    
    # save generations
    write_jsonl(generation_file, generations)
    print(f'Generation file saved at {generation_file}')


def generate(
    # required
    base_model: str = None,
    chat_template_name: str = None,
    output_dir: str = None,

    # model
    bf16: bool = False,
    fp16: bool = False,
    tensor_parallel_size: int = 8,

    # data
    data_dir: str = None, # If provided, the data will loaded from data_dir/DATASET_ID
    dataset_id: str = None,
    split: Literal["train", "test", "math500", "math4500"] = "test",
    data_file: str = None, # If provided, the data will loaded from data_file
    question_key: str = "problem",
    
    # gen
    enforce_eager: bool = False,
    num_scheduler_steps: int = 1,
    add_prompt: str = None,
    add_prompt_prefix: str = None,
    system_prompt: str = '',
    max_model_len: int = 4096,
    max_generation_tokens: int = 3000,
    gen_per_question: int = 1,
    greedy: str = "true",
    temperature: float = 0.75,
    top_p: float = 0.95,

    # reproducibility
    seed: int = 42,
):
    # TODO: Make it more elegant
    greedy = greedy == "true"
    temperature = float(temperature)
    top_p = float(top_p)
    gen_per_question = int(gen_per_question)

    # Path
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    generation_file = os.path.join(output_dir, "generation.jsonl")
    tensor_parallel_size=torch.cuda.device_count()

    if "qwen" in base_model.lower():
        if '32b' in base_model.lower() or '72b' in base_model.lower():
            pass
        else:
            tensor_parallel_size = min(tensor_parallel_size, 4)
    print('Tensor parallel size: ', tensor_parallel_size)

    # load model
    llm = LLM(
        model=base_model,
        tensor_parallel_size=tensor_parallel_size,
        dtype=torch.bfloat16 if bf16 else (torch.float16 if fp16 else torch.float32),
        seed=seed,
        max_model_len=max_model_len, # should be enough for MATH dataset
        enforce_eager=enforce_eager,
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    if chat_template_name != "default":
        tokenizer.chat_template = CHAT_TEMPLATE[chat_template_name]

    # Load data
    if data_file is not None:
        test_dataset = read_jsonl(data_file)
    else:
        if data_dir is None:
            test_dataset = load_dataset(dataset_id)
        else:
            test_dataset = load_from_disk(os.path.join(data_dir, dataset_id))
        test_dataset = test_dataset[split]

    additional_prompt = ""
    if add_prompt is not None:
        additional_prompt = add_prompt
    prefix_prompt = ""
    if add_prompt_prefix is not None:
        prefix_prompt = add_prompt_prefix

    print('-'*50)
    print(f'Using system_prompt: {system_prompt}')
    print(f'Using add_prompt: {additional_prompt}')
    print(f'Using chat_template_name: {chat_template_name}')
    print('-'*50)
    print('\n')
    
    if system_prompt == "default":
        prompts = [
            tokenizer.apply_chat_template(
                conversation=[
                    {"role": "user", "content": prefix_prompt+td[question_key]+additional_prompt}
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
            for td in test_dataset
        ]
    else:
        prompts = [
            tokenizer.apply_chat_template(
                conversation=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": prefix_prompt+td[question_key]+additional_prompt}
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
            for td in test_dataset
        ]

    # generate
    if greedy: # num_response depends on pass_at
        sampling_params = SamplingParams(n=1, temperature=0, max_tokens=max_generation_tokens, top_p=1)
    else: 
        sampling_params = SamplingParams(n=gen_per_question, temperature=temperature, max_tokens=max_generation_tokens, top_p=top_p, stop_token_ids=[tokenizer.eos_token_id])
    
    print('-'*50)
    print('Sampling_params: ', sampling_params)
    print('-'*50)

    outputs = llm.generate(
        prompts,
        sampling_params,
    )
    assert len(outputs) == len(prompts)

    generations = []
    for td, o, p in zip(test_dataset, outputs, prompts):
        new_td = deepcopy(td)
        new_td["response"] = [r.text for r in o.outputs]
        new_td["prompt"] = p
        generations.append(new_td)

    # save generations
    write_jsonl(generation_file, generations)
    
    # save config
    config = {
        "base_model": base_model,
        "dataset_id": dataset_id,
        "test_split": split,
        "question_key": question_key,
        "generation_file": generation_file,
        "bf16": bf16,
        "fp16": fp16,
        "seed": seed,
        "enforce_eager": enforce_eager,
        "num_scheduler_steps": num_scheduler_steps,
        "greedy": greedy,
        "temperature": 0 if greedy else temperature,
        "top_p": 1 if greedy else top_p,
        "gen_per_question": gen_per_question,
    }

    write_jsonl(os.path.join(output_dir, "config.json"), config)
    
    print("="*50)
    print("Generation completed")
    print("="*50)
    print(f'Generation file saved at {generation_file}')
    print(f'Generation config file saved at {os.path.join(output_dir, "config.json")}')


def judge(config_file, llm_judge=False, extract_from_solution=False, expected_ans_key="answer"):
    # Load config
    config = read_json(config_file)
    use_llm_judge = llm_judge == 'llm_as_a_judge'
    config["judge_method"] = llm_judge
    config["extract_from_solution"] = extract_from_solution
    config["expected_ans_key"] = expected_ans_key
    
    generation_file = config["generation_file"]
    dataset = config.get("dataset_id", "unknown")
    split = config.get("test_split", "unknown")
    question_key = config.get("question_key", "problem")
    bf16 = config.get("bf16", None)
    fp16 = config.get("fp16", None)
    seed = config.get("seed", 666)
    enforce_eager = config.get("enforce_eager", False)
    num_scheduler_steps = config.get("num_scheduler_steps", 1)
    gen_per_question = config["gen_per_question"]

    print('-'*50)
    print('Set use_llm_judge: ', use_llm_judge)
    print('Set extract_from_solution: ', extract_from_solution)
    print('Set expected_ans_key: ', expected_ans_key)
    print(f"Evaluating the {split} of {dataset} dataset.")
    print('-'*50)
    
    # Load generations
    write_jsonl(config_file, config) # update judge info
    generations = read_jsonl(generation_file)
    _judge(use_llm_judge, generations, generation_file, gen_per_question, bf16, fp16, seed, enforce_eager, num_scheduler_steps, extract_from_solution, question_key, expected_ans_key)


def _compute_metrics(config_file, **metric_kwargs):
    # Load config
    config = read_json(config_file)
    generation_file = config["generation_file"]
    base_model = config['base_model']
    output_folder = os.path.dirname(generation_file)
    generations = read_jsonl(generation_file)
    
    metrics = metric_kwargs.keys()
    metric_list = []
    # here we want to update the same batch of generations iteratively
    if metric_kwargs["greedy"]:
        generations = compute_greedy(generations)
        metric_list.append(["greedy"])
    if metric_kwargs["pass_ratio"]:
        generations = compute_pass_ratio(generations)
        metric_list.append(["pass_ratio"])
    if metric_kwargs["pass@k"]:
        k_list = metric_kwargs["pass@k"]
        generations = compute_pass_at_k(generations, k_list)
        metric_list.append([f"pass@{k}" for k in k_list])
    if metric_kwargs["sc"]:
        sc_list = metric_kwargs["sc"]
        generations = compute_sc(generations, sc_list)
        metric_list.append([f"sc@{sc}" for sc in sc_list])
    if metric_kwargs["bon"]:
        bon_list = metric_kwargs["bon"]
        generations = compute_bon(generations, bon_list)
        metric_list.append([f"bon@{bon}" for bon in bon_list])
    if metric_kwargs["weighted_sc"]:
        wsc_list = metric_kwargs["weighted_sc"]
        generations = compute_weighted_sc(generations, wsc_list)
        metric_list.append([f"weighted_sc@{wsc}" for wsc in wsc_list])
    if metric_kwargs["critique_acc"]:
        generations = compute_critique_acc(generations)
        metric_list.append(["critique_acc"])
    if metric_kwargs["critique_acc_ms"]:
        generations = compute_critique_acc(generations, threshold=0.5)
        metric_list.append(["critique_acc_ms"])
    if metric_kwargs["response_len"]:
        print(f"Using {os.path.basename(base_model)} tokenizer to compute response length")
        tokenizer = AutoTokenizer.from_pretrained(base_model)
        generations = compute_response_len(generations, tokenizer)
        metric_list.append(["response_len_mean", "response_len_std"])

    for g in generations:
        if "type" not in g:
            g["type"] = "unknown"
        if "level" not in g:
            g["level"] = "unknown"
        if "difficulty" not in g: # OmniMath
            g["difficulty"] = "unknown"
        if "category" not in g: # MMLU-Pro
            g["category"] = "unknown"

    # save evaluation results
    all_types = sorted(list(set([g["type"] for g in generations])))
    all_levels = sorted(list(set([g["level"] for g in generations])))
    all_difficulties = sorted(list(set([g["difficulty"] for g in generations])))

    for out_key_list in metric_list:
        for out_key in out_key_list:
            
            if out_key.startswith("response_len"):
                percentage = False
            else:
                percentage = True
            result_file = os.path.join(output_folder, f"result_{out_key}.log")
            if percentage:
                with open(result_file, "w") as f:
                    for t in all_types:
                        gens = [g for g in generations if g["type"] == t]
                        f.write(f"{t}: {avg_on_custom_key(gens, out_key) * 100:.2f}\n")
                    for l in all_levels:
                        gens = [g for g in generations if g["level"] == l]
                        f.write(f"{l}: {avg_on_custom_key(gens, out_key) * 100:.2f}\n")
                    for d in all_difficulties:
                        gens = [g for g in generations if g["difficulty"] == d]
                        f.write(f"{d}: {avg_on_custom_key(gens, out_key) * 100:.2f}\n")
                    f.write(f"Overall: {avg_on_custom_key(generations, out_key) * 100:.2f}\n")
                    print(f"{out_key}: {avg_on_custom_key(generations, out_key) * 100:.2f}")
                    
            else:
                with open(result_file, "w") as f:
                    for t in all_types:
                        gens = [g for g in generations if g["type"] == t]
                        f.write(f"{t}: {avg_on_custom_key(gens, out_key):.2f}\n")
                    for l in all_levels:
                        gens = [g for g in generations if g["level"] == l]
                        f.write(f"{l}: {avg_on_custom_key(gens, out_key):.2f}\n")
                    for d in all_difficulties:
                        gens = [g for g in generations if g["difficulty"] == d]
                        f.write(f"{d}: {avg_on_custom_key(gens, out_key):.2f}\n")
                    f.write(f"Overall: {avg_on_custom_key(generations, out_key):.2f}\n")
                    print(f"{out_key}: {avg_on_custom_key(generations, out_key):.2f}")
    
    print(f"Result files saved at {output_folder}")
    write_jsonl(generation_file, generations)
    print("="*50)
    print("Metrics computing completed")
    print("="*50)
    print(f'Generation file updated at {generation_file}')


def compute_metrics(config_file, greedy=False, pass_at_k=False, sc=False, bon=False, wsc=False, pass_ratio=False, c_acc=False,
                    response_len=False, c_acc_ms=False):
    config = read_json(config_file)
    gen = config["gen_per_question"]
    k_list, sc_list, bon_list, wsc_list = [], [], [], []
    if pass_at_k:
        k_list = [2**i for i in range(int(math.log2(gen/2)) + 1)]
    if sc:
        sc_list = [2**i for i in range(2, int(math.log2(gen)) + 1)]
        # sc_list = [2**i for i in range(0, int(math.log2(gen)) + 1)]
    if bon:
        bon_list = [2**i for i in range(2, int(math.log2(gen)) + 1)]
        # bon_list = [2**i for i in range(0, int(math.log2(gen)) + 1)]
    if wsc:
        wsc_list = [2**i for i in range(2, int(math.log2(gen)) + 1)]
        # wsc_list = [2**i for i in range(0, int(math.log2(gen)) + 1)]


    metric_kwargs = {"greedy": greedy, "pass@k": k_list, "sc": sc_list, "bon": bon_list, "weighted_sc": wsc_list,
                      "critique_acc": c_acc, "critique_acc_ms": c_acc_ms, "pass_ratio": pass_ratio, "response_len": response_len}
    
    print(f"Metrics to compute: {metric_kwargs}")
    _compute_metrics(config_file, **metric_kwargs)


if __name__ == "__main__":
    fire.Fire()
