import sys

import torch
import transformers

sys.modules["transformers.AdamW"] = torch.optim.AdamW
setattr(transformers, "AdamW", torch.optim.AdamW)
import os
import pickle
import sys
from collections import Counter
from pathlib import Path

import numpy as np
import torch.optim
from seml.experiment import Experiment
from tqdm import tqdm
from transformers import AdamW

from structured_llmuq.utils.minimum_bayes_action import (
    ArgminMinimumBayesAction,
    FBetaMinimumBayesAction,
    HammingMinimumBayesAction,
)

os.environ["TORCHDYNAMO_DISABLE"] = "1"  # Gemma 3 caching error

experiment = Experiment()


@experiment.automain
def main(
    # Define your configuration parameters here
    experiment_results_path: str,
    threshold: float,
):
    results_path = Path(experiment_results_path)
    with open(results_path, "rb") as f:
        results = pickle.load(f)
    from lm_polygraph.generation_metrics.alignscore import AlignScore

    align_score = AlignScore()
    mba_hamming = HammingMinimumBayesAction()
    mba_f1 = FBetaMinimumBayesAction(beta=1.0)
    mbs_hamming = ArgminMinimumBayesAction(lambda s1, s2: np.sum(s1 != s2))
    mbs_f1 = ArgminMinimumBayesAction(
        lambda s1, s2: -(
            2 * np.sum((s1 == 1) & (s2 == 1)) / (np.sum(s1) + np.sum(s2) + 1e-8)
        )
    )

    metrics = []
    for idx, result in tqdm(enumerate(results)):
        print(idx)
        if "generations" not in result["metrics"]:
            print("No generations found, skipping...")
            metrics.append(result)
            continue
        cnts = Counter(
            edge
            for generation in result["metrics"]["generations"]
            for edge in generation
        )
        probabilities = {
            edge: count / len(result["metrics"]["generations"])
            for edge, count in cnts.items()
        }
        probabilities  # TODO: recalibrate?
        bayes = {edge for edge, prob in probabilities.items() if prob >= threshold}
        closest_generation_idx = min(
            range(len(result["metrics"]["generations"])),
            key=lambda idx: len(
                bayes.symmetric_difference(set(result["metrics"]["generations"][idx]))
            ),
        )
        edge_to_idx = {}
        for generation in result["metrics"]["generations"]:
            for edge in generation:
                if edge not in edge_to_idx:
                    edge_to_idx[edge] = len(edge_to_idx)
        idx_to_edge = {idx: edge for edge, idx in edge_to_idx.items()}
        generations_z = np.zeros(
            (len(result["metrics"]["generations"]), len(edge_to_idx)), dtype=int
        )
        for gen_idx, generation in enumerate(result["metrics"]["generations"]):
            for edge in generation:
                edge_idx = edge_to_idx[edge]
                generations_z[gen_idx, edge_idx] = 1

        mba_hamming_z = mba_hamming(generations_z)
        mba_f1_z = mba_f1(generations_z)
        mbs_hamming_z = mbs_hamming(generations_z)
        mbs_f1_z = mbs_f1(generations_z)
        bayes_hamming = {
            idx_to_edge[idx] for idx, val in enumerate(mba_hamming_z) if val == 1
        }
        bayes_f1 = {idx_to_edge[idx] for idx, val in enumerate(mba_f1_z) if val == 1}
        bayes_hamming_sampled = {
            idx_to_edge[idx] for idx, val in enumerate(mbs_hamming_z) if val == 1
        }
        bayes_f1_sampled = {
            idx_to_edge[idx] for idx, val in enumerate(mbs_f1_z) if val == 1
        }

        reassembled = ". ".join(bayes) + "."
        beam = result["beam_search_answer"]["tokens_decoded_generated_truncated"][0]
        beam_reassembled = ". ".join(result["metrics"]["beam_search"][0].keys()) + "."

        inputs = {
            "bayes": reassembled,
            "beam": beam,
            "beam_reassembled": beam_reassembled,
            "closest_to_bayes_reassembled": ". ".join(
                result["metrics"]["generations"][closest_generation_idx]
            )
            + ".",
            "closest_to_bayes": result["generations"][
                "tokens_decoded_generated_truncated"
            ][closest_generation_idx],
            "bayes_hamming": ". ".join(bayes_hamming) + ".",
            "bayes_f1": ". ".join(bayes_f1) + ".",
            "bayes_hamming_sampled": ". ".join(bayes_hamming_sampled) + ".",
            "bayes_f1_sampled": ". ".join(bayes_f1_sampled) + ".",
            "most_common_sequence": Counter(
                result["generations"]["tokens_decoded_generated_truncated"]
            ).most_common(1)[0][0],
            "most_common_latent": ". ".join(
                Counter(
                    frozenset(k) for k in result["metrics"]["generations"]
                ).most_common(1)[0][0]
            )
            + ".",
        }

        scores = align_score(
            {
                "greedy_texts": list(inputs.values()),
            },
            [result["reference_answer"]] * len(inputs),
        )
        scores_inv = align_score(
            {
                "greedy_texts": [result["reference_answer"]] * len(inputs),
            },
            list(inputs.values()),
        )

        # Score w.r.t. ground truth text
        scores_reference = align_score(
            {
                "greedy_texts": list(inputs.values()),
            },
            [result["question"]] * len(inputs),
        )
        scores_inv_reference = align_score(
            {
                "greedy_texts": [result["question"]] * len(inputs),
            },
            list(inputs.values()),
        )

        metrics.append(
            result
            | {
                "align_score": {
                    "scores_gnd": {
                        key: score for key, score in zip(inputs.keys(), scores)
                    },
                    "scores_inv_gnd": {
                        key: score for key, score in zip(inputs.keys(), scores_inv)
                    },
                    "scores_reference": {
                        key: score
                        for key, score in zip(inputs.keys(), scores_reference)
                    },
                    "scores_inv_reference": {
                        key: score
                        for key, score in zip(inputs.keys(), scores_inv_reference)
                    },
                    "texts": inputs,
                    "edge_probabilities": probabilities,
                }
            }
        )

    output_path = results_path.parent / f"{results_path.stem}_with_alignscore.pkl"
    with open(output_path, "wb") as f:
        pickle.dump(metrics, f)
    print(f"Saved results with align scores to {output_path}\n")
    return output_path
