from adapters import *
from easydict import EasyDict as edict
import datasets
import torchelie
import copy
import time
import pandas as pd
import numpy as np
from transformers import AdamW, get_linear_schedule_with_warmup


def count_parameters(m):
    d = {}
    for m, name in zip(
        [m, m.adapters, m.model.auto.embeddings, m.model.auto],
        ["", "adapters_", "emb_", "encoder_"],
    ):
        d = {
            **d,
            f"{name}trainable_parameters": sum(
                p.numel() for p in m.parameters() if p.requires_grad
            ),
            f"{name}total_parameters": sum(p.numel() for p in m.parameters()),
            f"{name}fixed_parameters": sum(
                p.numel() for p in m.parameters() if not p.requires_grad
            ),
        }
    return d


def get_metric():
    f1 = datasets.load_metric("f1", average="macro")
    accuracy = datasets.load_metric("accuracy")

    def compute(predictions, references):
        f1_score = f1.compute(
            predictions=predictions, references=references, average="macro"
        )
        acc_score = accuracy.compute(predictions=predictions, references=references)
        return {**f1_score, **acc_score}

    metric = copy.deepcopy(f1)
    metric.compute = compute
    return metric


def weight_init(m, std=1e-3, k=1):
    std = std / (k ** 0.5)
    m.weight.data.normal_(0.0, std).clamp_(-2 * std, 2 * std)
    if hasattr(m, "bias"):
        m.bias.data.zero_()


def map_label(label_name, task_name=""):
    label_name = str(label_name)
    is_paraphrase = "equivalent"
    is_not_paraphrase = "not-equivalent"
    label_mapping = {
        "not_duplicate": is_not_paraphrase,
        "duplicate": is_paraphrase,
        "not_equivalent": is_not_paraphrase,
        "equivalent": is_paraphrase,
    }
    if label_name in label_mapping:
        return label_mapping[label_name]
    if task_name == "paws":
        if label_name == "0":
            return is_not_paraphrase
        if label_name == "1":
            return is_paraphrase
        raise ValueError("PAWS value error")
    if any(char.isdigit() for char in label_name):
        label_name = f"{task_name}-{label_name}"
    return str(label_name).lower()


def make_zero_shot(model, t):
    model.model.i[0] = model.model.i[0] * 0

    clf = model.classifiers[0]
    clf.task_name = t
    clf.labels_name = model.task_labels_name[clf.task_name]
    L0 = torch.nn.Linear(768, len(clf.labels_name)).cuda()
    L0.apply(weight_init)
    clf.weight = L0.weight
    clf.bias = L0.bias
    model.classifiers[0] = clf

    model.l_num_labels[0] = len(clf.labels_name)
    model.T.T.weight[0] = model.T.weight.mean(axis=0)

    model.build_features([t])
    return model


def offline_inference(args, H, t):
    from sklearn.linear_model import Ridge
    from sklearn.preprocessing import StandardScaler

    tf = pd.read_pickle("/cw/working-arwen/damien/datasets/metaeval/df_features.pck")[
        "all"
    ].to_dict()
    X = np.vstack([tf[x] for x in args.training_task])
    scl = StandardScaler()
    reg = Ridge()
    reg.fit(X, scl.fit_transform(H))
    h = scl.inverse_transform(reg.predict(np.array(tf[t]).reshape(1, -1)))
    return h


def average_by_feature(args, H, t):
    tf = pd.read_csv(args.features_path)
    tt = tf.set_index("task")["task_type"]
    h = H[
        [
            k
            for (k, v) in dict(enumerate(args.training_task)).items()
            if v in tt[tt == tt[t]].index
        ]
    ].mean(axis=0)
    return h


