import copy
import datasets
import functools
import glob
import itertools
import numpy as np
import os
import pandas as pd
import pytorch_lightning as pl
import sqlitedict
import torch
import torch.nn.functional as F
import torch_optimizer as optim
import torchelie
from argparse import ArgumentParser
from datetime import datetime
from itertools import chain
from pytorch_lightning.loggers import NeptuneLogger, CometLogger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm as tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoTokenizer,
)
from scipy.linalg import qr
from sklearn.utils.extmath import randomized_svd

from typing import Optional

k_printer = print
print(pl.__version__)  # 1.0.4


def flatten(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def weight_init(m, std=1e-3, k=1):
    std = std / (k ** 0.5)
    m.weight.data.normal_(0.0, std).clamp_(-2 * std, 2 * std)
    m.bias.data.zero_()


class Bilinear(torch.nn.Module):
    def __init__(self, context_size, input_size, output_size):
        super(Bilinear, self).__init__()
        self.weight = torch.randn(
            context_size, input_size, output_size, requires_grad=True
        ).T.cuda()
        self.bias = torch.zeros(output_size, requires_grad=True).cuda()

    def forward(self, t, inputs):
        return inputs @ (self.weight @ t).T + self.bias


class ResidualEmbedding(torch.nn.Module):
    def __init__(self, e1, e2, trained_tokens=None):
        super(ResidualEmbedding, self).__init__()
        self.e1 = e1
        self.e2 = e2
        self.trained_tokens = trained_tokens

    def forward(self, x):
        xe1 = self.e1(x)
        if self.trained_tokens != "all":
            x = copy.deepcopy(x)
            for i, k in enumerate(self.trained_tokens):
                x[x == k] = i + 1
            x[x > i + 1] = 0
        return xe1 + self.e2(x)


class CLSEmbedding(torch.nn.Module):
    def __init__(self, config):
        super(CLSEmbedding, self).__init__()
        embedding_size = (
            config.embedding_size
            if "embedding_size" in config.__dict__
            else config.hidden_size
        )
        self.cls = torch.nn.Embedding(1, embedding_size).cuda()

    def forward(self, x):
        x[:, 0, :] = x[:, 0, :] + self.cls(torch.tensor(0).cuda())
        return x


class HyperEmbedding(torch.nn.Module):
    def __init__(self, config, T, e1, trained_tokens=None, i=None):
        super(HyperEmbedding, self).__init__()
        embedding_size = (
            config.embedding_size
            if "embedding_size" in config.__dict__
            else config.hidden_size
        )
        self.trained_tokens = trained_tokens
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.e1 = e1
        self.T = T
        self.i = i[0]
        self.task_embedding_activation = config.task_embedding_activation()
        # self.LN=torch.nn.LayerNorm(embedding_size)
        self.projections = torch.nn.ModuleList(
            [
                torch.nn.Linear(config.task_embedding_size, embedding_size)
                for _ in trained_tokens
            ]
        )
        for p in self.projections:
            p.apply(weight_init)

    def forward(self, x):
        # print("he",self.i,end=",")
        xe1 = self.e1(x)
        xe2 = torch.zeros_like(xe1, requires_grad=False)
        t = self.T(self.i)
        t = self.task_embedding_activation(t)
        for i, k in enumerate(self.trained_tokens):
            xe2[x == k] = self.projections[i](self.dropout(t))
        return xe1 + xe2


class Adapter(torch.nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.activation = config.adapter_activation()
        self.L1 = torch.nn.Linear(config.hidden_size, config.adapter_size)
        self.L2 = torch.nn.Linear(config.adapter_size, config.hidden_size)
        for layer in [self.L1, self.L2]:
            layer.apply(weight_init)

    def forward(self, inputs):
        return inputs + self.L2(self.activation(self.L1(inputs)))


class HyperAdapter(torch.nn.Module):
    def __init__(self, config, T, i=None):
        super().__init__()
        self.task_embedding_activation = config.task_embedding_activation()
        self.T = T
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.i = i[0]
        self.activation = config.adapter_activation()
        self.L1 = Bilinear(
            config.task_embedding_size, config.hidden_size, config.adapter_size
        ).cuda()
        self.L2 = Bilinear(
            config.task_embedding_size, config.adapter_size, config.hidden_size
        ).cuda()
        for layer in [self.L1, self.L2]:
            layer.apply(lambda x: weight_init(x, k=config.task_embedding_size))

    def forward(self, inputs):
        # print("ha",self.i,end=",")
        t = self.T(self.i)
        t = self.task_embedding_activation(t)
        t = self.dropout(t)
        h1 = self.L1(t, inputs)
        h2 = self.L2(t, self.activation(h1))
        return h2 + inputs


class ConditionalLayerNorm(torch.nn.Module):
    def __init__(self, config, T, i=None, base_LN=None):
        super().__init__()
        self.task_embedding_activation = config.task_embedding_activation()
        self.T = T
        self.i = i[0]
        # self.LN = torch.nn.LayerNorm(config.hidden_size,elementwise_affine=False)
        self.LN = torchelie.utils.freeze(base_LN)
        self.L1 = torch.nn.Linear(config.task_embedding_size, config.hidden_size)
        self.L2 = torch.nn.Linear(config.task_embedding_size, config.hidden_size)

        for layer in [self.L1, self.L2]:
            layer.apply(lambda x: weight_init(x, k=1))

    def forward(self, inputs):
        # print("ga",self.i,end=",")
        t = self.T(self.i)
        t = self.task_embedding_activation(t)
        a = self.L1(t) + 1
        b = self.L2(t)
        return self.LN(inputs) * a + b


class GatedAdapter(Adapter):
    def __init__(self, config, T, i=None):
        super().__init__(config)
        self.task_embedding_activation = config.task_embedding_activation()
        self.T = T
        self.gate = torch.nn.Linear(config.task_embedding_size, config.adapter_size)
        self.i = i[0]
        self.gate_activation = torch.nn.Sigmoid()

        for layer in [self.L1, self.L2, self.gate]:
            layer.apply(lambda x: weight_init(x, k=1))

    def forward(self, inputs):
        # print("ga",self.i,end=",")
        t = self.T(self.i)
        t = self.task_embedding_activation(t)
        g = self.gate(t)
        g = self.gate_activation(g)
        return inputs + self.L2(g * self.activation(self.L1(inputs)))


class RandomAdapter(Adapter):
    def __init__(self, config, T, i=None):
        super().__init__(config)
        self.task_embedding_activation = config.task_embedding_activation()
        self.T = T
        self.i = i[0]
        # q, _ = np.linalg.qr(np.random.random((config.hidden_size, config.task_embedding_size)))
        q = np.random.random((config.hidden_size, config.hidden_size))
        U, Sigma, VT = randomized_svd(
            q,
            random_state=config.seed,
            n_components=config.task_embedding_size,
            n_iter=5,
        )
        self.U = torch.from_numpy(U * Sigma * 1e-3).float().cuda().detach()
        self.VT = torch.from_numpy(VT).float().cuda().detach()

    def forward(self, inputs):
        # print("ga",self.i,end=",")
        t = self.T(self.i)
        t = self.task_embedding_activation(t) + 1
        return inputs + inputs @ ((self.U * t) @ self.VT)
        # return inputs + inputs@((self.Q*t)@self.Q.T)


def freeze_model(m):
    print(count_parameters(m))
    for n, param in m.named_parameters():
        if True:
            param.requires_grad = False

    for param in m.classifier.parameters():
        param.requires_grad = True

    for n, p in m.named_parameters():
        n = n.lower().replace("_", "")
        if "layernorm" in n or "cls" in n:
            p.requires_grad = True
        # print(n, p.requires_grad)
    print(count_parameters(m))

    # for x in flatten(m):
    #    if "LayerNorm" in x.__class__.__name__:
    #        x.reset_parameters()
    return m


def add_adapters(m, adapters, T=None):
    m_name = m.config.model_type

    if "albert" in m_name:
        ln1 = (
            m.albert.encoder.albert_layer_groups[0]
            .albert_layers[0]
            .full_layer_layer_norm
        )
        ln2 = (
            m.albert.encoder.albert_layer_groups[0].albert_layers[0].attention.LayerNorm
        )
        m.albert.encoder.albert_layer_groups[0].albert_layers[
            0
        ].full_layer_layer_norm = torch.nn.Sequential(adapters[0], ln1)
        m.albert.encoder.albert_layer_groups[0].albert_layers[
            0
        ].attention.LayerNorm = torch.nn.Sequential(dapters[1], ln2)
        return m

    num_adapter_layers = len(adapters) // 2
    for i in range(num_adapter_layers):
        m_ = getattr(m, m_name)
        d1 = m_.encoder.layer[-num_adapter_layers + i].attention.output.dense
        m_.encoder.layer[
            -num_adapter_layers + i
        ].attention.output.dense = torch.nn.Sequential(d1, adapters[2 * i])
        d2 = getattr(m, m_name).encoder.layer[-num_adapter_layers + i].output.dense
        m_.encoder.layer[-num_adapter_layers + i].output.dense = torch.nn.Sequential(
            d2, adapters[2 * i + 1]
        )
        setattr(m, m_name, m_)

    return m


def add_cln(m, T=None, i=None):
    if m.config.ln_mode != "cln":
        return m
    print("use cln")
    m_name = m.config.model_type
    m_ = getattr(m, m_name)
    num_layers = len(m_.encoder.layer)
    for i in range(num_layers):
        base_LN = m_.encoder.layer[i].attention.output.LayerNorm
        m_.encoder.layer[i].attention.output.LayerNorm = ConditionalLayerNorm(
            m.config, T, m.i, base_LN=base_LN
        )
        base_LN = m_.encoder.layer[i].output.LayerNorm
        m_.encoder.layer[i].output.LayerNorm = ConditionalLayerNorm(
            m.config, T, m.i, base_LN=base_LN
        )
    setattr(m, m_name, m_)
    return m


def add_embeddings(m, T=None):
    assert m.config.embedding_mode in {"freeze", "cls", "special", "all", "hyper"}
    m_name = m.config.model_type
    m_ = getattr(m, m_name)
    e1 = m_.embeddings.word_embeddings

    if m.config.embedding_mode == "freeze":
        m_.embeddings.token_type_embeddings.weight.requires_grad = False
        m_.embeddings.word_embeddings.weight.requires_grad = False

    if m.config.embedding_mode == "cls":
        m_.embeddings.word_embeddings = torch.nn.Sequential(e1, CLSEmbedding(m.config))

    if m.config.embedding_mode == "special":
        print("fine-tune special", m.config.trained_tokens[:20])
        e2 = torch.nn.Embedding(len(m.config.trained_tokens) + 1, m.config.hidden_size)
        m_.embeddings.token_type_embeddings.weight.requires_grad = True
        e2.weight.data.fill_(0)
        e2.weight.requires_grad = True
        m_.embeddings.word_embeddings = ResidualEmbedding(
            e1, e2, trained_tokens=m.config.trained_tokens
        )

    if m.config.embedding_mode == "hyper":
        print("using hyperembedding", m.config.trained_tokens[:20])
        m_.embeddings.word_embeddings = HyperEmbedding(
            config=m.config, T=T, e1=e1, trained_tokens=m.config.trained_tokens, i=m.i
        )
        m_.embeddings.token_type_embeddings = HyperEmbedding(
            config=m.config,
            T=T,
            e1=m_.embeddings.token_type_embeddings,
            trained_tokens=[0, 1],
            i=m.i,
        )

    if m.config.embedding_mode == "all":
        m_.embeddings.token_type_embeddings.weight.requires_grad = True
        m_.embeddings.word_embeddings.weight.requires_grad = True

    setattr(m, m_name, m_)
    return m
