import argparse
from os import truncate
from pathlib import Path
import random

from hashlib import sha256

from bert_score import BERTScorer
from datasets import load_dataset, load_metric
from infomet import InfoMet, against_uniform, get_measure_fn, measures as Measures
from infomet.infomet import renyi_div
import numpy as np
import pandas as pd
from sacrebleu import BLEU
from sklearn.covariance import MinCovDet
import torch
from torch.utils.data import random_split
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

from utils.datasets import prep_dataset, prep_inputs, prep_model
from utils.scores import (
    compute_mahalanobis,
    mk_score_function,
    mk_iproj_score_function,
    temperatures,
    mk_score_function_nostep,
)


def parse_args():
    parser = argparse.ArgumentParser("Description")

    # Task configuration
    parser.add_argument(
        "--dataset_name", type=str, help="Dataset name", default="wmt16"
    )
    parser.add_argument(
        "--dataset_config", type=str, help="Dataset config", default="de-en"
    )
    parser.add_argument(
        "--dataset_split", type=str, help="Train, validation or test", default="test"
    )

    parser.add_argument(
        "--compute_dist_mahalanobis",
        type=str,
        default=None,
        help="Path to mahalanobis reference file.",
    )
    parser.add_argument(
        "--compute_dist_set",
        type=str,
        default=None,
        help="Path to mahalanobis reference file.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Huggingface model name",
        default="Helsinki-NLP/opus-mt-de-en",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Where to store the results.",
        default="data/",
    )
    parser.add_argument(
        "--shuffle_input",
        action="store_true",
        default=False,
        help="Size of the beam search. Has to be > 1 since the point of this is to select the best hyps among those.",
    )
    parser.add_argument(
        "--switch_lang",
        action="store_true",
        default=False,
        help="Wether to swap target lang and source lang",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=10**10,
        help="Maximum size of the dataset to use.",
    )
    # Decoding parameters
    parser.add_argument(
        "--num_beams",
        type=int,
        default=8,
        help="Size of the beam search. Has to be > 1 since the point of this is to select the best hyps among those.",
    )

    return parser.parse_args()


