from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
from torch import nn
from collections import defaultdict
from torch.nn import functional as F
from timm.layers.attention import maybe_add_mask
import matplotlib.pyplot as plt
    
class Attention_UniSVD(nn.Module):
    def __init__(self, factorized, attn2_with_bias=True, attn_drop=0, drop=0, 
                 qk_head_dim=None, vo_head_dim=None, head_dim=64, num_heads=12,
                 ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.qk_head_dims = qk_head_dim
        self.vo_head_dims = vo_head_dim
        self.fused_attn = True
        # if self.qk_head_dims is None:
        #     self.scale = self.head_dim ** -0.5
        # else:
        self.scale = self.qk_head_dims**-0.5

        dim = int(head_dim * num_heads)

        self.w_qs = nn.Linear(dim, factorized.q_w.shape[0], bias=attn2_with_bias)
        self.w_ks = nn.Linear(dim, factorized.k_w.shape[0], bias=attn2_with_bias)

        self.w_qs.weight.data = factorized.q_w # (rank * head, in_features)
        self.w_ks.weight.data = factorized.k_w # (rank * head, in_features)
        self.w_qs.bias.data = factorized.q_b
        self.w_ks.bias.data = factorized.k_b

        self.w_vs = nn.Linear(dim, factorized.v_w.shape[0], bias=attn2_with_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(factorized.o_w.shape[1], dim, bias=attn2_with_bias)
        self.proj_drop = nn.Dropout(drop)

        self.w_vs.weight.data = factorized.v_w    # (rank * head, in_features)
        self.proj.weight.data = factorized.o_w    # (out_features, rank * head)
        self.w_vs.bias.data = factorized.v_b
        self.proj.bias.data = factorized.o_b
    
    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        B, N, C = x.shape
        
        q = self.w_qs(x).view(B, N, self.num_heads, -1).transpose(1, 2)
        k = self.w_ks(x).view(B, N, self.num_heads, -1).transpose(1, 2)
        v = self.w_vs(x).view(B, N, self.num_heads,-1).transpose(1, 2)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=attn_mask,
                dropout_p=0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = maybe_add_mask(attn, attn_mask)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
@dataclass
class FactorizedMatrix:
    mat_l: torch.Tensor = None
    mat_r: torch.Tensor = None
    eq_rank: int = 0
    active_rank: int = 0

@dataclass
class FactorizedMatrix_UniSVD:
    q_w: torch.Tensor = None
    k_w: torch.Tensor = None
    v_w: torch.Tensor = None
    o_w: torch.Tensor = None
    q_b: torch.Tensor = None
    k_b: torch.Tensor = None
    v_b: torch.Tensor = None
    o_b: torch.Tensor = None
    eq_rank: int = 0
    active_rank: int = 0
    qk_head_dim: int = 64
    vo_head_dim: int = 64


class BaseFactorization:
    def __init__(self):
        self.scaling_dict = {}
        self.input_shapes = {}

    def compute_scaling(self, model, name_omit, calib_data, mixup_fn, white_list=[]):
        print("\nNo scaling method implemented.")
        pass

    def factorize_matrix(self, matrix, name, rank=-1, ratio=-1, verbose=True) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        print(f"Factorizing {name} matrix") if verbose else None
        if rank == -1 and ratio == -1:
            print(f"Warning: {name} rank or ratio must be defined!")
            return
        elif rank != -1 and ratio != -1:
            print(
                f"Warning: {name} rank and ratio are both defined, "
                "only one can be used at a time!"
            )
            return
        dev = torch.device(torch.cuda.current_device())
        eq_rank = (
            matrix.shape[0] * matrix.shape[1] // (matrix.shape[0] + matrix.shape[1])
        )
        if rank == 0:
            rank = eq_rank
        elif ratio != -1:
            # rank = int(np.round(eq_rank * ratio))
            rank = int(eq_rank * ratio)
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return
        fact_matrix = self._factorize_matrix(
            matrix=matrix, name=name, eq_rank=eq_rank, rank=rank, dev=dev
        )
        return fact_matrix

    def _factorize_matrix(self, matrix, name, eq_rank, rank, dev) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        raise NotImplementedError("Subclasses should implement this method.")

    def create_factorized_sequential(
        self, factorized_matrix: FactorizedMatrix, original_module
    ) -> nn.Module:
        dev = original_module.weight.device
        module_l = nn.Linear(
            original_module.in_features,
            factorized_matrix.active_rank,
            bias=False,
        )
        module_r = nn.Linear(
            factorized_matrix.active_rank,
            original_module.out_features,
            # bias=(original_module.bias is not None),
            bias=False,
        )
        module_l = module_l.to(dev)
        module_r = module_r.to(dev)

        weight_l, weight_r = factorized_matrix.mat_l, factorized_matrix.mat_r
        module_l.weight.data.copy_(weight_r[: factorized_matrix.active_rank, :].to(dev))
        module_r.weight.data.copy_(weight_l[:, : factorized_matrix.active_rank].to(dev))
        # if hasattr(original_module, "bias") and original_module.bias is not None:
        #     module_r.bias.data.copy_(original_module.bias)
        module = weight_l = weight_r = None
        del weight_l, weight_r, module

        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        # return nn.Sequential(module_l, module_r).to(dev)
        return SeqSVD(module_l, module_r, original_module.bias if hasattr(original_module, "bias")
                      else None).to(dev)

    def factorize_model(self, uncom_model, rank_dict, name_omit, verbose=True) -> dict:
        """
        Apply low-rank decomposition to the model in place. Note that name omit
        is supported implicitly as removing or not mentioning something in the
        compression ratio dict will resul in it not being compressed.

        Args:
            name (str): module name
            module (nn.Linear): the given Linear module
            raw_profile (dict): the raw profile of the given module
        """
        print("\nApplying factorization")
        dev = torch.device(torch.cuda.current_device())
        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        model = uncom_model.eval().to(dev)
        copied_modules = {
            name: module_sub
            for name, module_sub in model.named_modules()
            if all(omit not in name for omit in name_omit)
            and isinstance(module_sub, nn.Linear)
        }
        print(rank_dict) if verbose else None
        for name, module_sub in copied_modules.items():
            # condition for not applying low rank
            if (
                module_sub.out_features < 10
                or name in name_omit
                or rank_dict[name] == -1
                or rank_dict[name] == 1.0
            ):
                continue

            factorized = self.factorize_matrix(
                matrix=module_sub.weight,
                rank=rank_dict[name] if isinstance(rank_dict[name], int) else -1,
                ratio=rank_dict[name] if isinstance(rank_dict[name], float) else -1,
                name=name,
                verbose=verbose,
            )

            svd_replacement = self.create_factorized_sequential(
                factorized_matrix=factorized, original_module=module_sub
            )

            print(f"Applying low rank on {name:^10}, rank {rank_dict[name]}") if verbose else None

            base, localname = model, name
            while "." in localname:
                prefix, localname = localname.split(".", 1)
                base = base.__getattr__(prefix)

            setattr(base, localname, svd_replacement)


class SeqSVD(nn.Module):
    def __init__(self, mod_a, mod_b, bias=None):
        super().__init__()
        self.mod_a = mod_a
        self.mod_b = mod_b
        self.bias = bias

    def forward(self, x):
        x = self.mod_a(x)
        x = self.mod_b(x)
        if self.bias is not None:
            x += self.bias
        return x


class Hookstuff:
    # this is one unified hook to obtain the scalings for everybody.
    def __init__(self, model: nn.Module, name_omit=[], dump_shape=False, white_list=[]):
        self.model = model
        self.name_omit = name_omit
        self.white_list = white_list
        self.dump_shape = dump_shape
        self.x_dict = {}
        self.profile = {}
        self.input_shape = {}
        self.hooks = []
        self.cp_modules = reversed(
            [
                (name, module_sub)
                for name, module_sub in model.named_modules()
                # if all(omit not in name for omit in name_omit)
                if isinstance(module_sub, nn.Linear)
            ]
        )

    def _hook_fn(self, layer_name):
        def get_scaling_mat(module, input, output):
            pass

        return get_scaling_mat

    def _register_hooks_recursive(self, cp_modules: dict, prefix=""):
        for name, layer in cp_modules:
            if isinstance(layer, nn.Linear):
                if any(n in name for n in self.name_omit):
                    continue
                if layer.out_features < 10:
                    continue
                try:
                    if self.white_list and not any(n in name for n in self.white_list):
                        continue
                except Exception:
                    if self.white_list and not any(
                        n in name for n, _ in self.white_list
                    ):
                        continue
                hook = layer.register_forward_hook(self._hook_fn(name))
                self.hooks.append(hook)

    def attach_hooks(self):
        self._register_hooks_recursive(self.cp_modules)

    def clear_hooks(self):
        for hook in self.hooks:
            hook.remove()

class BaseFactorization_UniSVD:
    def __init__(self):
        self.scaling_dict = {}
        self.input_shapes = {}
        self.x_dict = {}

    def compute_scaling(self, model, name_omit, calib_data, mixup_fn, white_list=[]):
        print("\nNo scaling method implemented.")
        pass

    def factorize_matrix_unisvd(self, matrix1, matrix2, bias1, bias2, name1, name2, rank1=-1, ratio1=-1, rank1_2=-1, rank2=-1, rank2_2=-1, verbose=True,
                          head_dim=64, num_heads=12, num_layers=12) -> FactorizedMatrix_UniSVD:

        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        print(f"Factorizing {name1} matrix") if verbose else None
        print(f"Factorizing {name2} matrix") if verbose else None

        if rank1 == -1 and ratio1 == -1:
            print(f"Warning: {name1} rank or ratio must be defined!")
            return
        elif rank1 != -1 and ratio1 != -1:
            print(
                f"Warning: {name1} rank and ratio are both defined, "
                "only one can be used at a time!"
            )
            return

        dev = torch.device(torch.cuda.current_device())

        eq_rank1 = (
            matrix1.shape[0] * matrix1.shape[1] // (matrix1.shape[0] + matrix1.shape[1])
        )

        eq_rank2 = (
            matrix2.shape[0] * matrix2.shape[1] // (matrix2.shape[0] + matrix2.shape[1])
        )

        if rank1 == 0:
            rank1 = eq_rank1
        elif ratio1 != -1:
            rank1 = int(eq_rank1 * ratio1)
        elif rank1 > eq_rank1:
            print(f"Warning: {name1} rank is larger than equivalent rank!")
            return

        fact_matrix = self._factorize_matrix_unisvd(
            matrix1=matrix1, matrix2=matrix2, bias1=bias1, bias2=bias2, name1=name1, name2=name2,
            eq_rank1=eq_rank1, eq_rank2=eq_rank2, rank1=rank1, rank1_2=rank1_2, rank2=rank2, rank2_2=rank2_2, dev=dev, 
            head_dim=head_dim, num_heads=num_heads, num_layers=num_layers
        )

        return fact_matrix

    def factorize_matrix(self, matrix, name, rank=-1, ratio=-1, verbose=True) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        print(f"Factorizing {name} matrix") if verbose else None
        if rank == -1 and ratio == -1:
            print(f"Warning: {name} rank or ratio must be defined!")
            return
        elif rank != -1 and ratio != -1:
            print(
                f"Warning: {name} rank and ratio are both defined, "
                "only one can be used at a time!"
            )
            return
        dev = torch.device(torch.cuda.current_device())
        eq_rank = (
            matrix.shape[0] * matrix.shape[1] // (matrix.shape[0] + matrix.shape[1])
        )
        if rank == 0:
            rank = eq_rank
        elif ratio != -1:
            rank = int(eq_rank * ratio)
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return
        fact_matrix = self._factorize_matrix(
            matrix=matrix, name=name, eq_rank=eq_rank, rank=rank, dev=dev
        )
        return fact_matrix

    def _factorize_matrix(self, matrix, name, eq_rank, rank, dev) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        raise NotImplementedError("Subclasses should implement this method.")

    def create_factorized_sequential(
        self, factorized_matrix: FactorizedMatrix, original_module
    ) -> nn.Module:
        
        dev = original_module.weight.device
        module_l = nn.Linear(
            original_module.in_features,
            factorized_matrix.active_rank,
            bias=False,
        )
        module_r = nn.Linear(
            factorized_matrix.active_rank,
            original_module.out_features,
            bias=False,
        )
        module_l = module_l.to(dev)
        module_r = module_r.to(dev)

        weight_l, weight_r = factorized_matrix.mat_l, factorized_matrix.mat_r
        module_l.weight.data.copy_(weight_r[: factorized_matrix.active_rank, :].to(dev))
        module_r.weight.data.copy_(weight_l[:, : factorized_matrix.active_rank].to(dev))

        module = weight_l = weight_r = None
        del weight_l, weight_r, module

        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        return SeqSVD(module_l, module_r, original_module.bias if hasattr(original_module, "bias")
                      else None).to(dev)
    
    def generate_steps(self, a: int, b: int, steps: int = 12) -> list:
        step_size = (b - a) / (steps - 1)
        return [int(round(a + step_size * i)) for i in range(steps)]

    def factorize_model(self, uncom_model, rank_dict, name_omit, verbose=True) -> dict:
        """
        Apply low-rank decomposition to the model in place. Note that name omit
        is supported implicitly as removing or not mentioning something in the
        compression ratio dict will resul in it not being compressed.

        Args:
            name (str): module name
            module (nn.Linear): the given Linear module
            raw_profile (dict): the raw profile of the given module
        """
        print("\nApplying factorization")
        dev = torch.device(torch.cuda.current_device())
        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        model = uncom_model.eval().to(dev)
        copied_modules = {
            name: module_sub
            for name, module_sub in model.named_modules()
            if all(omit not in name for omit in name_omit)
            and isinstance(module_sub, nn.Linear)
        }
        print('# rank_dict : \n', rank_dict) if verbose else None

        grouped_modules = defaultdict(dict)
        for name, module_sub in copied_modules.items():
            if ".attn.qkv" in name:
                key = name.replace(".attn.qkv", ".attn")
                grouped_modules[key]["qkv"] = (name, module_sub)
            elif ".attn.proj" in name:
                key = name.replace(".attn.proj", ".attn")
                grouped_modules[key]["proj"] = (name, module_sub)
            else:
                grouped_modules[name]["single"] = (name, module_sub)
        
        # Rank for decomposition
        qk_rank1 = 24
        qk_rank2 = 24
        vo_rank1 = 24
        vo_rank2 = 24
        print(f">> rank1 : {qk_rank1} , {qk_rank2} | rank2 : {vo_rank1} , {vo_rank2} ")
        
        for key, layers in grouped_modules.items():
            if "qkv" in layers and "proj" in layers:
                qkv_name, qkv_mod = layers["qkv"]
                proj_name, proj_mod = layers["proj"]
                print(f"[Same Block ATTENTION] {qkv_name} & {proj_name}")

                base, localname = model, qkv_name
                localname = localname.replace(".qkv", "")
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)

                num_heads = base.attn.num_heads
                head_dim = base.attn.head_dim

                # print(f"num_heads : {num_heads} | head_dim : {head_dim}")

                if num_heads == 16:
                    num_layers = 24
                else: 
                    num_layers = 12

                factorized = self.factorize_matrix_unisvd(matrix1=qkv_mod.weight, matrix2=proj_mod.weight, bias1=qkv_mod.bias, bias2=proj_mod.bias, name1=qkv_name, name2=proj_name, 
                                                    rank1=qk_rank1, rank1_2=qk_rank2, rank2=vo_rank1, rank2_2=vo_rank2, 
                                                    head_dim=head_dim, num_heads=num_heads, num_layers=num_layers)

                qk_head_dim = factorized.qk_head_dim
                vo_head_dim = factorized.vo_head_dim

                print(f"Applying low rank on {qkv_name:^10}, rank {qk_head_dim}") if verbose else None

                dev = qkv_mod.weight.device
                atten = Attention_UniSVD(factorized, attn2_with_bias=True, attn_drop=0, drop=0, qk_head_dim=qk_head_dim, vo_head_dim=vo_head_dim, 
                                         head_dim=head_dim, num_heads=num_heads)
                
                setattr(base, 'attn', atten.to(dev))
   
                print(f"Applying low rank on {proj_name:^10}, rank {vo_head_dim}") if verbose else None

            else:
                for lname, lmod in layers.values():

                    if (
                            lmod.out_features < 10
                            or lname in name_omit
                            or rank_dict[lname] == -1
                            or rank_dict[lname] == 1.0
                    ):
                        continue

                    factorized = self.factorize_matrix(
                        matrix=lmod.weight,
                        rank=rank_dict[lname] if isinstance(rank_dict[lname], int) else -1,
                        ratio=rank_dict[lname] if isinstance(rank_dict[lname], float) else -1,
                        name=lname,
                        verbose=False,
                    )

                    svd_replacement = self.create_factorized_sequential(
                        factorized_matrix=factorized, original_module=lmod
                    )

                    print(f"Applying low rank on {lname:^10}, rank {rank_dict[lname]}") if verbose else None

                    base, localname = model, lname
                    while "." in localname:
                        prefix, localname = localname.split(".", 1)
                        base = base.__getattr__(prefix)

                    setattr(base, localname, svd_replacement)



                    