import time
import torch
from tqdm import tqdm
import logging
import numpy as np
import gc

from datasets import Dataset

from lm_polygraph import WhiteboxModel
from lm_polygraph.estimators import Estimator
from lm_polygraph.stat_calculators import StatCalculator
from utils import parse_ans

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("bestofn_eval")


def is_correct_answer(generated_output: str, gold_answer: str) -> bool:
    pred = parse_ans(generated_output)
    gold = parse_ans(gold_answer)
    if pred is None or gold is None:
        return False
    try:
        return np.isclose(pred, gold).item()
    except Exception as e:
        return False


def process_stats(stats):
    """Keep only selected keys but allow dynamic PRM and ReasonEval keys.
    Includes any key that starts with 'prm_scores_' or 'reasoneval_scores_'.
    """
    keep_exact = {
        "greedy_texts",
        "greedy_tokens",
        "greedy_logprobs",
        "uncertainty_claim_logits",
        "reasoneval_scores",
        "prm_scores",  # backward-compat
    }
    result = {}
    for k, v in stats.items():
        if k in keep_exact or k.startswith("prm_scores_") or k.startswith("reasoneval_scores_"):
            result[k] = v
    return result


def _update_sample(
        r: dict,
        model: WhiteboxModel,
        estimators: list[Estimator],
        stat_calculators: list[StatCalculator],
        n: int,
        max_new_tokens: int = 100,
        verbose: bool = True,
):
    try:
        stats = r["stats"]
        input_text = r["input"]
        texts = [input_text for _ in range(n)]

        # stat calculators
        for stat_calc in stat_calculators:
            start_time = time.time()
            if verbose:
                log.info(f"Calculating {stat_calc}...")
            try:
                # import pdb; pdb.set_trace()
                stats.update(stat_calc(stats, texts, model, max_new_tokens=max_new_tokens))
                if verbose:
                    log.info(f"Done calculating in {round(time.time() - start_time, 2)} seconds")
            except torch.cuda.OutOfMemoryError as e:
                log.warning(f"OOM error in stat calculator {stat_calc}: {e}")
                log.warning("Clearing GPU cache and returning neutral result...")
                torch.cuda.empty_cache()
                gc.collect()
                
                # Create dummy data for OOM case
                # Each stat calculator returns different data structures:
                # - SampleGenerationCalculator: greedy_texts, greedy_tokens, etc.
                # - StepsExtractor: claims (list of Claim objects)
                # - PRMStatCalculator: prm_scores (list of floats per claim)
                # - ReasonEvalStatCalculator: reasoneval_scores (list of dicts per claim)
                # - CalculatorApplyUQHead: uncertainty_claim_logits (list of floats per claim)
                
                # Set basic generation stats
                stats["greedy_texts"] = ["[OOM ERROR]"] * n
                stats["greedy_tokens"] = [[]] * n
                stats["greedy_logprobs"] = [[]] * n
                
                # Create empty claims structure - this is needed by PRM/ReasonEval/UHead estimators
                stats["claims"] = [[]] * n  # Empty list of claims for each sample
                
                # Set scores for claim-based estimators (empty because no claims)
                stats["prm_scores"] = [[]] * n  # backward-compat
                # If PRM estimators are present, also set dynamic keys they expect
                for est in estimators:
                    try:
                        from bestofn.estimators.prm import PRMEstimator
                        if isinstance(est, PRMEstimator) and getattr(est, "model_id", ""):
                            stats[f"prm_scores_{est.model_id}"] = [[]] * n
                    except Exception:
                        pass
                # ReasonEval: keep backward-compat only; estimator expects 'reasoneval_scores'
                stats["reasoneval_scores"] = [[]] * n  # Empty scores for each sample
                stats["uncertainty_claim_logits"] = [[]] * n  # Empty scores for each sample
                
                r["stats"] = process_stats(stats)
                
                # Create neutral scores for all estimators
                # Most estimators will return empty list when no claims exist
                try:
                    estimations = {str(est): est(stats) for est in estimators}
                    r["scores"].update(estimations)
                except:
                    # If estimators fail, use default neutral scores
                    r["scores"].update({str(est): [0.5] * n for est in estimators})
                
                r.update({
                    "sample_texts": ["OOM ERROR"] * n,
                    "correctness": [False] * n  # Mark as wrong
                })
                return r
                
        r["stats"] = process_stats(stats)
        # import pdb; pdb.set_trace()
        # estimations
        estimations = {str(est): est(stats) for est in estimators}
        r["scores"].update(estimations)

        r.update({
            "sample_texts": r["stats"]["greedy_texts"],
            "correctness": [is_correct_answer(t, r["gold_answer"]) for t in r["stats"]["greedy_texts"]]
        })

        # Clear GPU cache after processing each sample
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return r
        
    except Exception as e:
        # import pdb; pdb.set_trace()
        log.error(f"Unexpected error in _update_sample: {e}")
        # Return neutral result for any other error
        stats = r.get("stats", {})
        
        # Ensure all required stats exist
        stats["greedy_texts"] = ["[ERROR]"] * n
        stats["greedy_tokens"] = [[]] * n
        stats["greedy_logprobs"] = [[]] * n
        stats["claims"] = [[]] * n
        stats["prm_scores"] = [[]] * n  # backward-compat
        # Also set dynamic PRM keys if estimators specify them
        for est in estimators:
            try:
                from bestofn.estimators.prm import PRMEstimator
                if isinstance(est, PRMEstimator) and getattr(est, "model_id", ""):
                    stats[f"prm_scores_{est.model_id}"] = [[]] * n
            except Exception:
                pass
        stats["reasoneval_scores"] = [[]] * n
        stats["uncertainty_claim_logits"] = [[]] * n
        
        r["stats"] = process_stats(stats)
        
        try:
            estimations = {str(est): est(stats) for est in estimators}
            r["scores"].update(estimations)
        except:
            r["scores"].update({str(est): [0.5] * n for est in estimators})
            
        r.update({
            "sample_texts": ["[ERROR]"] * n,
            "correctness": [False] * n
        })
        return r


