import os
import re
import random
import numpy as np
import torch
import datetime
import evaluate
from transformers import (
    PretrainedConfig,
    BertTokenizer,
    BertConfig,
    DataCollatorWithPadding
)
from transformers.models.bert.modeling_bert import (
    BertModel,
    BertForSequenceClassification,
    BertForQuestionAnswering,
    BertOutput,
    BertSelfAttention,
    BertSelfOutput,
)
from models.modeling_bert import (
    ReduBertForSequenceClassification,
    ReduBertForQuestionAnswering,
    ReduBertLayer,
    ReduBertSelfOutput,
    ReduBertOutput,
)
from collections import defaultdict
from transformers.utils import logging
from datasets import Dataset, DatasetDict, load_from_disk, load_metric
from typing import Optional, Dict, List, Callable, Any, Union
from trainer.bert_glue_trainer import BertGlueTrainer
from trainer.bert_squad_trainer import BertSquadTrainer
from loguru import logger


class BertUtils:
    DataCollator = DataCollatorWithPadding

    gelu_task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }

    columns = [
        'input_ids',
        'attention_mask',
        'token_type_ids',
        'position_ids',
        'head_mask',
        'inputs_embeds',
        'output_attentions',
        'output_hidden_states',
        'return_dict',
        'label',
    ]

    def init_preprocess_function(self, args, tokenizer: BertTokenizer):
        if args.dataset != "glue":
            raise ValueError
        sentence1_key, sentence2_key = self.gelu_task_to_keys[args.task_name]

        def preprocess_function(examples):
            # Tokenize the texts
            args = (
                (examples[sentence1_key],) if sentence2_key is None else (
                examples[sentence1_key], examples[sentence2_key])
            )
            result = tokenizer(*args, max_length=512, truncation=True)
            return result

        return preprocess_function

    def collect(self,
                data_args,
                model_args,
                model: Union[BertForSequenceClassification, BertForQuestionAnswering],
                trainer: Union[BertGlueTrainer, BertSquadTrainer]
                ):
        config = model.config
        handlers = []
        norm_dict = {}
        act_dict = {}
        vo_dict = {}
        attn_dict = {}

        def init_fn(name, collector, collect_input=True):
            if name not in collector:
                collector[name] = []

            def fn(module, inputs, outputs):
                if collect_input:
                    data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)
                else:
                    data: torch.Tensor = outputs[0].detach().cpu().squeeze(0)
                size = min(data_args.per_sample_token_num, data.shape[0])
                indices: np.ndarray = np.random.choice(data.shape[0], size, replace=False).tolist()
                collector[name].append(data[indices, ...])

            return fn

        def init_wo_fn(name, collector):
            if name not in collector:
                collector[name] = []

            def fn(module, inputs, outputs):
                data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)  # [L, D]
                size = min(data_args.per_sample_token_num, data.shape[0])
                indices: np.ndarray = np.random.choice(data.shape[0], size, replace=False).tolist()
                collector[name].append(data[indices, ...])

            return fn

        def init_act_fn(name, collector, collect_input=True):
            if name not in collector:
                collector[name] = []

            def fn(module, inputs, outputs):
                if collect_input:
                    data: torch.Tensor = inputs[0].detach().cpu().squeeze(0)
                else:
                    data: torch.Tensor = outputs[0].detach().cpu().squeeze(0)
                collector[name].append((data.shape[0], data.abs().mean(dim=0), data.pow(2).mean(dim=0)))

            return fn

        for name, module in model.named_modules():
            if 'bert' not in name:
                continue
            if isinstance(module, torch.nn.LayerNorm):
                if 'embeddings' in name:
                    continue
                norm_dict[name] = []
                handler = module.register_forward_hook(init_fn(name, norm_dict))
                handlers.append(handler)
            if isinstance(module, BertOutput):
                act_dict[name] = []
                handler = module.register_forward_hook(init_act_fn(name, act_dict))
                handlers.append(handler)
            if isinstance(module, BertSelfAttention):
                attn_dict[name + '.query'] = []
                attn_dict[name + '.key'] = []
                handlers.append(module.query.register_forward_hook(
                    init_fn(name + '.query', attn_dict, collect_input=False)))
                handlers.append(module.key.register_forward_hook(
                    init_fn(name + '.key', attn_dict, collect_input=False)))
            if isinstance(module, BertSelfOutput):
                vo_dict[name + '.dense'] = []
                handlers.append(module.dense.register_forward_hook(
                    init_wo_fn(name + '.dense', vo_dict)))

        trainer.collect()

        def dump_norm():
            hiddens = []
            for key, value in norm_dict.items():
                print("key={}".format(key))
                hidden = torch.concat(value, dim=0).transpose(0, 1)
                indices = np.random.choice(hidden.shape[1], size=data_args.token_sample_num, replace=False)
                hiddens.append(hidden[:, indices])
            hidden = torch.cat(hiddens, dim=-1)
            U, S, _ = torch.linalg.svd(hidden)
            # print(S)
            return U[:, :model_args.redu_hidden_size]

        def shape(x: torch.Tensor) -> torch.Tensor:
            num_attention_heads = config.num_attention_heads
            attention_head_size = int(config.hidden_size / config.num_attention_heads)
            new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
            x = x.view(new_x_shape)
            return x.permute(1, 2, 0)  # [A, H, L]

        def dump_attn():
            attn_proj = {}
            target_head_size = int(model_args.redu_attention_size / config.num_attention_heads)
            keys = sorted(list(set(x.replace('.query', '').replace('.key', '')
                                   for x in attn_dict.keys())))
            for key in keys:
                print("key={}".format(key))
                hidden_query = torch.concat(attn_dict[key + '.query'], dim=0)  # [L, H]
                hidden_key = torch.concat(attn_dict[key + '.key'], dim=0)  # [L, H]

                indices = np.random.choice(
                    hidden_query.shape[0], size=data_args.token_sample_num, replace=False)
                hidden_query = hidden_query[indices, :]
                hidden_key = hidden_key[indices, :]
                hidden_query = shape(hidden_query)
                hidden_key = shape(hidden_key)
                Uq, Sq, _ = torch.linalg.svd(hidden_query)
                Uk, Sk, _ = torch.linalg.svd(hidden_key)

                M = torch.diag_embed(Sq) @ Uq.transpose(1, 2) @ Uk @ torch.diag_embed(Sk)
                Um, Sm, VmT = torch.linalg.svd(M)
                UT = Uq @ torch.diag_embed(1.0 / Sq) @ Um[..., :target_head_size] @ torch.diag_embed(
                    torch.sqrt(Sm[:, :target_head_size]))
                V = torch.diag_embed(torch.sqrt(Sm[:, :target_head_size])) @ VmT[:, :target_head_size,
                                                                             :] @ torch.diag_embed(
                    1.0 / Sk) @ Uk.transpose(1, 2)
                attn_proj[key + '.U'] = UT.transpose(1, 2)
                attn_proj[key + '.V'] = V
            return attn_proj

        def dump_vo():
            r_kv = int(model_args.redu_attention_size / config.num_attention_heads)
            vo_value = {}
            for key, value in vo_dict.items():
                print("key={}".format(key))
                hidden = torch.concat(value, dim=0)  # [L, H]
                size = min(hidden.shape[0], data_args.token_sample_num)
                indices = np.random.choice(
                    hidden.shape[0], size=size, replace=False)
                hidden = shape(hidden[indices, :])
                U, S, _ = torch.linalg.svd(hidden)
                vo_value[key] = U[..., :r_kv]
            return vo_value

        def dump_act():
            act_value = {}
            bert_state_dict = model.state_dict()
            for key, value in act_dict.items():
                print("key={}".format(key))
                ws, hs, h2s = zip(*value)
                ws = torch.tensor(ws).unsqueeze(-1)
                hs = torch.stack(hs)
                h2s = torch.stack(h2s)

                h_mean = (ws * hs).sum(dim=0) / ws.sum()
                h2_mean = (ws * h2s).sum(dim=0) / ws.sum()
                h_std = (h2_mean - h_mean * h_mean + 1e-5).sqrt()
                weight: torch.Tensor = bert_state_dict[key + ".dense.weight"].transpose(0, 1).detach().cpu()
                norm = weight.norm(dim=1)
                h_value: torch.Tensor = (h_mean + h_std) * norm
                act_value[key] = h_value
            return act_value

        norm_proj = dump_norm()
        attn_proj = dump_attn()
        vo_proj = dump_vo()
        act_value = dump_act()
        bert_comp_params = {
            "norm_proj": norm_proj,
            "attn_proj": attn_proj,
            "vo_proj": vo_proj,
            "act_value": act_value,
        }
        return bert_comp_params

    @torch.no_grad()
    def load_model_params(self,
                          model: Union[ReduBertForSequenceClassification, ReduBertForQuestionAnswering],
                          state_dict: Dict[str, torch.Tensor],
                          compression_params: Any,
                          ):
        norm_proj: torch.Tensor = compression_params["norm_proj"]
        attn_proj: Dict[str, Optional[torch.Tensor]] = compression_params["attn_proj"]
        vo_proj: Optional[Dict[str, torch.Tensor]] = compression_params["vo_proj"]
        act_value: Dict[str, Optional[torch.Tensor]] = compression_params["act_value"]

        config = model.config
        layer_pat = re.compile('bert.encoder.layer.\d+')
        layer_id_pat = re.compile('.*\.(\d+)')

        for n, p in model.named_parameters():
            if n not in state_dict:
                continue
            if p.shape == state_dict[n].shape:
                p.copy_(state_dict[n])

        def get_att_layernorm_params(layer_id: int):
            path = "bert.encoder.layer.{}.attention.output.LayerNorm".format(layer_id)
            return (
                state_dict[path + ".weight"],
                state_dict[path + ".bias"],
            )

        def get_ffn_layernorm_params(layer_id: int):
            if layer_id < 0:
                path = "bert.embeddings.LayerNorm"
            else:
                path = "bert.encoder.layer.{}.output.LayerNorm".format(layer_id)
            return (
                state_dict[path + ".weight"],
                state_dict[path + ".bias"],
            )

        def load_qk_params(
                lin_name: str,
                linear: torch.nn.Linear,
                norm_w: torch.Tensor,
                norm_b: torch.Tensor,
                proj: torch.Tensor,
        ):
            num_attention_heads = config.num_attention_heads
            attention_head_size = int(config.hidden_size / config.num_attention_heads)

            n_w = torch.diag(norm_proj.T @ torch.diag(norm_w) @ norm_proj)
            n_b = norm_proj.T @ norm_b
            lin_w = state_dict[lin_name + '.weight']
            lin_b = state_dict[lin_name + '.bias']
            n_lin_w = lin_w @ torch.diag(norm_w) @ norm_proj @ torch.diag(1.0 / n_w)
            n_lin_b = lin_w @ norm_b + lin_b - n_lin_w @ n_b

            prune_head_size = int(config.redu_attention_size / config.num_attention_heads)
            n_lin_w = n_lin_w.view(num_attention_heads, attention_head_size, config.redu_hidden_size)
            n_lin_b = n_lin_b.view(num_attention_heads, attention_head_size)
            n_lin_w = (proj @ n_lin_w).view(num_attention_heads * prune_head_size, config.redu_hidden_size)
            n_lin_b = (proj @ n_lin_b.unsqueeze(-1)).squeeze(-1).view(num_attention_heads * prune_head_size)

            linear.weight.copy_(n_lin_w)
            linear.bias.copy_(n_lin_b)

        def prune_ffn_weight(
                w: torch.Tensor,
                b: torch.Tensor,
                retain_indices: torch.Tensor,
                prune_dim: int = 0
        ):
            if prune_dim == 0:
                w = w[retain_indices, :]
                b = b[retain_indices]
            else:
                w = w[:, retain_indices]
                b = b
            return w, b

        def load_afternorm_linear(
                lin_name: str,
                linear: torch.nn.Linear,
                norm_w: torch.Tensor,
                norm_b: torch.Tensor,
                retain_indices: Optional[torch.Tensor] = None,
                prune_dim: int = 0,
                v_proj: Optional[torch.Tensor] = None
        ):
            # print("name={}".format(name))
            n_w = torch.diag(norm_proj.T @ torch.diag(norm_w) @ norm_proj)
            n_b = norm_proj.T @ norm_b
            lin_w = state_dict[lin_name + '.weight']
            lin_b = state_dict[lin_name + '.bias']
            n_lin_w = lin_w @ torch.diag(norm_w) @ norm_proj @ torch.diag(1.0 / n_w)
            n_lin_b = lin_w @ norm_b + lin_b - n_lin_w @ n_b

            if retain_indices is not None:
                n_lin_w, n_lin_b = prune_ffn_weight(n_lin_w, n_lin_b, retain_indices, prune_dim)

            if v_proj is not None:
                d_kv = int(config.hidden_size / config.num_attention_heads)
                n_lin_w = n_lin_w.view(config.num_attention_heads, d_kv, config.redu_hidden_size)
                n_lin_b = n_lin_b.view(config.num_attention_heads, d_kv, 1)
                n_lin_w = torch.matmul(v_proj.transpose(1, 2), n_lin_w)
                n_lin_b = torch.matmul(v_proj.transpose(1, 2), n_lin_b)
                n_lin_w = n_lin_w.reshape(config.redu_attention_size, config.redu_hidden_size)
                n_lin_b = n_lin_b.reshape(config.redu_attention_size)

            linear.weight.copy_(n_lin_w)
            linear.bias.copy_(n_lin_b)

        def load_beforenorm_linear(
                lin_name: str,
                linear: torch.nn.Linear,
                retain_indices: Optional[torch.Tensor] = None,
                prune_dim: int = 1,
                o_proj: Optional[torch.Tensor] = None
        ):
            lin_w = state_dict[lin_name + '.weight']
            lin_b = state_dict[lin_name + '.bias']
            n_lin_w = norm_proj.T @ lin_w
            n_lin_b = norm_proj.T @ lin_b

            if retain_indices is not None:
                n_lin_w, n_lin_b = prune_ffn_weight(n_lin_w, n_lin_b, retain_indices, prune_dim)

            if o_proj is not None:
                d_kv = int(config.hidden_size / config.num_attention_heads)
                n_lin_w = n_lin_w.view(config.redu_hidden_size, config.num_attention_heads, d_kv).permute(1, 0, 2)
                n_lin_w = torch.matmul(n_lin_w, o_proj)
                n_lin_w = n_lin_w.permute(1, 0, 2).reshape(config.redu_hidden_size, config.redu_attention_size)

            linear.weight.copy_(n_lin_w)
            linear.bias.copy_(n_lin_b)

        def load_norm_params(
                norm_name: str,
                norm: torch.nn.LayerNorm
        ):
            norm_w = state_dict[norm_name + '.weight']
            norm_b = state_dict[norm_name + '.bias']
            lam = torch.diag(norm_proj.T @ torch.diag(norm_w) @ norm_proj)
            beta = norm_proj.T @ norm_b
            norm.weight.copy_(lam)
            norm.bias.copy_(beta)

        def load_act_indices(
                act_name: str,
        ):
            h_value = act_value[act_name]
            if h_value is None:
                retain_indices = None
            else:
                retain_indices = h_value.sort()[1][-config.redu_intermediate_size:]
            return retain_indices

        layer: ReduBertLayer = model.bert.encoder.layer[0]
        layer.attention.output.proj.weight.copy_(norm_proj.T)

        for name, module in model.named_modules():
            if layer_pat.fullmatch(name):
                assert isinstance(module, ReduBertLayer)
                print(name)
                layer_id = int(layer_id_pat.match(name).group(1))

                ffn_norm_w, ffn_norm_b = get_ffn_layernorm_params(layer_id - 1)
                att_norm_w, att_norm_b = get_att_layernorm_params(layer_id)

                if layer_id > 0:
                    load_qk_params(
                        name + ".attention.self.query",
                        module.attention.self.query,
                        ffn_norm_w,
                        ffn_norm_b,
                        attn_proj[name + '.attention.self.U']
                    )

                    load_qk_params(
                        name + ".attention.self.key",
                        module.attention.self.key,
                        ffn_norm_w,
                        ffn_norm_b,
                        attn_proj[name + '.attention.self.V']
                    )

                    load_afternorm_linear(
                        name + ".attention.self.value", module.attention.self.value, ffn_norm_w, ffn_norm_b,
                        v_proj=vo_proj[name + ".attention.output.dense"]
                    )

                load_beforenorm_linear(
                    name + ".attention.output.dense", module.attention.output.dense,
                    o_proj=vo_proj[name + ".attention.output.dense"] if layer_id > 0 else None
                )

                indices = load_act_indices(name + ".output")

                load_afternorm_linear(
                    name + ".intermediate.dense",
                    module.intermediate.dense,
                    att_norm_w,
                    att_norm_b,
                    retain_indices=indices,
                    prune_dim=0
                )

                load_beforenorm_linear(
                    name + ".output.dense",
                    module.output.dense,
                    retain_indices=indices,
                    prune_dim=1
                )

                load_norm_params(
                    name + ".attention.output.LayerNorm",
                    module.attention.output.LayerNorm
                )
                load_norm_params(
                    name + ".output.LayerNorm",
                    module.output.LayerNorm
                )

        ffn_norm_w, ffn_norm_b = get_ffn_layernorm_params(
            config.num_hidden_layers - 1)

        if model.bert.add_pooling_layer:
            load_afternorm_linear(
                "bert.pooler.dense",
                model.bert.pooler.dense,
                ffn_norm_w,
                ffn_norm_b,
            )
        else:
            load_afternorm_linear(
                "qa_outputs",
                model.qa_outputs,
                ffn_norm_w,
                ffn_norm_b,
            )
        logger.info("[finish load state]")
