import argparse
from collections import defaultdict
from os import truncate
from pathlib import Path

from datasets import load_dataset, load_metric
from infomet import InfoMet, against_uniform, get_measure_fn, measures as Measures
import numpy as np
import pandas as pd
from sacrebleu import BLEU
from sklearn.covariance import MinCovDet
import torch
from torch.utils.data import random_split
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from sklearn.covariance import OAS

from utils.datasets import prep_dataset, prep_model


def parse_args():
    parser = argparse.ArgumentParser("Description")

    # Task configuration
    parser.add_argument(
        "--dataset_name", type=str, help="Dataset name", default="wmt16"
    )
    parser.add_argument(
        "--dataset_config", type=str, help="Dataset config", default="de-en"
    )
    parser.add_argument(
        "--dataset_split", type=str, help="Train, validation or test", default="test"
    )

    parser.add_argument(
        "--size",
        type=int,
        help="Number of sample to use to compute mahalanobis",
        default=100000,
    )
    parser.add_argument(
        "--compute_dist_mahalanobis",
        type=str,
        default=None,
        help="Path to mahalanobis reference file.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Huggingface model name",
        default="Helsinki-NLP/opus-mt-de-en",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Where to store the results.",
        default="data/",
    )
    parser.add_argument(
        "--switch_lang",
        action="store_true",
        default=False,
        help="Wether to swap target lang and source lang",
    )
    return parser.parse_args()


def main():

    import pickle as pk

    args = parse_args()
    device = "cuda"
    model, tokenizer = prep_model(args.model_name)
    model = model.eval().to(device)

    dataset = prep_dataset(
        args.dataset_name,
        args.dataset_config,
        args.dataset_split,
        args.size,
        tokenizer,
        switch_lang=args.switch_lang,
    )
    if args.switch_lang:
        tgt, src = args.dataset_config.split("-")
        args.dataset_config = f"{src}-{tgt}"
    hidden_representations = defaultdict(list)

    # model = model.get_encoder() if hasattr(model, "get_encoder") else model

    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            x, _ = dataset[i]
            inputs = tokenizer(x, return_tensors="pt", truncation=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            outputs = model.generate(
                **inputs,
                num_beams=2,
                output_hidden_states=True,
                output_scores=True,
                return_dict_in_generate=True,
                max_length=128,
            )

            if hasattr(outputs, "decoder_hidden_states"):
                _hidden = [
                    h[0, 0, :] for h in outputs.decoder_hidden_states[0]
                ]  # list of embedding: one per layer

            else:
                _hidden = [h[0, 0, :] for h in outputs.hidden_states[0]]

            for k in range(len(_hidden)):
                hidden_representations[k].append(_hidden[k].detach().cpu())

    file_name = (
        Path(args.output_dir)
        / f"mahalanobis-{args.size}-{args.model_name.replace('/', '-')}-{args.dataset_config}-{args.dataset_name.replace('/', '-')}-{args.dataset_split}.dat"
    )

    hidden_representations = [
        torch.stack(l).detach().cpu().numpy() for _, l in hidden_representations.items()
    ]

    print(len(hidden_representations))
    print(hidden_representations[0].shape)

    mincov_list = [OAS().fit(X=elem) for elem in tqdm(hidden_representations[-1:])]

    file_name.parent.mkdir(exist_ok=True, parents=True)

    with open(file_name, "wb") as fd:
        pk.dump([(mincov.location_, mincov.covariance_) for mincov in mincov_list], fd)


main()