def _bestofn(
        dataset: Dataset,
        model: WhiteboxModel,
        estimators: list[Estimator],
        stat_calculators: list[StatCalculator],
        save_path: str,
        save_frequency: int | None,
        n: int,
        max_new_tokens,
        results: list[dict],
        verbose: bool = True,
):
    assert len(dataset) == len(results)
    log.info(f"Processing {len(dataset)} samples with {n} completions each...")

    for i, (sample, r) in tqdm(enumerate(zip(dataset, results)), total=len(dataset)):
        results[i] = _update_sample(r, model, estimators, stat_calculators, n, max_new_tokens, verbose)
        # import pdb; pdb.set_trace()
        if (save_frequency is not None and (i + 1) % save_frequency == 0) or i + 1 == len(dataset):
            if verbose:
                log.info(f"Saving results to {save_path}")
            torch.save(results, save_path)

    log.info("Done.")


def bestofn(
        dataset: Dataset,
        model: WhiteboxModel,
        estimators: list[Estimator],
        stat_calculators: list[StatCalculator],
        save_path: str,
        save_frequency: int | None,
        n: int,
        max_new_tokens: int = 100,
        verbose: bool = True,
):
    results = [{
        "input": sample["question"],
        "gold_answer": sample["answer"],
        "scores": {},
        "stats": {},
    } for sample in dataset]
    _bestofn(
        dataset, model, estimators, stat_calculators,
        save_path, save_frequency,
        n, max_new_tokens, results,
        verbose=verbose,
    )


def update_bestofn(
        dataset: Dataset,
        model: WhiteboxModel,
        estimators: list[Estimator],
        stat_calculators: list[StatCalculator],
        save_path: str,
        save_frequency: int | None,
        verbose: bool = True,
):
    results = torch.load(save_path, weights_only=False)
    n = len(results[0]["sample_texts"])
    _bestofn(
        dataset, model, estimators, stat_calculators,
        save_path, save_frequency,
        n, max_new_tokens=100, results=results,
        verbose=verbose,
    )
