import copy
import functools
import glob
import itertools
import os
import time
import warnings
from argparse import ArgumentParser
from datetime import datetime
from itertools import chain
from typing import Optional

import datasets
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import sqlitedict
import torch
import torch.nn.functional as F
import torchelie
from utils import BatchSchedulerSampler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data.dataset import ConcatDataset
from tqdm.notebook import tqdm as tqdm
from transformers import AdamW
from transformers import AutoConfig
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from transformers import glue_compute_metrics

GLUE_TASKS = ["cola", "sst2", "mrpc", "qqp", "stsb", "mnli", "qnli", "rte", "wnli"]

def label_dataset(dataset, dataset_index):
    def g_i_d(g_i, dataset_index):
        def f(self, *args, **kwargs):
            y = g_i(self, *args, **kwargs)
            y["dataset_index"] = dataset_index
            return y

        return f

    dataset._getitem = g_i_d(dataset._getitem, dataset_index)
    return dataset


class MetaDataset:
    def __init__(self, metaeval_path):
        self.df = pd.read_pickle(f"{metaeval_path}/task_features.pck")
        self.task_num_labels = self.df.set_index("task")["num_labels"].to_dict()
        self.task_text_fields = self.df.set_index("task")["text_fields"].to_dict()
        self.task_tuple = (
            self.df[["task", "dataset"]]
            .set_index("task")["dataset"]
            .map(lambda x: tuple(x.split("/")))
            .to_dict()
        )
        self.task_label_fields = self.df.set_index("task")["label_fields"].to_dict()
        self.task_splits_mapping = self.df.set_index("task")["splits_mapping"].to_dict()
        self.task_labels_name = self.df.set_index("task")["labels_name"].to_dict()


class MetaEvalDataModule(pl.LightningDataModule):
    loader_columns = [
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
        "dataset_index",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = None,
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        dataset_index: int = 0,
        max_samples=1000,
        max_test_samples=10000,
        num_workers=0,
        metaeval_path=None,
        seed=0,
        glue_max_samples=None,
        tokenizer_name_or_path=None,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        if not tokenizer_name_or_path:
            self.tokenizer_name_or_path = model_name_or_path
        else:
            self.tokenizer_name_or_path = tokenizer_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.max_samples = max_samples
        self.glue_max_samples = glue_max_samples
        self.num_workers = num_workers
        self.max_test_samples = max_test_samples
        self.metaeval_path = metaeval_path
        self.metadataset = MetaDataset(metaeval_path)
        self.seed = seed
        self.text_fields = ["sentence1","sentence2"]
        self.task_label_fields=["label"]
        self.num_labels = self.metadataset.task_num_labels[task_name]
        self.task_splits_mapping = self.metadataset.task_splits_mapping[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer_name_or_path, use_fast=True
        )
        self.labels_name = self.metadataset.task_labels_name[task_name]
        if self.task_name in GLUE_TASKS and self.glue_max_samples:
            self.max_samples = self.glue_max_samples
        print(self.task_name)

    def setup(self, stage):

        self.dataset = datasets.DatasetDict()
        for p in glob.glob(f"{self.metaeval_path}/metaeval/{self.task_name}/*"):
            df=pd.read_csv(p)
            for c in df.columns:
                if "sen" in c:
                    df[c]=df[c].map(str)
            self.dataset[
                p.split("/")[-1].replace(".csv", "")
            ] = datasets.Dataset.from_pandas(df)

        if list(self.dataset.keys()) == ["train"]:
            train_valtest = self.dataset["train"].train_test_split(seed=0)
            val_test = train_valtest["test"].train_test_split(seed=0)
            self.dataset["train"] = train_valtest["train"]
            self.dataset["validation"] = val_test["train"]
            self.dataset["test"] = val_test["test"]

        if (
            "train" in self.dataset.keys()
            and "test" in self.dataset.keys()
            and "validation" not in self.dataset.keys()
        ):
            t_v = self.dataset["train"].train_test_split(seed=0)
            self.dataset["train"] = t_v["train"]
            self.dataset["validation"] = t_v["test"]

        if self.task_name in GLUE_TASKS + ["hope_edi", "hans"]:
            self.dataset["test"] = self.dataset["validation"]
            t_v = self.dataset["train"].train_test_split(seed=0)
            self.dataset["train"] = t_v["train"]
            self.dataset["validation"] = t_v["test"]
            self.raw_test = self.dataset["test"]

        # Sample train and val
        for split in ["train", "validation", "test"]:
            nb_samples = self.max_samples
            if split == "validation":
                nb_samples = int(self.max_samples * 0.2)
            if split == "test":
                nb_samples = self.max_test_samples
            ds = copy.deepcopy(self.dataset[split])
            if len(ds) > nb_samples and nb_samples:
                ds = ds.train_test_split(train_size=nb_samples, seed=self.seed)["train"]
                self.dataset[split] = ds
        self.train_sampler = None
        self.test_sampler = None
        self.val_sampler = None

        assert "train" in self.dataset.keys()
        assert "validation" in self.dataset.keys()
        assert "test" in self.dataset.keys()

        for split in self.dataset.keys():
            if stage == "test" and split != "test":
                continue
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
            )

            self.columns = [
                c for c in self.dataset[split].column_names if c in self.loader_columns
            ]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + [
            "test"
        ]

    kw_dataloaders = {"drop_last": True, "pin_memory": False}

    def train_dataloader(self):
        return DataLoader(
            self.dataset["train"],
            num_workers=self.num_workers,
            pin_memory=True,
            batch_size=self.train_batch_size,
            sampler=self.train_sampler,
            drop_last=False,
        )

    def val_dataloader(self):
        print("val_dataloader")
        return DataLoader(
            self.dataset["validation"],
            num_workers=self.num_workers,
            pin_memory=True,
            batch_size=self.eval_batch_size,
            sampler=self.val_sampler,
            drop_last=False,
        )

    def test_dataloader(self):
        print("test_dataloader")
        return DataLoader(
            self.dataset["test"],
            num_workers=self.num_workers,
            pin_memory=True,
            batch_size=self.eval_batch_size,
            sampler=self.test_sampler,
            drop_last=False,
        )

    def convert_to_features(self, example_batch, indices=None):

        # Either encode single sentence or sentence pairs
        if self.text_fields[1] in example_batch:
            texts_or_text_pairs = list(
                zip(
                    example_batch[self.text_fields[0]],
                    example_batch[self.text_fields[1]],
                )
            )
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]
        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs,
            max_length=self.max_seq_length,
            pad_to_max_length=True,
            truncation=True,
        )

        # Rename label to labels to make it easier to pass to model forward
        label_field = self.task_label_fields[0]
        features["labels"] = example_batch[label_field]

        features["text_a"] = example_batch[self.text_fields[0]]
        if self.text_fields[1] in example_batch:
            features["text_b"] = example_batch[self.text_fields[1]]
        else:
            features["text_b"] = ["" for _ in example_batch[self.text_fields[0]]]
        return features


