import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import pandas as pd
import ast
import numpy as np
import argparse
import sys
from math_verify import parse, verify
from math import prod
from typing import List, Tuple

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
        all_scores_res.append(positive_probs.cpu().tolist())
    return all_scores_res

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

def generate_all_rewards(df, model, tokenizer, is_mc_question):
    """Generate all_rewards from the dataframe"""
    all_rewards = []
    
    for idx, question in enumerate(df['question']):
        preds = ast.literal_eval(df['all_predictions'][idx])  # list of 4 generated strings
        row_rewards = []

        for response in preds:
            # prepare the chat prompt
            if not is_mc_question:
                messages = [
                        {"role": "system", "content": "Provide your step-by-step reasoning first, "
                                                    "and then print \"The answer is \\boxed{{X}}\", "
                                                    "where X is the final answer, at the end of your response."},
                    {"role": "user",   "content": question},
                    {"role": "assistant", "content": response + "<extra_0>"}
                ]
            else:
                messages = [
                    {"role": "system", "content": "Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
                                                "where X is the answer choice (one capital letter), at the end of your response."},
                    {"role": "user",   "content": question},
                    {"role": "assistant", "content": response + "<extra_0>"}
                ]
            conversation_str = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            input_ids = tokenizer.encode(
                conversation_str,
                return_tensors="pt"
            ).to(model.device)

            # run model
            with torch.no_grad():
                outputs = model(input_ids=input_ids)

            # mask out everything except the <extra_0> separator tokens
            step_sep_id = tokenizer.encode("<extra_0>")[0]
            token_masks = (input_ids == step_sep_id)

            # compute and collect the reward sequence for this prediction
            reward_list = make_step_rewards(outputs.logits, token_masks)[0]
            row_rewards.append(reward_list)

        all_rewards.append(row_rewards)
    
    return all_rewards

def evaluate_threshold(df, all_rewards, threshold, llm_call_limit):
    """Evaluate performance for a given threshold"""
    all_rewards_v1 = [[val[0] for val in _val] for _val in all_rewards]
    all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]

    choices_2_scores = [compute_answer_scores(final_ans[:1], reward_score[:1]) for final_ans, reward_score in zip(all_final_ans, all_rewards_v1)]
    indices = np.arange(len(df))

    num_llm_calls = len(indices)
    num_llm_calls_dist = [len(indices)]
    each_sample_num_llm_calls = [1 for _ in range(len(df))]

    for i in range(llm_call_limit - 1):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        num_llm_calls += len(indices)
        num_llm_calls_dist.append(len(indices))
        for index in indices:
            choices_2_scores[index] = compute_answer_scores(all_final_ans[index][:i+2], all_rewards_v1[index][:i+2])
            each_sample_num_llm_calls[index] += 1

    preds = []
    all_candidates = []
    for answer_scores in choices_2_scores:
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])
        all_candidates.append(answer_scores)

        
    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    
    return num_llm_calls / len(df), acc, num_llm_calls_dist , preds, all_candidates, each_sample_num_llm_calls

def main():
    parser = argparse.ArgumentParser(description='Analyze thresholds for model switching')
    parser.add_argument('csv_file', help='Path to the CSV file')
    parser.add_argument('--thresholds', nargs='+', type=float, default=[0.999], 
                       help='List of thresholds to evaluate (default: 0.999)')
    parser.add_argument('--model_path', default="/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-PRM-72B/snapshots/9df429b02adb5f764cd6e30e76a0cca16d501ae1/",
                       help='Path to the model')
    parser.add_argument('--is_mc_question', default=False, type=bool,
                       help='Whether the questions are multiple choice')
    parser.add_argument('--existing_rewards', default=False, type=bool,
                       help='Whether the rewards are already computed')
    parser.add_argument('--llm_call_limit_list', nargs='+', default=[100000], type=int,
                        help='List of LLM call limits')
    parser.add_argument('--output_file', default=None, type=str,
                        help='Path to the output file')
    args = parser.parse_args()
    
    # Load CSV
    print(f"Loading CSV file: {args.csv_file}")
    df = pd.read_csv(args.csv_file)
    
    # Load model and tokenizer
    if not args.existing_rewards:
        print("Loading model and tokenizer...")
        model = AutoModel.from_pretrained(
            args.model_path, 
            device_map="auto", 
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()

        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    
        # Generate all_rewards
        print("Generating all_rewards...")
        all_rewards = generate_all_rewards(df, model, tokenizer, args.is_mc_question)
    
    else:
        df_rewards = pd.read_csv(args.output_file)
        all_rewards = [ast.literal_eval(df_rewards['rewards'][idx]) for idx in range(len(df_rewards))]
    
    # Evaluate each threshold
    print(f"\nEvaluating thresholds: {args.thresholds}")
    print("Threshold\tLLM_Calls_Ratio\tAccuracy")
    print("-" * 50)
    
    for llm_call_limit in args.llm_call_limit_list:
        for threshold in args.thresholds:
            llm_calls_ratio, accuracy, num_llm_calls_dist, preds, all_candidates, each_sample_num_llm_calls = evaluate_threshold(df, all_rewards, threshold, llm_call_limit)
            print(f"llm_call_limit: {llm_call_limit}\tthreshold: {threshold:.3f}\t\t{llm_calls_ratio:.3f}\t\t{accuracy:.2f}%")
            print(f"Num LLM calls distribution: {num_llm_calls_dist}")
    
    df['preds'] = preds
    df['all_candidates'] = all_candidates
    df['each_sample_num_llm_calls'] = each_sample_num_llm_calls

    if not args.existing_rewards:
        df['rewards'] = all_rewards
        if args.output_file:
            df.to_csv(args.output_file, index=False)

    df.to_csv(args.output_file.split('.csv')[0] + '_with_preds.csv', index=False)

if __name__ == "__main__":
    main() 