class TiedLinear(torch.nn.Module):
    def __init__(
        self,
        W=None,
        clf=None,
        L=None,
        labels_name=None,
        all_labels=None,
        task_name=None,
        device=None,
    ):
        super(TiedLinear, self).__init__()
        self.clf = clf
        self.weight = self.clf.weight
        self.bias = self.clf.bias
        self.L = L.cuda()
        self.labels_name = labels_name
        self.all_labels = all_labels
        self.task_name = task_name
        self.device = device

    def forward(self, inputs):
        # print(self.weight.shape)
        self.clf.weight = self.weight
        self.clf.bias = self.bias
        self.labels_weight = torch.zeros_like(self.clf.weight, device=self.device)
        for k, label_name in enumerate(self.labels_name):
            i = self.all_labels[map_label(label_name, self.task_name)]
            i = torch.tensor(i, device=self.device)  # .cuda()
            k = torch.tensor(k, device=self.device)  # .cuda()
            self.labels_weight[k] = self.L(i)
        return self.clf(inputs) + F.linear(inputs, self.labels_weight)


class ReturnsBias(torch.nn.Module):
    def __init__(self, weight, bias):
        super(ReturnsBias, self).__init__()
        self.weight = weight
        self.bias = bias

    def forward(self, inputs):
        return self.bias


class TaskEmbedding(torch.nn.Module):
    def __init__(self, T, i, features, config=None):
        super(TaskEmbedding, self).__init__()
        self.T = T
        self.i = i[0]
        self.features = features
        self.L1 = torch.nn.Linear(self.features.shape[1], T.embedding_dim).cuda()
        self.L1.apply(weight_init)
        self.num_embeddings = self.T.num_embeddings
        self.weight = self.T.weight
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.use_task_features = config.use_task_features

    def forward(self, x):
        tf = 0
        if self.use_task_features:
            tf = self.L1(self.features[x])
        return self.dropout(self.T(x)) + tf