def concat_dm(args, task_names, test_mode=False):
    l_dm = []
    l_num_labels = []
    l_labels_name = []
    seeds = [
        sum([a == task_names[i] for a in task_names[:i]])
        for i in range(len(task_names))
    ]
    for seed, task_name in zip(seeds, copy.deepcopy(task_names)):
        args.task_name = task_name
        args.seed = seed
        dm = MetaEvalDataModule.from_argparse_args(args)
        dm.setup("fit")
        l_dm += [dm]
        l_num_labels += [dm.num_labels]
        l_labels_name += [dm.labels_name]

    l_dm = copy.deepcopy(l_dm)
    dm = copy.deepcopy(l_dm[0])
    for split in ["train", "validation", "test"]:
        if test_mode:
            if split != "test":
                continue
            if split == "test":
                return dm
        datasets = []
        for dm in l_dm:
            dataset = dm.dataset[split]
            datasets += [dataset]
        labeled_datasets = [
            label_dataset(d, index) for (index, d) in enumerate(datasets)
        ]

        if test_mode:
            l_test_dm = []
            for d in labeled_datasets:
                dm.dataset[split] = d
                l_test_dm += [copy.deepcopy(dm)]
            return l_test_dm

        concat_dataset = ConcatDataset(labeled_datasets)
        dm.dataset[split] = concat_dataset
    dm.train_sampler = BatchSchedulerSampler(
        dataset=dm.dataset["train"],
        batch_size=dm.train_batch_size,
        args=args,
        test_mode=False,
    )
    dm.val_sampler = BatchSchedulerSampler(
        dataset=dm.dataset["validation"],
        batch_size=dm.eval_batch_size,
        args=args,
        test_mode=False,
    )
    dm.test_sampler = BatchSchedulerSampler(
        dataset=dm.dataset["test"],
        batch_size=dm.eval_batch_size,
        args=args,
        test_mode=True,
    )
    dm.l_num_labels = l_num_labels
    dm.l_labels_name = l_labels_name
    return dm
