import sys
# 1) Raise the decimal‐to‐int conversion limit (Python 3.11+)
sys.set_int_max_str_digits(10_000_000)

import numpy as np
import pandas as pd
from sympy import sympify, SympifyError
from sklearn.linear_model import LogisticRegression

def safe_sympify(s):
    """
    Try to parse with sympy; return None on any failure.
    """
    try:
        return sympify(s)
    except (SyntaxError, ValueError, SympifyError):
        return None

def is_math_equiv(ref, pred):
    """
    Test math equivalence, but skip any ref/pred that won't parse.
    """
    # try wrapping with $…$, plain, and un‑escaped
    candidates = [f"${ref}$", ref, pred.replace("\\(", "").replace("\\)", "")]
    parses = []
    for expr in candidates:
        p = safe_sympify(expr)
        if p is not None:
            parses.append(p)
    if len(parses) < 2:
        # couldn’t parse one side; assume “not equivalent”
        return False

    # only compare the first two successfully parsed expressions
    try:
        return parses[0].equals(parses[1])
    except Exception:
        return False

def fit_confidence_model(data, **logreg_kwargs):
    X = data['num_tokens'].values.reshape(-1, 1)
    y = data['correctness'].values
    model = LogisticRegression(**logreg_kwargs)
    model.fit(X, y)
    def predict_proba_tokens(x):
        x_arr = np.array(x).reshape(-1, 1)
        return model.predict_proba(x_arr)[:, 1]
    model.predict_proba_tokens = predict_proba_tokens
    return model

# Usage example:
conf_data = pd.read_json(
    f"./confidence_score_training/data/combined_confidence_score_data_tiny_s42_MATH500.jsonl",
    f"confidence_score_training/data/combined_confidence_score_data_tiny_s42_MATH500.jsonl",
    lines=True
)

conf_predictor = fit_confidence_model(conf_data, solver='lbfgs', max_iter=1000)

# To get P(correctness=1) for a single token count:
p = conf_predictor.predict_proba_tokens(4780)
print(f"P(correct) at 4780 tokens ≈ {p}")

# Or batch:
probs = conf_predictor.predict_proba_tokens([100, 500, 1000])
print(probs)

from transformers import AutoTokenizer
# model_id  = "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/530ca3e1ad39d440e182c2e4317aa40f012512fa/"
model_id  = "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-1.5B-Instruct/snapshots/989aa7980e4cf806f80c7fef2b1adb7bc71aa306/"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

csv_file = "./Results/skills_2/MATH500/self_con_seed42_budget16_acc65.8_models[tinyQwen]_v2.csv"
model_path = "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-PRM-72B/snapshots/9df429b02adb5f764cd6e30e76a0cca16d501ae1/"
thresholds = [0.8, 0.85, 0.9, 0.95, 0.99, 0.999]
llm_call_limit_list = [4, 8, 16]

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import pandas as pd
import ast
import numpy as np
import argparse
import sys
from math_verify import parse, verify
from math import prod
from typing import List, Tuple

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
        all_scores_res.append(positive_probs.cpu().tolist())
    return all_scores_res

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

def generate_all_rewards(df, model, tokenizer):
    """Generate all_rewards from the dataframe"""
    all_rewards = []
    
    for idx, question in enumerate(df['question']):
        print(f"Question Number: {idx+1}")
        preds = ast.literal_eval(df['all_predictions'][idx])  # list of 4 generated strings
        row_rewards = []

        for response in preds:
            # prepare the chat prompt
            messages = [
                {"role": "system", "content": "Provide your step-by-step reasoning first, "
                                              "and then print \"The answer is \\boxed{{X}}\", "
                                              "where X is the final answer, at the end of your response."},
                {"role": "user",   "content": question},
                {"role": "assistant", "content": response + "<extra_0>"}
            ]
            conversation_str = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            input_ids = tokenizer.encode(
                conversation_str,
                return_tensors="pt"
            ).to(model.device)

            # run model
            with torch.no_grad():
                outputs = model(input_ids=input_ids)

            # mask out everything except the <extra_0> separator tokens
            step_sep_id = tokenizer.encode("<extra_0>")[0]
            token_masks = (input_ids == step_sep_id)

            # compute and collect the reward sequence for this prediction
            reward_list = make_step_rewards(outputs.logits, token_masks)[0]
            row_rewards.append(reward_list)

        all_rewards.append(row_rewards)
    
    return all_rewards

def evaluate_threshold(df, all_rewards, threshold, llm_call_limit):
    """Evaluate performance for a given threshold"""
    all_rewards_v1 = [[val[0] for val in _val] for _val in all_rewards]
    all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]

    choices_2_scores = [compute_answer_scores(final_ans[:1], reward_score[:1]) for final_ans, reward_score in zip(all_final_ans, all_rewards_v1)]
    indices = np.arange(len(df))

    num_llm_calls = len(indices)

    for i in range(llm_call_limit - 1):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        num_llm_calls += len(indices)

        for index in indices:
            choices_2_scores[index] = compute_answer_scores(all_final_ans[index][:i+2], all_rewards_v1[index][:i+2])

    preds = []
    for answer_scores in choices_2_scores:
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])

    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    
    return num_llm_calls / len(df), acc

def main():
    # parser = argparse.ArgumentParser(description='Analyze thresholds for model switching')
    # parser.add_argument('csv_file', help='Path to the CSV file')
    # parser.add_argument('--thresholds', nargs='+', type=float, default=[0.999], 
    #                    help='List of thresholds to evaluate (default: 0.999)')
    # parser.add_argument('--model_path', default="/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-PRM-72B/snapshots/9df429b02adb5f764cd6e30e76a0cca16d501ae1/",
    #                    help='Path to the model')
    # parser.add_argument('--llm_call_limit_list', nargs='+', default=[100000], type=int,
    #                     help='List of LLM call limits')
    # args = parser.parse_args()
    
    # Load CSV
    print(f"Loading CSV file: {csv_file}")
    df = pd.read_csv(csv_file)
    
    # Load model and tokenizer
    print("Loading model and tokenizer...")
    # model = AutoModel.from_pretrained(
    #     model_path, 
    #     device_map="auto", 
    #     torch_dtype=torch.bfloat16,
    #     trust_remote_code=True,
    # ).eval()

    # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    # Generate all_rewards
    print("Generating all_rewards...")
    # all_rewards = generate_all_rewards(df, model, tokenizer)
    all_preds = [ast.literal_eval(val) for val in list(df['all_predictions'])]
    all_rewards = [
        [[conf_predictor.predict_proba_tokens(len(tokenizer(pred).input_ids))] for pred in preds] for preds in all_preds
    ]
    
    # Evaluate each threshold
    print(f"\nEvaluating thresholds: {thresholds}")
    print("Threshold\tLLM_Calls_Ratio\tAccuracy")
    print("-" * 50)
    
    for llm_call_limit in llm_call_limit_list:
        for threshold in thresholds:
            llm_calls_ratio, accuracy = evaluate_threshold(df, all_rewards, threshold, llm_call_limit)
            print(f"llm_call_limit: {llm_call_limit}\tthreshold: {threshold:.3f}\t\t{llm_calls_ratio:.3f}\t\t{accuracy:.2f}%")

main()