import framework
from framework.dataset import Lambada, BLiMP, WinoGrande, PIQA
from framework.dataset import HellaSwag, AI2ARC
from typing import Tuple, Any, Dict
from tqdm import tqdm
import torch
from framework.loader.dataset_splitter import DatasetSplitter
import torch.utils.data
import torch.nn.functional as F

import os

from framework.task import args


@args
def a(parser: framework.helpers.ArgumentParser):
    parser.add_argument("-lm.eval.blimp.batch_mul", default=16)
    parser.add_argument("-lm.eval.enabled", default=False)
    parser.add_argument("-lm.eval.lambada.enabled", default=False)
    parser.add_argument("-lm.eval.blimp.enabled", default=False)
    parser.add_argument("-lm.eval.winogrande.enabled", default=False)
    parser.add_argument("-lm.eval.winogrande.batch_mul", default=16)
    parser.add_argument("-lm.eval.hellaswag.enabled", default=False)
    parser.add_argument("-lm.eval.hellaswag.batch_mul", default=16)
    parser.add_argument("-lm.eval.piqa.enabled", default=False)
    parser.add_argument("-lm.eval.piqa.batch_mul", default=16)
    parser.add_argument("-lm.eval.ai2arc.enabled", default=False)
    parser.add_argument("-lm.eval.ai2arc.batch_mul", default=4)
    parser.add_argument("-lm.eval.validate", default=True)


class LMEvalMixin:
    helper: framework.helpers.TrainingHelper

    def create_datasets(self):
        self.state_enabled = True
        self.pad_quantum = None
        # sep = "<STORY_SEP>"
        sep = ""
        self.prob_compare_valid_sets = framework.data_structures.DotDict()

        if self.helper.args.lm.eval.enabled:
            cache_dir = os.path.join(os.environ.get('DATA_PATH', ''), 'cache')
            if self.helper.args.lm.eval.lambada.enabled:
                self.valid_sets.lambada = Lambada(self.train_set.vocabulary, sep=sep, cache_dir=cache_dir)
            if self.helper.args.lm.eval.winogrande.enabled:
                self.prob_compare_valid_sets.winogrande = WinoGrande("valid", self.train_set.vocabulary, sep=sep,cache_dir=cache_dir)
            if self.helper.args.lm.eval.blimp.enabled:
                self.prob_compare_valid_sets.blimp = BLiMP(self.train_set.vocabulary, sep=sep, cache_dir=cache_dir)
            if self.helper.args.lm.eval.hellaswag.enabled:
                self.prob_compare_valid_sets.hellaswag = HellaSwag(self.train_set.vocabulary, cache_dir=cache_dir)
            if self.helper.args.lm.eval.piqa.enabled:
                self.prob_compare_valid_sets.piqa = PIQA(self.train_set.vocabulary, cache_dir=cache_dir)
            if self.helper.args.lm.eval.ai2arc.enabled:
                self.prob_compare_valid_sets.ai2arc = AI2ARC(self.train_set.vocabulary, cache_dir=cache_dir)

    def get_test_logits(self, data: torch.Tensor) -> Tuple[Any, Any]:
        olen = data.shape[0]
        res, _ = self.run_model({"data": data}, 0)
        return res.outputs[:olen-1]

    @torch.no_grad()
    def validate_on_pc_dataset(self, name) -> Tuple[Any, float]:
        self.model.eval()

        ds = self.prob_compare_valid_sets[name]
        print(f"Validating on {ds.__class__.__name__}")

        self.pad_quantum = 128
        if 2*self.pad_quantum > self.prob_compare_valid_sets[name].maxlen:
            self.pad_quantum = self.prob_compare_valid_sets[name].maxlen

        test = ds.start_test()
        for d in tqdm(self.valid_loaders[name]):
            d = self.prepare_data(d)
            prefix_len = d.get("prefix_len")
            good_lprob = test.get_lprobs(self.get_test_logits(d["sentence_good"]), d["sentence_good"], d["good_len"], prefix_len)
            bad_lprobs = []
            for i in range(test.n_ways-1):
                bad_lprobs.append(test.get_lprobs(self.get_test_logits(d[f"sentence_bad_{i}"]), d[f"sentence_bad_{i}"], d[f"bad_len_{i}"], prefix_len))

            test.step(good_lprob, bad_lprobs, d)

        self.model_pad_quantum = None
        self.model.train()
        print(f"{ds.__class__.__name__} accuracy: {test.accuracy}")
        return {f"{name}/{k}": v for k, v in test.plot().items()}



    def validate_on_name(self, name: str, token_ids=None) -> Tuple[Any, float, Dict[str, Any]]:
        print(f"Starting validation on {name}...")
        self.state_enabled = name not in {"lambada"}
        self.pad_quantum = 128 if name == "lambada" else None
        if name == 'val_fair_vs_unfair':
            res = self.validate_on_token(self.valid_sets['val'], self.valid_loaders['val'])
        else:
            res = self.validate_on(self.valid_sets[name], self.valid_loaders[name])
        self.pad_quantum = None
        self.state_enabled = True
        return res

    def create_loaders(self):
        super().create_loaders()
        for name, ds in self.prob_compare_valid_sets.items():
            batch_size = self.helper.args.lm.eval.get(name, {}).get("batch_mul", 1) * self.test_batch_size

            ds = DatasetSplitter(ds, n_partitions=self.helper.dist_env.world_size,
                                 current=self.helper.dist_env.rank)

            kwargs = {}
            if "max_length" in ds[0]:
                kwargs["batch_sampler"] = framework.loader.sampler.BucketedSampler(
                    ds, batch_size, infinite=False, long_first=True, random_order=False, seed=0,
                    length_key_names=["max_length"])
            else:
                kwargs["batch_size"] = batch_size

            self.valid_loaders.update({
                name: torch.utils.data.DataLoader(ds,
                    collate_fn=framework.loader.collate.VarLengthCollate(batch_dim=self.batch_dim),
                    num_workers=self.VALID_NUM_WORKERS,
                    **kwargs)
            })

    def validate(self) -> Tuple[Any, float]:
        is_end = self.helper.state.iter == self.helper.args.stop_after

        if self.helper.args.lm.eval.validate:
            res = super().validate()
        else:
            res = {}

        for name in self.prob_compare_valid_sets:
            print(f"Starting validation on {name}...")
            if is_end or (not self.helper.args.lm.eval.get(name, {}).get("end_only", False)):
                res.update(self.validate_on_pc_dataset(name))
        return res
