import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import pandas as pd
import ast
import numpy as np
import argparse
import sys
from math_verify import parse, verify
from math import prod
from typing import List, Tuple

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
        all_scores_res.append(positive_probs.cpu().tolist())
    return all_scores_res

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

def generate_all_rewards(df, model, tokenizer, is_mc_question):
    """Generate all_rewards from the dataframe"""
    all_rewards = []
    
    for idx, question in enumerate(df['question']):
        preds = ast.literal_eval(df['all_predictions'][idx])  # list of 4 generated strings
        row_rewards = []

        for response in preds:
            # prepare the chat prompt
            if not is_mc_question:
                messages = [
                        {"role": "system", "content": "Provide your step-by-step reasoning first, "
                                                    "and then print \"The answer is \\boxed{{X}}\", "
                                                    "where X is the final answer, at the end of your response."},
                    {"role": "user",   "content": question},
                    {"role": "assistant", "content": response + "<extra_0>"}
                ]
            else:
                messages = [
                    {"role": "system", "content": "Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
                                                "where X is the answer choice (one capital letter), at the end of your response."},
                    {"role": "user",   "content": question},
                    {"role": "assistant", "content": response + "<extra_0>"}
                ]
            conversation_str = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            input_ids = tokenizer.encode(
                conversation_str,
                return_tensors="pt"
            ).to(model.device)

            # run model
            with torch.no_grad():
                outputs = model(input_ids=input_ids)

            # mask out everything except the <extra_0> separator tokens
            step_sep_id = tokenizer.encode("<extra_0>")[0]
            token_masks = (input_ids == step_sep_id)

            # compute and collect the reward sequence for this prediction
            reward_list = make_step_rewards(outputs.logits, token_masks)[0]
            row_rewards.append(reward_list)

        all_rewards.append(row_rewards)
    
    return all_rewards

def evaluate_threshold(df, all_rewards, threshold, llm_call_limit):
    """Evaluate performance for a given threshold"""
    all_rewards_v1 = [[val[0] for val in _val] for _val in all_rewards]
    all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]

    choices_2_scores = [compute_answer_scores(final_ans[:1], reward_score[:1]) for final_ans, reward_score in zip(all_final_ans, all_rewards_v1)]
    indices = np.arange(len(df))

    num_llm_calls = len(indices)
    # num_llm_calls_dist = [len(indices)]
    each_sample_num_llm_calls = [1 for _ in range(len(df))]

    for i in range(llm_call_limit - 1):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        num_llm_calls += len(indices)

        for index in indices:
            choices_2_scores[index] = compute_answer_scores(all_final_ans[index][:i+2], all_rewards_v1[index][:i+2])
            each_sample_num_llm_calls[index] += 1
    preds = []
    all_candidates = []
    for answer_scores in choices_2_scores:
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])
        all_candidates.append(answer_scores)

    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    
    return num_llm_calls / len(df), acc, each_sample_num_llm_calls, preds, all_candidates

def main():
    parser = argparse.ArgumentParser(description='Probability-based Early Stopping')
    parser.add_argument('csv_file', help='Path to the CSV file')
    parser.add_argument('--thresholds', nargs='+', type=float, default=[0.999], 
                       help='List of thresholds to evaluate (default: 0.999)')
    parser.add_argument('--llm_call_limit_list', nargs='+', default=[100000], type=int,
                        help='List of LLM call limits')
    parser.add_argument('--output_file', type=str, default=None,
                        help='Path to the output CSV file')
    args = parser.parse_args()
    
    # Load CSV
    print(f"Loading CSV file: {args.csv_file}")
    df = pd.read_csv(args.csv_file)

    # Generate all_rewards
    print("Generating Probs...")
    probs = [ast.literal_eval(df['logprobs'][i]) for i in range(len(df))]
    all_rewards = []
    for curr_query_probs in probs:
        # all_rewards.append(
        #     [[np.mean(curr_sample_probs)] for curr_sample_probs in curr_query_probs]
        # )
        all_rewards.append(
            [[np.exp(np.mean([np.log(val) for val in curr_sample_probs]))] for curr_sample_probs in curr_query_probs]
        )
    # all_rewards = generate_all_rewards(df, model, tokenizer, args.is_mc_question)
    
    # Evaluate each threshold
    print(f"\nEvaluating thresholds: {args.thresholds}")
    print("Threshold\tLLM_Calls_Ratio\tAccuracy")
    print("-" * 50)
    
    for llm_call_limit in args.llm_call_limit_list:
        for threshold in args.thresholds:
            llm_calls_ratio, accuracy, each_sample_num_llm_calls, preds, all_candidates = evaluate_threshold(df, all_rewards, threshold, llm_call_limit)
            print(f"llm_call_limit: {llm_call_limit}\tthreshold: {threshold}\t\t{llm_calls_ratio:.3f}\t\t{accuracy:.2f}%")
    
    # df['rewards'] = all_rewards

    # if args.output_file:
    #     df.to_csv(args.output_file, index=False)
    df['preds'] = preds
    df['all_candidates'] = all_candidates
    df['each_sample_num_llm_calls'] = each_sample_num_llm_calls
    if args.output_file:
        df.to_csv(args.output_file, index=False)

if __name__ == "__main__":
    main() 