import shutil
import time
from argparse import Namespace
from pprint import pprint
from typing import Dict, List, Optional

import numpy as np
import torch
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          EvalPrediction, Trainer, TrainingArguments)

from datasets import Dataset, DatasetDict  # type: ignore
from lm_understanding.metrics import (accuracy, expected_log_score,
                                      tv_distance_of_preds)
from lm_understanding.question_template import TemplateModelBehavior

from .baseline import Baseline


def to_dataset(model_behavior: TemplateModelBehavior) -> DatasetDict:
    datasets = dict()
    for split_name in model_behavior.split_names:
        datasets[split_name] = Dataset.from_dict(
            dict(
                index=list(range(len(model_behavior.questions(split_name)))),
                text=model_behavior.questions(split_name), 
                labels=np.array(model_behavior.answers(split_name))
            )
        )
    return DatasetDict(datasets)


def compute_metrics(preds: EvalPrediction) -> Dict[str, float]:
    logits = torch.tensor(preds.predictions)
    labels: np.ndarray = preds.label_ids  # type: ignore
    probs = logits.softmax(dim=1)[:, 0].numpy()
    return dict(
        accuracy=accuracy(probs, labels),
        log_score=expected_log_score(probs, labels),
        etvd=tv_distance_of_preds(probs, labels)
    )


class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop('labels')
        probs = torch.stack((labels, 1-labels), dim=1)
        out = model(**inputs)
        loss = torch.nn.functional.cross_entropy(out.logits, probs)
        return (loss, out) if return_outputs else loss


class Distillation(Baseline):
    def __init__(self, model_behavior: TemplateModelBehavior, baseline_config: Namespace, *args, **kwargs):
        super().__init__(model_behavior, baseline_config, *args, **kwargs)
        self.n_shot = baseline_config.n_shot
        self.rng = np.random.RandomState(baseline_config.seed)
        self.save_path = f'tmp_trainer/models/{str(int(time.time() * 100))}'
        training_args = TrainingArguments(
            per_device_eval_batch_size=baseline_config.batch_size,
            per_device_train_batch_size=baseline_config.batch_size,
            output_dir=self.save_path,
            num_train_epochs=baseline_config.train_epochs,
            load_best_model_at_end=True,
            evaluation_strategy='epoch',
            save_strategy='epoch',
        )
        tokenizer = AutoTokenizer.from_pretrained(baseline_config.model_name, use_fast=False)
        self.datasets = self.make_datasets(tokenizer)
        train_dataset = self.datasets['train'].train_test_split(test_size=0.1)

        model = AutoModelForSequenceClassification.from_pretrained(baseline_config.model_name, num_labels=2)
        self.trainer = DistillationTrainer(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_dataset["train"],  # type: ignore
            eval_dataset=train_dataset['test'],  # type: ignore
            compute_metrics=compute_metrics,
            args=training_args
        )

    def make_datasets(self, tokenizer) -> DatasetDict:
        datasets = to_dataset(self.model_behavior)

        for split_name in self.model_behavior.split_names:
            make_prompt = lambda examples: dict(prompt=self.prompt(examples['index'], split_name))
            datasets[split_name] = datasets[split_name].map(make_prompt)

        tokenize = lambda examples: tokenizer(examples["prompt"], max_length=512, padding='max_length', truncation=True)
        return datasets.map(tokenize, batched=True)

    def train(self, *args, **kwargs):
        self.trainer.train()
        shutil.rmtree(self.save_path)

    def prompt(self, question_idx: int, split_name: str):
        question = self.model_behavior.questions(split_name)[question_idx]
        few_shot_examples = self.few_shot_examples(question_idx, split_name, self.n_shot)
        examples = [question] + [fse.as_prompt() for fse in few_shot_examples]
        return '[SEP]'.join(examples)

    def split_predictions(self, split_name: str) -> Dict[str, List]:
        out = self.trainer.predict(self.datasets[split_name])  # type: ignore
        assert out.metrics is not None
        pprint(out.metrics)
        return dict(predictions=torch.from_numpy(out.predictions).softmax(dim=1)[:, 0].tolist())
