
from datetime import datetime
import time
import random
from easydict import EasyDict as edict
from tqdm.auto import tqdm
from itertools import chain
import glob
import itertools
import copy
import pandas as pd
import datasets
import os
import numpy as np
import pytorch_lightning as pl
import torch
import tensorboard
import torch.nn.functional as F
import functools
from collections_extended import setlist
from transformers import AutoTokenizer
from datamodule import concat_dm, MetaEvalDataModule
from encoder import Transformer, count_parameters, make_zero_shot, freeze_except, offline_inference, average_by_feature
from args import args

metaeval_path = "."
GLUE_TASKS = ['cola', 'sst2', 'mrpc', 'qqp', 'mnli', 'qnli', 'rte']

datasets_path = f'{metaeval_path}/metaeval/'
base_checkpoint_path = f'{metaeval_path}/checkpoints'
features_path =f'{metaeval_path}/task_features.csv'


df = pd.read_csv(features_path)
df = df[df['split_keys'].map(str) != "['test']"]
all_tasks = list(df.task)

tf = pd.read_csv(features_path)
tf = tf[tf.task.map(lambda x: x in all_tasks)]
tt = tf.set_index("task")["task_type"]
ttng = tf[tf.task.map(lambda x: x not in GLUE_TASKS)].set_index(
    "task")["task_type"]


def get_same_type_ng_tasks(task):
    y = ttng[ttng == tt[task]]
    return list(y.index)


class conditional_adapters:
    training_task = [all_tasks]
    task_embedding_size = 32
    repeat_test = False

class adapters:
    training_task = [[t] for t in all_tasks]
    max_epochs = 3
    learning_rate = 2e-4
    adapter_mode = "single"
    embedding_mode = "freeze"

class baseline_single:
    training_task = [[t] for t in all_tasks]
    max_epochs = 3
    learning_rate=2e-4
    adapter_mode = "single"
    embedding_mode = "freeze"

class baseline_full(baseline_single):
    adapter_mode = "none"
    embedding_mode = "all"
    learning_rate = 2e-5

def process_xp(xp, base_args=None):
    xp["ts_xp"] = str(time.time()).split(".")[0]
    for (k, v) in xp.items():
        assert type(v) != tuple
        if type(v) != list:
            xp[k] = [v]
    return xp


def dirl(xp):
    if type(xp) != list:
        xpl = [xp]
    else:
        xpl = xp
    y = []
    for x in xpl:
        y += dir(x)
    return list({x for x in y if not x.startswith("__")})


def stage_1_args(args):
    d = {
        k: v
        for (k, v) in args.items() if (k in dirl(xp)) and not any([
            x in k for x in [
                "_2", "inference_task_embedding", "ts", "logger",
                "task_labels_name", "finetuning_trained_modules", "task_name",
                "num_labels"
            ]
        ])
    }
    print(d)
    return str(time.time())
    return str(d)


def main(args):
    pl.seed_everything(args.seed)
    args.l_num_labels = None
    if type(args.training_task) == list:
        dm = concat_dm(args, args.training_task)
        args.l_num_labels = dm.l_num_labels
    else:
        dm = MetaEvalDataModule.from_argparse_args(args)
        dm.setup('fit')
    args.num_labels = dm.num_labels
    model = Transformer(
        args, task_labels_name=dm.metadataset.task_labels_name).cuda()
    args.checkpoint_callback = None
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger.log_hyperparams(args)
    trainer.logger.log_hyperparams(count_parameters(model))
    return dm, model, trainer

def loop(xp, base_args=None):
    xp = process_xp(xp, base_args=base_args)
    print(xp)
    for i, values in tqdm(
            list(
                enumerate(
                    itertools.product(
                        *[xp[a] for a in xp])))):
        args=edict(base_args)
        for a, v in zip(xp.keys(), values):
            args[a] = v

        args.seed = 0
        task = args.training_task
        args.xp_name = f"{args.adapter_mode}-{(len(task) if type(task)==list else task)}"
        args.ts = str(time.time())
        yield args

base_args = copy.deepcopy({k:v for (k,v) in args.__dict__.items() if not k.startswith("__")})

xp=conditional_adapters
xp={k:v for (k,v) in xp.__dict__.items() if not k.startswith("__")}

trained = {}

for args in loop(xp, base_args=base_args):

    dm, model, trainer = main(args)
    task = args.training_task
    print(task)
    if stage_1_args(args) not in trained:
        trainer.fit(model, copy.deepcopy(dm))
        checkpoint_path = None
    else:
        pass

    if args.inference != "supervised" and stage_1_args(args) not in trained:
        checkpoint_path = f"{base_checkpoint_path}/{time.time()}"
        model.select_classifier(0)
        trainer.save_checkpoint(checkpoint_path)
        args0 = copy.deepcopy(args)
        args0.inference_task_embedding = args.inference_task_embedding
        trained[stage_1_args(args)] = 1

    evaluated_tasks = {}
    task = args.inference_test_task if args.inference_test_task else args.training_task
    for di, t, in enumerate(task):
        model.model.i[0] = di
        test_dm = MetaEvalDataModule.from_argparse_args(
            edict(args, **{
                "task_name": t,
                "max_samples": args.max_samples_2
            }))
        test_dm.setup('fit')

        if args.inference != "supervised":
            model = GLUETransformer.load_from_checkpoint(
                hparams=copy.deepcopy(args0),
                checkpoint_path=checkpoint_path,
                strict=False)
            model = make_zero_shot(model, t)
            if args.inference_task_embedding == "average":
                model.F = model.F * 0
            if args.inference_task_embedding == "offline":
                H = model.T.weight.cpu().detach().numpy()
                h = offline_inference(args, H, t)
                model.T.T.weight[0] = torch.tensor(h[0]).cuda()
            if args.inference_task_embedding == "average_by_feature":
                H = model.T.weight.cpu().detach().numpy()
                h = average_by_feature(args, H, t)
                model.T.T.weight[0] = torch.tensor(h[0]).cuda()

        model.test_task = t

        if args.inference == "fine-tune" and args.finetuning_trained_modules and args.max_samples_2:
            test_dm.setup("fit")
            trainer = pl.Trainer.from_argparse_args(args0)
            MTD = copy.deepcopy(model.T.weight.data)
            MTLEAF = torch.nn.Embedding(*MTD.shape)
            MTLEAF.weight.data = MTD
            model.T.weight = MTLEAF.weight
            model.T.T.weight = MTLEAF.weight
            model = freeze_except(
                model, trained_modules=args.finetuning_trained_modules)
            trainer.max_epochs = args.max_epochs_2
            trainer.fit(model, test_dm)
        trainer.logger.log_hyperparams({k: getattr(args, k) for k in dirl(xp)})
        trainer.logger.log_hyperparams(count_parameters(model))
        trainer.logger.log_hyperparams({"test_task": t})

        trainer.test(model, test_dm.test_dataloader())
        evaluated_tasks[t] = 1
