import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import pandas as pd
import ast
import numpy as np
import argparse
import sys
from math_verify import parse, verify
from math import prod
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import re

DELIMITERS = [
    # "But wait",
    # "Wait",
    # "Alternatively",
    # "Is there another way to think about this?",
    # "But let me double-check",
    # "But hold on",
    "\n\n"
]

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores


def split_into_segments(text: str, delimiters: List[str]) -> List[str]:
    """Split text so each segment starts with one of the delimiters."""
    pattern = "|".join(re.escape(d) for d in delimiters)
    parts = re.split(fr'(?=(?:{pattern}))', text)
    return [p for p in parts if p.strip()]


def make_masked_texts(full: str, segments: List[str]) -> List[str]:
    """For each segment, return full text with that segment removed."""
    masked = []
    for i in range(len(segments)):
        masked.append("".join(seg for j, seg in enumerate(segments) if j != i))
    return masked

def compute_segment_spans(
    text: str,
    segments: List[str],
    tokenizer: AutoTokenizer
) -> List[Tuple[int,int]]:
    """
    Return list of (token_start, token_end) for each segment,
    guaranteeing each span covers at least one token.
    """

    # 1) Find the char‐level start & end of each segment
    char_starts = []
    char_ends   = []
    search_pos = 0

    for seg in segments:
        idx = text.find(seg, search_pos)
        if idx < 0:
            raise ValueError(f"Segment not found in text: {seg!r}")
        char_starts.append(idx)
        # end = start + length(seg) - 1
        char_ends.append(idx + len(seg) - 1)
        # move the search window past this segment
        search_pos = idx + len(seg)

    # 2) Tokenize with offsets
    enc = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        # use_fast=True
    )
    offsets = enc["offset_mapping"]  # each is (char_start, char_end)

    # 3) For each segment, pick the first token whose span overlaps the segment start,
    #    and the last token whose span overlaps the segment end.
    spans = []
    for cs, ce in zip(char_starts, char_ends):
        # first token that starts at or after cs
        ts = next(i for i,(s,e) in enumerate(offsets) if e > cs)
        # last token that ends at or before ce
        te = max(i for i,(s,e) in enumerate(offsets) if s <= ce)
        spans.append((ts, te))

        if te < ts:
            # sanity check
            raise RuntimeError(f"Computed empty span for segment at chars {cs}-{ce}")

    return spans

def compute_weights(
    question: str,
    full_response: str,
    masked_texts: List[str],
    embedder: SentenceTransformer
) -> np.ndarray:
    """
    Compute 1 - cosine_sim for each masked text vs. (q + full_resp),
    then normalize to sum to 1.
    """
    base_emb = embedder.encode(question + "\n" + full_response, convert_to_numpy=True)
    mask_embs = embedder.encode(masked_texts, convert_to_numpy=True)
    sims = cosine_similarity(mask_embs, base_emb.reshape(1,-1)).reshape(-1)
    weights = 1.0 - sims
    weights /= weights.sum()
    return weights

def make_weighted_rewards(
    logprobs: List[float],
    spans: List[Tuple[int,int]],
    weights: np.ndarray
) -> float:
    """
    For each segment i, take geometric mean of its token probs,
    then take weighted sum across segments.
    """
    seg_scores = []
    for (ts, te) in spans:
        segment_probs = logprobs[ts: te+1]
        # geometric mean
        # gm = float(np.exp(np.mean(np.log(segment_probs))))
        gm = float(np.mean(np.log(segment_probs)))
        seg_scores.append(gm)
    

    # return float(np.dot(weights, seg_scores))
    return float(np.exp(np.dot(weights, seg_scores)))


