import gc

import torch
from torch import nn

from ._interface import BaseFactorization_UniSVD
from ._interface import FactorizedMatrix, FactorizedMatrix_UniSVD
from ._interface import Hookstuff
from .svd_llm import whitening

class FWSVD_Hook(Hookstuff):
    def _hook_fn(self, layer_name):
        def get_scaling_mat(module, input, output):
            x = input[0].detach().float()
            self.x_dict[layer_name] = x.cpu()
            self.input_shape[layer_name] = list(x.shape)
            self.input_shape[layer_name].extend([module.out_features, 0])
            return

        return get_scaling_mat


class FWSVDFactorization(BaseFactorization_UniSVD):
    def __init__(self, alpha, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.loss_type = "CE"

    def compute_scaling(self, model, name_omit, calib_data, mixup_fn, white_list=[]):
        print("\nCollecting Fisher importance...")
        dev = torch.device(torch.cuda.current_device())
        loss_fn = nn.CrossEntropyLoss()

        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        model = model.eval().to(dev)
        
        for data, target in calib_data:
            model_inputs, target_mix = mixup_fn(data, target)
            model_inputs, target_mix = model_inputs.to(dev), target_mix.to(dev)
            out = model(model_inputs)
            if self.loss_type == "CE":
                loss = loss_fn(out, target_mix)
            else:
                loss = out
            loss.mean().backward()
            del model_inputs, target_mix
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear):
                    if name not in self.scaling_dict:
                        tmp = module.weight.grad.detach()
                        self.scaling_dict[name] = tmp.pow(2).mean(0)
                    else:
                        tmp = module.weight.grad.detach()
                        self.scaling_dict[name] += tmp.pow(2).mean(0)
            model.zero_grad()
            with torch.cuda.device(torch.cuda.current_device()):
                torch.cuda.empty_cache()

        for key, val in self.scaling_dict.items():
            self.scaling_dict[key] = (val / len(calib_data)).sqrt()

        shapes_getter = FWSVD_Hook(model, name_omit, True, white_list=white_list)
        shapes_getter.attach_hooks()

        for key, value in shapes_getter.x_dict.items():
            self.x_dict[key] = value

        dummy_input = torch.randn(20, 3, 224, 224).to(dev)
        model(dummy_input)
        shapes_getter.clear_hooks()
        for key, value in shapes_getter.input_shape.items():
            self.input_shapes[key] = value
        del shapes_getter, dummy_input

        gc.collect()

        return

    def _factorize_matrix(self, matrix, eq_rank, rank, name, dev):
        dev = torch.device(torch.cuda.current_device())
        raw_profile = self.scaling_dict[name]
        scale_diag = raw_profile**self.alpha + 1e-6

        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return

        mat_scaled = matrix * scale_diag.view(1, -1)

        u, s, vh = torch.svd_lowrank(mat_scaled, q=rank)
        s_val = torch.sqrt(torch.diag(s))  # half singular value
        vh = (vh / scale_diag.view(-1, 1)).t()

        s_val = torch.sqrt(torch.diag(s))  # half singular value
        mat_l = u @ s_val
        mat_l = mat_l[:, :rank].to(dev)
        mat_r = s_val @ vh
        mat_r = mat_r[:rank, :].to(dev)

        return FactorizedMatrix(
            mat_l=mat_l.cpu(),  # Left singular vectors
            mat_r=mat_r.cpu(),  # Right singular vectors
            eq_rank=eq_rank,  # Equivalent rank
            active_rank=rank,  # Active rank
        )

    def generate_steps(self, a: int, b: int, steps: int = 12) -> list:
        step_size = (b - a) / (steps - 1)
        return [int(round(a + step_size * i)) for i in range(steps)]
    
    def merge_bias(self, W, b):
        return torch.cat([W, b.unsqueeze(1)], dim=1)

    def robust_cholesky(self, M,
                        min_eig_eps=1e-6,
                        jitter_init=1e-8,
                        jitter_mult=10.0,
                        max_tries=6):

        dev = M.device
        dt = M.dtype

        M = 0.5 * (M + M.transpose(-1, -2))

        if not torch.isfinite(M).all():
            M = torch.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)

        I = torch.eye(M.shape[-1], device=dev, dtype=dt)
        jitter = jitter_init
        for _ in range(max_tries):
            try:
                return torch.linalg.cholesky(M + jitter * I)
            except RuntimeError:
                jitter *= jitter_mult

        evals, evecs = torch.linalg.eigh(M) 
        evals_clamped = torch.clamp(evals, min=min_eig_eps)
        M_fixed = (evecs * evals_clamped.unsqueeze(-2)) @ evecs.transpose(-1, -2)
        M_fixed = 0.5 * (M_fixed + M_fixed.transpose(-1, -2))  

        try:
            return torch.linalg.cholesky(M_fixed)
        except RuntimeError:
            L_alt = evecs @ torch.diag_embed(torch.sqrt(evals_clamped))
            return L_alt

    def tail_sum(self, S, n):
        return S[-n:].sum()

    def low_rank_approximation_qk(self, A: torch.Tensor, A2: torch.Tensor,
                                       X: torch.Tensor,
                                       S1: torch.Tensor, S1_i: torch.Tensor,
                                       r: int, name: str, rr: int, visualize: bool=False) -> torch.Tensor:

        U1, S1, Vh1 = torch.linalg.svd(A, full_matrices=False)
        U2, S2, Vh2 = torch.linalg.svd(A2, full_matrices=False)

        if visualize:
            comcat = A2.transpose(-1, -2) @ A
            U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)

            A_inv_v = torch.linalg.pinv(A[:,:-1])
            A2_inv_v = torch.linalg.pinv(A2[:,:-1].T)

            U_r_1_v = U1[:, :rr]  # (m, r)
            S_r_1_v = S1[:rr]  # (r,)
            Vh_r_1_v = Vh1[:rr, :]  # (r, n)

            U_r_2_v = U2[:, :rr]  # (m, r)
            S_r_2_v = S2[:rr]  # (r,)
            Vh_r_2_v = Vh2[:rr, :]  # (r, n)

            U_r_3_v = U3[:, :rr]  # (m, r)
            S_r_3_v = S3[:rr]  # (r,)
            Vh_r_3_v = Vh3[:rr, :]  # (r, n)

            A_r_1_v = U_r_1_v * S_r_1_v @ Vh_r_1_v
            A_r_2_v = U_r_2_v * S_r_2_v @ Vh_r_2_v
            A_r_3_v = U_r_3_v * S_r_3_v @ Vh_r_3_v

            error1_visual = torch.norm(A[:,:-1]- A_r_1_v[:,:-1] )  # Frobenius norm

            error2_visual = torch.norm(A2[:,:-1]- A_r_2_v[:,:-1] )  # Frobenius norm

            error3_visual = torch.norm(A[:,:-1]- A2_inv_v @ A_r_3_v[:-1,:-1])

            error4_visual = torch.norm(A2[:,:-1].T- A_r_3_v[:-1,:-1] @ A_inv_v)

            error_final = min(error1_visual, error2_visual)

        # For UniSVD
        U_r_1 = U1[:, :r]  # (m, r)
        S_r_1 = S1[:r]  # (r,)
        Vh_r_1 = Vh1[:r, :]  # (r, n)
        A_r_1 = U_r_1 * S_r_1 @ Vh_r_1

        U_r_2 = U2[:, :r]  # (m, r)
        S_r_2 = S2[:r]  # (r,)
        Vh_r_2 = Vh2[:r, :]  # (r, n)
        A_r_2 = U_r_2 * S_r_2 @ Vh_r_2

        # For combined decomposition
        # comcat = A2.transpose(-1, -2) @ A
        # U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)
        # U_r_3 = U3[:, :r]
        # S_r_3 = S3[:r]  # (r,)
        # Vh_r_3 = Vh3[:r, :]  # (r, n)
        # A_r_3 = U_r_3 * S_r_3 @ Vh_r_3

        error1 = torch.norm(A[:,:-1] - A_r_1[:,:-1] )  # Frobenius norm
        error2 = torch.norm(A2[:,:-1] -  A_r_2[:,:-1] )  # Frobenius norm

        min_error = min(error1, error2) 
        if min_error == error1:
            query = Vh_r_1
            key_0 = (U_r_1 * S_r_1).transpose(-1, -2)
            key = key_0 @ A2
        elif min_error == error2:
            query_0 = (U_r_2 * S_r_2).transpose(-1, -2)
            query = query_0 @ A
            key = Vh_r_2

        return query, key

    def low_rank_approximation_vo(self, A: torch.Tensor, A2: torch.Tensor,
                                      X: torch.Tensor,
                                      S1: torch.Tensor, S1_i: torch.Tensor,
                                      S2: torch.Tensor, r: int, name: str, rr:int, visualize: bool=False) -> torch.Tensor:

        U1, S1, Vh1 = torch.linalg.svd(A, full_matrices=False)
        U2, S2, Vh2 = torch.linalg.svd(A2, full_matrices=False)

        if visualize:
            # combined matrices
            comcat = A2 @ A
        
            U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)
        
            # Calculate for comparison with the combined decomposition method

            A_inv_v = torch.linalg.pinv(A[:,:-1])
            A2_inv_v = torch.linalg.pinv(A2)

            U_r_1_v = U1[:, :rr]  # (m, r)
            S_r_1_v = S1[:rr]  # (r,)
            Vh_r_1_v = Vh1[:rr, :]  # (r, n)

            U_r_2_v = U2[:, :rr]  # (m, r)
            S_r_2_v = S2[:rr]  # (r,)
            Vh_r_2_v = Vh2[:rr, :]  # (r, n)

            U_r_3_v = U3[:, :rr]  # (m, r)
            S_r_3_v = S3[:rr]  # (r,)
            Vh_r_3_v = Vh3[:rr, :]  # (r, n)

            A_r_1_v = U_r_1_v * S_r_1_v @ Vh_r_1_v
            A_r_2_v = U_r_2_v * S_r_2_v @ Vh_r_2_v
            A_r_3_v = U_r_3_v * S_r_3_v @ Vh_r_3_v

            error1_visual = torch.norm(A[:,:-1] - A_r_1_v[:,:-1] ) 

            error2_visual = torch.norm(A2 - A_r_2_v )  

            error3_visual = torch.norm(A[:,:-1] - A2_inv_v @ A_r_3_v[:,:-1])

            error4_visual = torch.norm(A2 - A_r_3_v[:,:-1] @ A_inv_v)
        
        ###########################################

        # For UniSVD
        U_r_1 = U1[:, :r]  # (m, r)
        S_r_1 = S1[:r]  # (r,)
        Vh_r_1 = Vh1[:r, :]  # (r, n)
        A_r_1 = U_r_1 * S_r_1 @ Vh_r_1

        U_r_2 = U2[:, :r]  # (m, r)
        S_r_2 = S2[:r]  # (r,)
        Vh_r_2 = Vh2[:r, :]  # (r, n)
        A_r_2 = U_r_2 * S_r_2 @ Vh_r_2

        # For combined decomposition
        # U_r_3 = U3[:, :r]
        # S_r_3 = S3[:r]  # (r,)
        # Vh_r_3 = Vh3[:r, :]  # (r, n)
        # A_r_3 = U_r_3 * S_r_3 @ Vh_r_3

        error1 = torch.norm(A[:,:-1] - A_r_1[:,:-1])  # Frobenius norm
        error2 = torch.norm(A2 - A_r_2)  # Frobenius norm
        
        min_error = min(error1, error2)

        if min_error == error1:
            value = Vh_r_1
            output = A2 @ (U_r_1 * S_r_1)

        elif min_error == error2:
            value = Vh_r_2 @ A
            output = U_r_2 * S_r_2

        return value, output
    
    
    def _factorize_matrix_unisvd(self, matrix1, matrix2, bias1, bias2, name1, name2, eq_rank1, eq_rank2, rank1, rank1_2, rank2, rank2_2, dev, 
                           head_dim=64, num_heads=12, num_layers=12):
        
        raw_profile = self.scaling_dict
        x_dict = self.x_dict
        
        scale_diag1, scale_diag_inv1 = whitening(dev, raw_profile, name1)
        scale_diag1, scale_diag_inv1 = scale_diag1.float(), scale_diag_inv1.float()

        scale_diag2, scale_diag_inv2 = whitening(dev, raw_profile, name2)
        scale_diag2, scale_diag_inv2 = scale_diag2.float(), scale_diag_inv2.float()
        
        if rank1 == 0:
            rank1 = eq_rank1
        elif rank1 > eq_rank1:
            print(f"Warning: {name1} rank is larger than equivalent rank!")
            return
        
        mat_scaled1 = matrix1.to(dev)
        mat_scaled2 = matrix2.to(dev)

        qkv_w = mat_scaled1
        q_w, k_w, v_w = qkv_w.chunk(3, dim=0)
        qkv_b = bias1
        q_b, k_b, v_b = qkv_b.chunk(3, dim=0)

        o_w = mat_scaled2
        o_b = bias2

        q_w = self.merge_bias(q_w, q_b)
        k_w = self.merge_bias(k_w, k_b)
        v_w = self.merge_bias(v_w, v_b)

        in_feats = q_w.shape[1]
        dim = int(head_dim * num_heads)

        wq = q_w.reshape(num_heads, head_dim, in_feats)
        wk = k_w.reshape(num_heads, head_dim, in_feats)
        wvi = v_w.reshape(num_heads, head_dim, in_feats)
        woi = o_w.reshape(dim, num_heads, head_dim).transpose(0, 1)

        r_qk = self.generate_steps(rank1, rank1_2, num_layers) 
        r_vo = self.generate_steps(rank2, rank2_2, num_layers)

        qk_head_dim = r_qk[int(name1.split(".")[1])]
        vo_head_dim = r_vo[int(name1.split(".")[1])]

        wq_l = []
        wk_r = []

        for i in range(num_heads):
            l, r = self.low_rank_approximation_qk(wq[i, :, :], wk[i, :, :], x_dict[name1], scale_diag1, scale_diag_inv1, qk_head_dim, name1, rr=24, visualize=False)
            wq_l.append(l) # torch.Size([rank, in_features + 1]) 
            wk_r.append(r) # torch.Size([rank, in_features + 1]) 
        
        wq = torch.cat(wq_l, 0)     # torch.Size([rank * head, in_features + 1]) 
        wk = torch.cat(wk_r, 0)   # torch.Size([rank * head, in_features + 1]) 

        bq = wq[:, -1]
        bk = wk[:, -1]
        wq = wq[:, :-1]
        wk = wk[:, :-1]

        wv_l = []
        wo_r = []

        for i in range(num_heads):
            v, o = self.low_rank_approximation_vo(wvi[i, :, :], woi[i, :, :], x_dict[name1], scale_diag1, scale_diag_inv1, scale_diag2, vo_head_dim, name1, rr=24, visualize=False)
            wv_l.append(v) # torch.Size([rank, in_features])
            wo_r.append(o) # torch.Size([in_features, rank])

        wv = torch.cat(wv_l, 0) # torch.Size([rank * head, in_features])
        wo = torch.cat(wo_r, 1) # torch.Size([out_features, rank * head])

        bv = wv[:, -1]
        wv = wv[:, :-1]

        return FactorizedMatrix_UniSVD(
            q_w=wq.cpu().float(),
            k_w=wk.cpu().float(),
            v_w=wv.cpu().float(),
            o_w=wo.cpu().float(),
            q_b=bq.cpu().float(),
            k_b=bk.cpu().float(),
            v_b=bv.cpu().float(),
            o_b=o_b.cpu().float(),
            eq_rank=eq_rank1,  # Equivalent rank
            active_rank=rank1,  # Active rank
            qk_head_dim = qk_head_dim,
            vo_head_dim=vo_head_dim,
        )

