import argparse
from os import truncate
import os
from pathlib import Path

from datasets import load_dataset, load_metric
from infomet import InfoMet, against_uniform, get_measure_fn, measures as Measures
import numpy as np
import pandas as pd
from sacrebleu import BLEU
from sklearn.covariance import MinCovDet
import torch
from torch.utils.data import random_split
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

from utils.datasets import prep_dataset, prep_model


def parse_args():
    parser = argparse.ArgumentParser("Description")

    # Task configuration
    parser.add_argument(
        "--dataset_name", type=str, help="Dataset name", default="wmt16"
    )
    parser.add_argument(
        "--dataset_config", type=str, help="Dataset config", default="de-en"
    )
    parser.add_argument(
        "--dataset_split", type=str, help="Train, validation or test", default="test"
    )

    parser.add_argument(
        "--mk_mahalanobis",
        action="store_true",
        default=False,
        help="if we should compute covariance matrix and mean for future mahalanobis distance",
    )
    parser.add_argument(
        "--compute_dist_mahalanobis",
        type=str,
        default=None,
        help="Path to mahalanobis reference file.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Huggingface model name",
        default="Helsinki-NLP/opus-mt-de-en",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Where to store the results.",
        default="data/",
    )
    parser.add_argument(
        "--size",
        type=int,
        help="Number of sample to store",
        default=2000,
    )
    parser.add_argument(
        "--switch_lang",
        action="store_true",
        default=False,
        help="Wether to swap target lang and source lang",
    )
    return parser.parse_args()


def main():

    import pickle as pk

    args = parse_args()

    device = "cuda"

    model, tokenizer = prep_model(args.model_name)
    model = model.eval().to(device)

    dataset = prep_dataset(
        args.dataset_name,
        args.dataset_config,
        args.dataset_split,
        args.size,
        tokenizer,
        switch_lang=args.switch_lang,
    )
    if args.switch_lang:
        tgt, src = args.dataset_config.split("-")
        args.dataset_config = f"{src}-{tgt}"
    distribs = []

    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            x, y = dataset[i]
            inputs = tokenizer(x, return_tensors="pt", truncation=True, max_length=256)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model.generate(
                **inputs,
                num_beams=1,
                num_return_sequences=1,
                output_scores=True,
                return_dict_in_generate=True,
                output_hidden_states=(args.compute_dist_mahalanobis is not None),
                max_length=256,
            )

            probs = [torch.softmax(p, dim=-1) for p in outputs.scores]
            probs = (sum(probs) / len(outputs.scores)).squeeze()
            distribs.append(probs.detach().cpu().numpy())

    file_name = (
        Path(args.output_dir)
        / f"distribs_signature-{args.model_name.replace('/', '-')}-{args.dataset_config}-{args.dataset_name.replace('/', '-')}-{args.dataset_split}.dat"
    )

    file_name.parent.mkdir(exist_ok=True, parents=True)

    with open(file_name, "wb") as fd:
        pk.dump(np.array(distribs), fd)


main()