def main():

    import pickle as pk

    args = parse_args()

    n_hyps = args.num_beams

    device = "cuda"

    model, tokenizer = prep_model(args.model_name)

    model = model.eval().to(device)

    dataset = prep_dataset(
        args.dataset_name,
        args.dataset_config,
        args.dataset_split,
        args.size,
        tokenizer,
        args.switch_lang,
    )

    if args.switch_lang:
        tgt, src = args.dataset_config.split("-")
        args.dataset_config = f"{src}-{tgt}"

    file_name = (
        Path(args.output_dir)
        / f"{args.model_name.replace('/', '-')}-{args.dataset_config}-{args.dataset_name.replace('/', '-')}-{args.dataset_split}.csv"
    )
    print(file_name)

    records = []
    records_nosteps = []
    if args.compute_dist_mahalanobis is not None:
        means = {}
        covs = {}

        for fpath in Path(args.compute_dist_mahalanobis).iterdir():
            size = int(fpath.name.split("-")[1])
            with open(fpath, "rb") as fd:
                import pickle as pk

                mincov = pk.load(fd)
                _means = [torch.Tensor(m[0]).to(device) for m in mincov]
                _covs = [torch.Tensor(m[1]).to(device) for m in mincov]

            means[size] = _means
            covs[size] = _covs
    if args.compute_dist_set is not None:
        with open(args.compute_dist_set, "rb") as fd:
            import pickle as pk

            sets = torch.Tensor(pk.load(fd)).to(device)

    compute_scores = mk_score_function()
    compute_scores_bags = mk_score_function_nostep()
    compute_iproj = mk_iproj_score_function()

    bleu = BLEU(effective_order=True)
    bertscorer = BERTScorer(lang="en", rescale_with_baseline=True)

    with torch.no_grad():
        for sample_id, (x, y) in tqdm(enumerate(dataset)):
            if args.shuffle_input:
                x = " ".join(random.shuffle(x.split(" ")))

            inputs = prep_inputs(x, tokenizer, args.dataset_name)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            outputs = model.generate(
                **inputs,
                num_beams=n_hyps,
                num_return_sequences=n_hyps,
                output_scores=True,
                return_dict_in_generate=True,
                output_hidden_states=(args.compute_dist_mahalanobis is not None),
                max_length=150,
            )

            seqs = outputs.sequences
            if not hasattr(outputs, "encoder_hidden_states"):
                seqs = seqs[:, inputs["input_ids"].shape[-1] :]

            _scores = outputs.sequences_scores.detach().cpu().squeeze().tolist()

            # likelyhoods = []

            scores = None

            scores_bags = {}

            for t in temperatures:
                probs = sum([torch.softmax(p / t, dim=-1) for p in outputs.scores])[
                    0
                ] / sum(1 for token in seqs[0] if token != tokenizer.pad_token_id)

                scores_bags |= compute_scores_bags(probs, t=t)

            for k in range(len(outputs.scores) - 1):
                # try:
                #     likelyhoods_t = torch.gather(
                #         outputs.scores[k], -1, seqs[:, k][:, None]
                #     )
                # except:
                #     print(seqs.shape)
                #     print(k)
                #     exit(0)

                step_scores = compute_scores(outputs.scores[k])

                # likelyhoods.append(-1)

                if not scores:
                    scores = {}
                    for metric_name, tensor in step_scores.items():
                        scores[metric_name] = [tensor]
                else:
                    for metric_name, tensor in step_scores.items():
                        scores[metric_name].append(tensor)

            # likelyhoods = (
            #     (torch.stack(likelyhoods).transpose(0, 1).squeeze())
            #     .detach()
            #     .cpu()
            #     .numpy()
            # )  # n_hyps, seqlen

            mahalanobis_score = {}
            if args.compute_dist_mahalanobis:

                hidden_state = []
                if hasattr(outputs, "decoder_hidden_states"):
                    _hidden = [
                        h[0, 0, :] for h in outputs.decoder_hidden_states[0]
                    ]  # list of embedding: one per layer

                else:
                    _hidden = [h[0, 0, :] for h in outputs.hidden_states[0]]

                hidden_state = _hidden[-1].detach()

                mahalanobis_score = {}
                for size in means.keys():
                    mahalanobis_score |= {
                        f"mahalanobis-{size}-last": compute_mahalanobis(
                            hidden_state, (means[size][0], covs[size][0])
                        )
                        .detach()
                        .cpu()
                        .squeeze()
                        .tolist()
                    }

            set_dist = {}
            if args.compute_dist_set is not None:
                probs = sum([torch.softmax(p, dim=-1) for p in outputs.scores])[
                    0
                ] / sum(1 for token in seqs[0] if token != tokenizer.pad_token_id)

                set_dist = compute_iproj(set_dist, probs, sets)

            scores = {
                metric_name: torch.stack(tensors)
                .transpose(0, 1)
                .squeeze()
                .detach()
                .cpu()
                .numpy()
                for metric_name, tensors in scores.items()
            }

            scores_bags = {
                metric_name: tensor.detach().cpu().numpy()
                for metric_name, tensor in scores_bags.items()
            }

            hyps = tokenizer.batch_decode(seqs, skip_special_tokens=True)
            bleus = [bleu.sentence_score(hypothesis=h, references=[y]) for h in hyps]
            berts = [
                bertscorer.score([h], [y])[2][0].cpu().detach().tolist() for h in hyps
            ]

            label_txt = f"{args.model_name}-{args.dataset_name}-{args.dataset_config}-{args.dataset_split}-{args.shuffle_input}-{sample_id}"

            for sentence_id in range(n_hyps):
                sentence_index = sha256(
                    (label_txt + str(sentence_id)).encode()
                ).hexdigest()

                records_nostep = {
                    "agg_id": sentence_index,
                    "dataset": args.dataset_name,
                    "split": args.dataset_split,
                    "config": f'{args.dataset_config}{"-shuffled" if args.shuffle_input else ""}',
                    "model": args.model_name,
                    "sample_id": sample_id,
                    "sentence_id": sentence_id,
                    "seq_likelyhood": _scores[sentence_id],
                    "bleu": bleus[sentence_id].score,
                    "bertscore_f1": berts[sentence_id],
                    **{
                        metric_name + "_nostep": scores_bags[metric_name]
                        for metric_name in scores_bags.keys()
                    },
                    **set_dist,
                    **mahalanobis_score,
                }

                records_nosteps.append(records_nostep)

                for step in range(len(outputs.scores) - 1):
                    record = {
                        "agg_id": sentence_index,
                        "dataset": args.dataset_name,
                        "split": args.dataset_split,
                        "config": f'{args.dataset_config}{"-shuffled" if args.shuffle_input else ""}',
                        "model": args.model_name,
                        "sample_id": sample_id,
                        "sentence_id": sentence_id,
                        "step": step,
                        # "likelyhood": likelyhoods[sentence_id, step],
                        **{
                            metric_name: tensor[sentence_id, step]
                            for metric_name, tensor in scores.items()
                        },
                    }
                    records.append(record)

    file_name_steps = (
        Path(args.output_dir)
        / f"{args.model_name.replace('/', '-')}-{args.dataset_config}-{args.dataset_name.replace('/', '-')}-{args.dataset_split}-steps.csv"
    )
    file_name_steps.parent.mkdir(exist_ok=True, parents=True)

    pd.DataFrame.from_records(records).to_csv(file_name_steps, float_format="%.2f")

    file_name_nosteps = (
        Path(args.output_dir)
        / f"{args.model_name.replace('/', '-')}-{args.dataset_config}-{args.dataset_name.replace('/', '-')}-{args.dataset_split}-nosteps.csv"
    )
    file_name_nosteps.parent.mkdir(exist_ok=True, parents=True)

    pd.DataFrame.from_records(records_nosteps).to_csv(
        file_name_nosteps, float_format="%.2f"
    )


main()