class Classifier(torch.nn.Module):
    def __init__(self, config):
        super(Classifier, self).__init__()
        self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels).cuda()
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.strip_pooler = config.strip_pooler

    def forward(self, features, **kwargs):
        if self.strip_pooler:
            x = features[:, 0, :]
        else:
            x = features[:, 0]
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class Transformer(pl.LightningModule):
    def __init__(self, hparams, task_labels_name=None):
        super().__init__()
        self.total_steps = 0
        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(
            hparams.model_name_or_path, num_labels=hparams.num_labels
        )
        if hparams.l_num_labels:
            self.l_num_labels = hparams.l_num_labels
        else:
            self.l_num_labels = [hparams.num_labels]
        self.task_labels_name = task_labels_name
        self.training_task = hparams.training_task
        hparams.num_tasks = len(self.l_num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            hparams.model_name_or_path, config=self.config
        )
        for (k, v) in hparams.items():
            self.config.__setattr__(k, v)
        self.model.config = self.config

        m_name = self.model.config.model_type
        self.hparams = hparams
        self.model.classifier = Classifier(self.config)

        m_ = getattr(self.model, m_name)
        self.model.auto = m_

        if "strip_pooler" in dir(self.config) and self.config.strip_pooler:
            self.model.auto.pooler = torch.nn.Identity()

        self.task_embedding_size = self.config.task_embedding_size

        self.T = torch.nn.Embedding(
            self.config.num_tasks, self.task_embedding_size
        ).cuda()
        self.model.i = torch.tensor([0]).cuda()

        self.T.apply(weight_init)
        self.task_features = pd.read_pickle(
            "/cw/working-arwen/damien/datasets/metaeval/df_features.pck"
        )["all"].to_dict()
        self.build_features(self.training_task)
        self.T = TaskEmbedding(
            T=self.T, i=self.model.i, features=self.F, config=self.config
        )

        self.test_task = self.hparams.training_task
        assert self.config.adapter_mode in [
            "none",
            "single",
            "gated",
            "bilinear",
            "random",
        ]

        if self.config.adapter_mode == "none":
            self.adapters = torch.nn.ModuleList()

        if self.config.adapter_mode == "single":
            self.model = freeze_model(self.model)
            self.adapters = torch.nn.ModuleList(
                [
                    Adapter(self.config).cuda()
                    for _ in range(self.config.num_adapter_layers * 2)
                ]
            )
            self.model = add_adapters(self.model, self.adapters)

        if self.config.adapter_mode == "bilinear":
            self.model = freeze_model(self.model)
            self.adapters = torch.nn.ModuleList(
                [
                    HyperAdapter(self.config, self.T, i=self.model.i).cuda()
                    for _ in range(self.config.num_adapter_layers * 2)
                ]
            )
            self.model = add_adapters(self.model, self.adapters)

        if self.config.adapter_mode == "gated":
            print("gated adapter")
            self.model = freeze_model(self.model)
            self.adapters = torch.nn.ModuleList(
                [
                    GatedAdapter(self.config, self.T, i=self.model.i).cuda()
                    for _ in range(self.config.num_adapter_layers * 2)
                ]
            )
            self.model = add_adapters(self.model, self.adapters)
        if self.config.adapter_mode == "random":
            print("random adapter")
            self.model = freeze_model(self.model)
            self.adapters = torch.nn.ModuleList(
                [
                    RandomAdapter(self.config, self.T, i=self.model.i).cuda()
                    for _ in range(self.config.num_adapter_layers * 2)
                ]
            )
            self.model = add_adapters(self.model, self.adapters)

        self.model = add_embeddings(self.model, T=self.T)
        self.model = add_cln(self.model, T=self.T, i=self.model.i)

        self.metric = get_metric()
        print("l_num_label:", self.l_num_labels)
        self.classifiers = torch.nn.ModuleList(
            [torch.nn.Linear(768, n).cuda() for n in self.l_num_labels]
        )
        if self.config.share_labels:
            self.share_labels(self.config)

        self.parameters_count = count_parameters(self)
        self.log_dict(self.parameters_count)
        print(self.parameters_count)
        if hparams.l_num_labels:
            self.select_classifier(0)

    def build_features(self, tasks):
        self.F = (
            torch.from_numpy(np.array([self.task_features[t] for t in tasks]))
            .float()
            .cuda()
            .detach()
        )

    def save_embedding(self):
        if (
            self.config.adapter_mode != "single"
            and self.config.version != -1
            and type(self.training_task == list)
        ):
            H = self.T.T.weight.cpu().detach().numpy()
            dfp = pd.DataFrame([[x] for x in H], columns=["H"])
            dfp["name"] = self.training_task
            dfp["args"] = [str(self.hparams.hparams)] * len(dfp)
            dfp["i"] = range(len(dfp))
            t = str(time.time()).split(".")[0]
            H = np.vstack(dfp.H.values)
            tes = self.config.task_embedding_size
            print(f"saving embedding {t}")
            dfp[["name", "i", "H", "args"]].to_pickle(
                f"/cw/working-arwen/damien/datasets/metaeval/dfp-{tes}-{t}.pickle"
            )

    def share_labels(self, config):
        print("sharing labels")
        l = []
        for (task_name, labels_name) in self.task_labels_name.items():
            for label_name in labels_name:
                l += [map_label(label_name, task_name=task_name)]
        all_labels = set(l)
        all_labels = dict(zip(all_labels, range(len(all_labels))))

        self.L = torch.nn.Embedding(len(all_labels), self.config.hidden_size).cuda()
        self.L.apply(weight_init)
        for i, task_name in enumerate(self.training_task):
            labels_name = self.task_labels_name[task_name]
            clf = self.classifiers[i]
            if config.share_labels == "hard":
                clf = ReturnsBias(weight=clf.weight, bias=clf.bias)
            clf_2 = TiedLinear(
                L=self.L,
                clf=clf,
                device=torch.device("cuda:0"),
                labels_name=labels_name,
                all_labels=all_labels,
                task_name=task_name,
            )
            self.classifiers[i] = clf_2

    def select_classifier(self, di):
        self.model.classifier.out_proj = self.classifiers[di]
        self.model.num_labels = self.l_num_labels[di]

    def forward(self, **inputs):
        if "dataset_index" in inputs:
            di = inputs.pop("dataset_index")[0].cpu().item()
            self.model.i[0] = di
            # print("update i")
        else:
            di = self.model.i[0].cpu().item()
        # print(di)

        labels = inputs["labels"]
        mask = labels == -1
        labels[mask] = 0
        inputs["labels"] = labels
        # print(f"<{di}, {self.model.i}>")
        self.select_classifier(di)
        # print(self.model.num_labels,di, inputs["labels"])
        # print({k:v.shape for (k,v) in inputs.items()})
        self.log("i", self.model.i.item())
        inputs = {k: v.cuda() for (k, v) in inputs.items()}
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        # print(self.model.i,end=";")
        outputs = self(**batch)
        loss = outputs[0]
        if self.hparams.loss_cap:
            loss_norm = loss.detach().clone()
            if loss_norm > self.hparams.loss_cap:
                loss = loss * self.hparams.loss_cap / loss_norm
        self.log("loss", loss)
        self.log(f"loss_{self.model.i.item()}", loss)
        n = self.T.num_embeddings // 2
        if n:
            stability = (
                sum(
                    (
                        self.T(torch.tensor(i, device=self.device))
                        - self.T(torch.tensor(i + n, device=self.device))
                    ).norm()
                    for i in range(n)
                )
                / n
            )
            stability = stability.detach() / self.T.weight.var(dim=0).mean()
            self.log("stability", stability.detach())

        return {"loss": loss}

    def validation_step(self, batch, batch_idx, dataloader_idx=0, stage="val"):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        loss_i = {f"{stage}_loss_{self.model.i.item()}": val_loss}
        # self.log(f"val_loss_{self.model.i.item()}", val_loss)
        return {"loss": val_loss, "preds": preds, "labels": labels, **loss_i}

    def validation_epoch_end(self, outputs, stage="val"):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log(f"{stage}_loss", loss, prog_bar=True)

        self.log(f"{stage}_loss_", loss, prog_bar=True)
        print("1", self.model.i)
        print("2", self.model.i.item())
        self.log(f"{stage}_loss_{self.model.i.item()}", loss)

        metrics = self.metric.compute(predictions=preds, references=labels)
        self.log_dict({f"{stage}_{k}": v for (k, v) in metrics.items()}, prog_bar=True)
        if type(self.test_task) != list:
            i = self.test_task
        else:
            i = f"L_{len(self.test_task)}"
        print("LOGDICT", {f"{stage}_{k}_{i}": v for (k, v) in metrics.items()})
        self.log_dict(
            {f"{stage}_{k}_{i}": v for (k, v) in metrics.items()}, prog_bar=True
        )
        return loss

    def training_epoch_end(self, outputs, stage="training"):
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log(f"{stage}_loss", loss, prog_bar=True)
        if self.current_epoch == self.config.max_epochs - 1:
            self.save_embedding()

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self.validation_step(batch, batch_idx, dataloader_idx=0, stage="test")

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs, stage="test")

    def setup(self, stage):
        if stage == "fit":
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.train_dataloader()

            # Calculate total steps
            self.total_steps = (
                (
                    len(train_loader.dataset)
                    // (self.hparams.train_batch_size * max(1, self.hparams.gpus))
                )
                // self.hparams.accumulate_grad_batches
                * float(self.hparams.max_epochs)
            )

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_prop * self.total_steps,
            num_training_steps=self.total_steps,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]


def freeze_except(model, trained_modules={}):
    def freeze_module(m):
        for param in m.parameters():
            param.requires_grad = False
        return m

    assert all(
        [
            x in ["all", "clf_bias", "clf_weight", "task_embedding"]
            for x in trained_modules
        ]
    )

    if not trained_modules:
        return model

    if "all" in trained_modules:
        return model

    model = freeze_module(model)

    for n, p in model.named_parameters():
        n = n.lower().replace("_", "")
        if "layernorm" in n or "cls" in n:
            p.requires_grad = True
    for c in model.classifiers:
        if "clf_bias" in trained_modules:
            c.bias.requires_grad = True
        if "clf_weight" in trained_modules:
            c.weight.requires_grad = True
    if "cls_weight" in trained_modules:
        model.L.requires_grad = True
    if "task_embedding" in trained_modules:
        model.T.T.requires_grad = True
    return model
