import logging
import os
import sys

import torch
import transformers

sys.modules["transformers.AdamW"] = torch.optim.AdamW
setattr(transformers, "AdamW", torch.optim.AdamW)

import pickle
from typing import List

import numpy as np
import wandb
from seml.experiment import Experiment

from structured_llmuq.model.build import build_model
from structured_llmuq.utils.uq_baselines import (
    CocoaEstimator,
    KLEEstimator,
    PTrueEstimator,
    SAREstimator,
    SemanticEntropyEstimator,
)

os.environ["TORCHDYNAMO_DISABLE"] = "1"  # Gemma 3 caching error

experiment = Experiment()


@experiment.automain
def main(
    # Define your configuration parameters here
    estimators: List[str],
    out_path: str | None = None,
    run_name: str | None = None,
    p_true_config: dict | None = None,
):
    if out_path is None:
        assert run_name is not None, "Either out_path or run_name must be provided."
        api = wandb.Api()
        run = api.run(run_name)
        config = run.config

        experiment = config["experiment"]

        out_path = (
            config["output"]["path"] + f"experiment_{experiment}/{run.name}/results.pkl"
        )
    p_true_config = p_true_config or {}
    logging.info(f"Loading results from {out_path}")

    # Load the results
    with open(out_path, "rb") as file:
        run_data = pickle.load(file)

    estimator_objs = []
    if "se" in estimators:
        estimator_objs.append(SemanticEntropyEstimator())

    if any(x in estimators for x in ["sar", "p_true", "p_good"]):
        # Build the model and estimator
        model = build_model(**config["model"])
        if "p_true" in estimators:
            estimator = PTrueEstimator(model)
            estimator_objs.append(estimator)
        if "p_good" in estimators:
            estimator = PTrueEstimator(
                model,
                prompt="{q}\n:{a}\nIs the summary:\n-Good\n-Bad\nThe summary is:",
                expected="Good",
                name="p_good",
            )
            estimator_objs.append(estimator)
        if "sar" in estimators:
            estimator = SAREstimator(model)
            estimator_objs.append(estimator)

    if "kle" in estimators:
        estimator_objs.append(KLEEstimator())

    if "cocoa" in estimators:
        estimator_objs.append(CocoaEstimator())

    for i, sample in enumerate(run_data):
        answers = sample["generations"]["tokens_decoded_generated_truncated"]
        question = sample["question"]
        log_likelihood_truncated = sample["generations"]["log_likelihood_truncated"]

        for estimator in estimator_objs:
            if isinstance(estimator, SemanticEntropyEstimator):
                probs = [np.exp(llk.mean()) for llk in log_likelihood_truncated]
                se_result = estimator(answers=answers, question=None, probs=probs)
                sample["se"] = se_result

            elif isinstance(estimator, SAREstimator):
                tokens_truncated = sample["generations"]["tokens_truncated"]
                sar_result = estimator(
                    answers=answers,
                    tokens=tokens_truncated,
                    log_likelihoods=log_likelihood_truncated,
                )
                sample["sar"] = sar_result
            elif isinstance(estimator, KLEEstimator):
                # OOM error
                kle_result = estimator(answers=answers)
                sample["kle"] = kle_result
            elif isinstance(estimator, CocoaEstimator):
                beam_search_text = sample["beam_search_answer"][
                    "tokens_decoded_generated_truncated"
                ][0]
                beam_search_ll = sample["beam_search_answer"][
                    "log_likelihood_truncated"
                ][0]

                cocoa_result = estimator(
                    answers=answers,
                    greedy_answer=beam_search_text,
                    greedy_log_likelihoods=beam_search_ll,
                )
                sample["cocoa"] = cocoa_result
            elif isinstance(estimator, PTrueEstimator):
                prompt = sample["prompt"]
                beam_search_text = sample["beam_search_answer"][
                    "tokens_decoded_generated_truncated"
                ][0]
                p_true_result = estimator(answer=beam_search_text, question=prompt)
                sample["p_true"] = p_true_result

        torch.cuda.empty_cache()  # <-- Clear CUDA memory

        logging.info(f"Computed estimators for sample with idx {i}")

    # check if directory exists and create if not
    out_dir = os.path.dirname(out_path)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    # Save back to the same file
    with open(out_path, "wb") as file:
        pickle.dump(run_data, file)

    print(f"Saved {estimator}-computed results to {out_path}\n")