def generate_all_rewards_weighted(
    df: pd.DataFrame,
    tokenizer: AutoTokenizer,
    embedder: SentenceTransformer
) -> List[List[List[float]]]:
    """
    Returns a list of shape [n_questions][n_candidates][1]  (to match your prob-based shape).
    """
    all_rewards = []
    for i, row in df.iterrows():
        question = row["question"]
        preds = ast.literal_eval(row["all_predictions"])
        probs = ast.literal_eval(row["logprobs"])
        row_rewards = []

        for response, token_probs in zip(preds, probs):
            # 1) segment
            segments = split_into_segments(response, DELIMITERS)
            # 2) masked texts
            masked = make_masked_texts(response, segments)
            # 3) weights
            weights = compute_weights(question, response, masked, embedder)
            # 4) spans
            spans = compute_segment_spans(response, segments, tokenizer)
            # 5) weighted reward
            w = make_weighted_rewards(token_probs, spans, weights)
            row_rewards.append([w])  # match the [[score]] format

        all_rewards.append(row_rewards)
    return all_rewards


def evaluate_threshold(df, all_rewards, threshold, llm_call_limit):
    """Evaluate performance for a given threshold"""
    all_rewards_v1 = [[val[0] for val in _val] for _val in all_rewards]
    all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]

    choices_2_scores = [compute_answer_scores(final_ans[:1], reward_score[:1]) for final_ans, reward_score in zip(all_final_ans, all_rewards_v1)]
    indices = np.arange(len(df))

    num_llm_calls = len(indices)
    each_sample_num_llm_calls = [1 for _ in range(len(df))]

    for i in range(llm_call_limit - 1):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        num_llm_calls += len(indices)

        for index in indices:
            choices_2_scores[index] = compute_answer_scores(all_final_ans[index][:i+2], all_rewards_v1[index][:i+2])
            each_sample_num_llm_calls[index] += 1
    
    preds = []
    all_candidates = []
    for answer_scores in choices_2_scores:
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])
        all_candidates.append(answer_scores)

    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    
    return num_llm_calls / len(df), acc, each_sample_num_llm_calls, preds, all_candidates

def main():
    parser = argparse.ArgumentParser(description='Weighted Token-based Early Stopping')
    parser.add_argument('csv_file', help='Path to the CSV file')
    parser.add_argument('--thresholds', nargs='+', type=float, default=[0.999], 
                       help='List of thresholds to evaluate (default: 0.999)')
    parser.add_argument('--llm_call_limit_list', nargs='+', default=[100000], type=int,
                        help='List of LLM call limits')
    parser.add_argument('--output_file', type=str, default=None,
                        help='Path to the output CSV file')
    parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2',
                        help='Sentence transformer model name')
    parser.add_argument('--tokenizer_name', type=str, default='all-MiniLM-L6-v2',
                        help='Tokenizer name')
    args = parser.parse_args()
    
    # Load CSV
    print(f"Loading CSV file: {args.csv_file}")
    df = pd.read_csv(args.csv_file)

    # Load sentence transformer
    print(f"Loading sentence transformer: {args.model_name}")
    embedder = SentenceTransformer(args.model_name)
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)

    # Generate weighted rewards
    print("Generating weighted rewards...")
    all_rewards = generate_all_rewards_weighted(df, tokenizer, embedder)
    
    # Evaluate each threshold
    print(f"\nEvaluating thresholds: {args.thresholds}")
    print("Threshold\tLLM_Calls_Ratio\tAccuracy")
    print("-" * 50)
    
    for llm_call_limit in args.llm_call_limit_list:
        for threshold in args.thresholds:
            llm_calls_ratio, accuracy, each_sample_num_llm_calls, preds, all_candidates = evaluate_threshold(df, all_rewards, threshold, llm_call_limit)
            print(f"llm_call_limit: {llm_call_limit}\tthreshold: {threshold}\t\t{llm_calls_ratio:.3f}\t\t{accuracy:.2f}%")
    
    # Save results
    df['preds'] = preds
    df['all_candidates'] = all_candidates
    df['each_sample_num_llm_calls'] = each_sample_num_llm_calls
    if args.output_file:
        df.to_csv(args.output_file, index=False)

if __name__ == "__main__":
    main() 