import gc
import torch

from ._interface import BaseFactorization_UniSVD
from ._interface import FactorizedMatrix, FactorizedMatrix_UniSVD
from ._interface import Hookstuff
from .svd_llm import whitening


def lw_shrinkage(S):
    """Compute the Ledoit-Wolf shrinkage estimator of the Gramm matrix."""
    n_features = S.shape[0]

    # Compute the shrinkage target (identity matrix)
    target = torch.eye(n_features, device="cuda") * torch.mean(torch.diag(S))

    # Approx the optimal shrinkage intensity
    shrinkage_intensity = torch.clip(
        1e10 * (torch.trace(S) - torch.trace(target)) / (torch.norm(S - target) ** 2),
        0,
        1,
    )

    # Compute the shrunk Gramm matrix
    S_shrinked = (1 - shrinkage_intensity) * S + shrinkage_intensity * target

    return S_shrinked


class Flar_Hook(Hookstuff):
    def _hook_fn(self, layer_name, last_feat=False):
        def get_scaling_mat(module, input, output):
            x = input[0].detach().clone().double()
            if x.dim() > 3:
                x = x.reshape(x.shape[0], -1, x.shape[-1])
            elif x.dim() == 2:
                x = x.unsqueeze(0)

            self.x_dict[layer_name] = x.cpu()
            if self.dump_shape:
                self.input_shape[layer_name] = list(x.shape)
                self.input_shape[layer_name].extend([module.out_features, 0])
                return
            if last_feat:
                if "head" in layer_name:
                    self.model.last_feat = x.clone()
                return
            out_prod = torch.matmul(x.transpose(1, 2), x)
            outpro_sum = torch.mean(out_prod, dim=0)
            outpro_sum = lw_shrinkage(outpro_sum)
            outpro_sum = outpro_sum.cpu()

            if layer_name not in self.profile:  # First run through each layer
                self.profile[layer_name] = outpro_sum
            else:
                self.profile[layer_name] += outpro_sum

            del module, input, x, out_prod, outpro_sum, output

        return get_scaling_mat


class FLAR_SVDFactorization(BaseFactorization_UniSVD):
    def compute_scaling(self, model, name_omit, calib_data, mixup_fn, white_list=[]):
        print("\nObtaining activations for FLAR SVD")
        dev = torch.device(torch.cuda.current_device())

        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        model = model.eval().to(dev) 

        extractor = Flar_Hook(model, name_omit, white_list=white_list)
        extractor.attach_hooks()

        with torch.no_grad():
            for data, target in calib_data:
                model_inps, targets = mixup_fn(data, target)
                model_inps = model_inps.to(dev)
                model(model_inps)
                del model_inps, targets

        extractor.clear_hooks()
        for key, value in extractor.profile.items():
            self.scaling_dict[key] = value
        
        for key, value in extractor.x_dict.items():
            self.x_dict[key] = value

        del extractor 

        shapes_getter = Flar_Hook(model, name_omit, True, white_list=white_list)
        shapes_getter.attach_hooks()
        dummy_input = torch.randn(20, 3, 224, 224).to(dev)
        model(dummy_input)
        shapes_getter.clear_hooks()
        for key, value in shapes_getter.input_shape.items():
            self.input_shapes[key] = value
        del shapes_getter, dummy_input

        gc.collect()

        return

    def _factorize_matrix(self, matrix, name, eq_rank, rank, dev):
        raw_profile = self.scaling_dict
        scale_diag, scale_diag_inv = whitening(dev, raw_profile, name)
        scale_diag, scale_diag_inv = scale_diag.float(), scale_diag_inv.float()
        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return

        mat_scaled = torch.matmul(matrix.to(dev), scale_diag)

        u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)
        s_val = torch.sqrt(torch.diag(s))  # half singular value
        mat_l = u @ s_val
        mat_l = mat_l[:, :rank]
        mat_r = s_val @ torch.matmul(vh, scale_diag_inv)
        mat_r = mat_r[:rank, :]

        return FactorizedMatrix(
            mat_l=mat_l.cpu().float(),  # Left singular vectors
            mat_r=mat_r.cpu().float(),  # Right singular vectors
            eq_rank=eq_rank,  # Equivalent rank
            active_rank=rank,  # Active rank
        )
    
    def generate_steps(self, a: int, b: int, steps: int = 12) -> list:
        step_size = (b - a) / (steps - 1)
        return [int(round(a + step_size * i)) for i in range(steps)]
    
    def merge_bias(self, W, b):
        return torch.cat([W, b.unsqueeze(1)], dim=1)

    def robust_cholesky(self, M,
                        min_eig_eps=1e-6,
                        jitter_init=1e-8,
                        jitter_mult=10.0,
                        max_tries=6):

        dev = M.device
        dt = M.dtype

        M = 0.5 * (M + M.transpose(-1, -2))

        if not torch.isfinite(M).all():
            M = torch.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)

        I = torch.eye(M.shape[-1], device=dev, dtype=dt)
        jitter = jitter_init
        for _ in range(max_tries):
            try:
                return torch.linalg.cholesky(M + jitter * I)
            except RuntimeError:
                jitter *= jitter_mult

        evals, evecs = torch.linalg.eigh(M) 
        evals_clamped = torch.clamp(evals, min=min_eig_eps)
        M_fixed = (evecs * evals_clamped.unsqueeze(-2)) @ evecs.transpose(-1, -2)
        M_fixed = 0.5 * (M_fixed + M_fixed.transpose(-1, -2))  

        try:
            return torch.linalg.cholesky(M_fixed)
        except RuntimeError:
            L_alt = evecs @ torch.diag_embed(torch.sqrt(evals_clamped))
            return L_alt

    def tail_sum(self, S, n):
        return S[-n:].sum()

    def low_rank_approximation_qk(self, A: torch.Tensor, A2: torch.Tensor,
                                       X: torch.Tensor,
                                       S1: torch.Tensor, S1_i: torch.Tensor,
                                       r: int, name: str, rr: int, visualize: bool=False) -> torch.Tensor:

        U1, S1, Vh1 = torch.linalg.svd(A, full_matrices=False)
        U2, S2, Vh2 = torch.linalg.svd(A2, full_matrices=False)

        if visualize:
            comcat = A2.transpose(-1, -2) @ A
            U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)

            A_inv_v = torch.linalg.pinv(A[:,:-1])
            A2_inv_v = torch.linalg.pinv(A2[:,:-1].T)

            U_r_1_v = U1[:, :rr]  # (m, r)
            S_r_1_v = S1[:rr]  # (r,)
            Vh_r_1_v = Vh1[:rr, :]  # (r, n)

            U_r_2_v = U2[:, :rr]  # (m, r)
            S_r_2_v = S2[:rr]  # (r,)
            Vh_r_2_v = Vh2[:rr, :]  # (r, n)

            U_r_3_v = U3[:, :rr]  # (m, r)
            S_r_3_v = S3[:rr]  # (r,)
            Vh_r_3_v = Vh3[:rr, :]  # (r, n)

            A_r_1_v = U_r_1_v * S_r_1_v @ Vh_r_1_v
            A_r_2_v = U_r_2_v * S_r_2_v @ Vh_r_2_v
            A_r_3_v = U_r_3_v * S_r_3_v @ Vh_r_3_v

            error1_visual = torch.norm(A[:,:-1]- A_r_1_v[:,:-1] )  # Frobenius norm

            error2_visual = torch.norm(A2[:,:-1]- A_r_2_v[:,:-1] )  # Frobenius norm

            error3_visual = torch.norm(A[:,:-1]- A2_inv_v @ A_r_3_v[:-1,:-1])

            error4_visual = torch.norm(A2[:,:-1].T- A_r_3_v[:-1,:-1] @ A_inv_v)

            error_final = min(error1_visual, error2_visual)

        # For UniSVD
        U_r_1 = U1[:, :r]  # (m, r)
        S_r_1 = S1[:r]  # (r,)
        Vh_r_1 = Vh1[:r, :]  # (r, n)
        A_r_1 = U_r_1 * S_r_1 @ Vh_r_1

        U_r_2 = U2[:, :r]  # (m, r)
        S_r_2 = S2[:r]  # (r,)
        Vh_r_2 = Vh2[:r, :]  # (r, n)
        A_r_2 = U_r_2 * S_r_2 @ Vh_r_2

        # For combined decomposition
        # comcat = A2.transpose(-1, -2) @ A
        # U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)
        # U_r_3 = U3[:, :r]
        # S_r_3 = S3[:r]  # (r,)
        # Vh_r_3 = Vh3[:r, :]  # (r, n)
        # A_r_3 = U_r_3 * S_r_3 @ Vh_r_3

        error1 = torch.norm(A[:,:-1] - A_r_1[:,:-1] )  # Frobenius norm
        error2 = torch.norm(A2[:,:-1] -  A_r_2[:,:-1] )  # Frobenius norm

        min_error = min(error1, error2) 
        if min_error == error1:
            query = Vh_r_1
            key_0 = (U_r_1 * S_r_1).transpose(-1, -2)
            key = key_0 @ A2
        elif min_error == error2:
            query_0 = (U_r_2 * S_r_2).transpose(-1, -2)
            query = query_0 @ A
            key = Vh_r_2

        return query, key

    def low_rank_approximation_vo(self, A: torch.Tensor, A2: torch.Tensor,
                                      X: torch.Tensor,
                                      S1: torch.Tensor, S1_i: torch.Tensor,
                                      S2: torch.Tensor, r: int, name: str, rr:int, visualize: bool=False) -> torch.Tensor:

        U1, S1, Vh1 = torch.linalg.svd(A, full_matrices=False)
        U2, S2, Vh2 = torch.linalg.svd(A2, full_matrices=False)

        if visualize:
            # combined matrices
            comcat = A2 @ A
        
            U3, S3, Vh3 = torch.linalg.svd(comcat, full_matrices=False)
        
            # Calculate for comparison with the combined decomposition method

            A_inv_v = torch.linalg.pinv(A[:,:-1])
            A2_inv_v = torch.linalg.pinv(A2)

            U_r_1_v = U1[:, :rr]  # (m, r)
            S_r_1_v = S1[:rr]  # (r,)
            Vh_r_1_v = Vh1[:rr, :]  # (r, n)

            U_r_2_v = U2[:, :rr]  # (m, r)
            S_r_2_v = S2[:rr]  # (r,)
            Vh_r_2_v = Vh2[:rr, :]  # (r, n)

            U_r_3_v = U3[:, :rr]  # (m, r)
            S_r_3_v = S3[:rr]  # (r,)
            Vh_r_3_v = Vh3[:rr, :]  # (r, n)

            A_r_1_v = U_r_1_v * S_r_1_v @ Vh_r_1_v
            A_r_2_v = U_r_2_v * S_r_2_v @ Vh_r_2_v
            A_r_3_v = U_r_3_v * S_r_3_v @ Vh_r_3_v

            error1_visual = torch.norm(A[:,:-1] - A_r_1_v[:,:-1] ) 

            error2_visual = torch.norm(A2 - A_r_2_v )  

            error3_visual = torch.norm(A[:,:-1] - A2_inv_v @ A_r_3_v[:,:-1])

            error4_visual = torch.norm(A2 - A_r_3_v[:,:-1] @ A_inv_v)
        
        ###########################################

        # For UniSVD
        U_r_1 = U1[:, :r]  # (m, r)
        S_r_1 = S1[:r]  # (r,)
        Vh_r_1 = Vh1[:r, :]  # (r, n)
        A_r_1 = U_r_1 * S_r_1 @ Vh_r_1

        U_r_2 = U2[:, :r]  # (m, r)
        S_r_2 = S2[:r]  # (r,)
        Vh_r_2 = Vh2[:r, :]  # (r, n)
        A_r_2 = U_r_2 * S_r_2 @ Vh_r_2

        # For combined decomposition
        # U_r_3 = U3[:, :r]
        # S_r_3 = S3[:r]  # (r,)
        # Vh_r_3 = Vh3[:r, :]  # (r, n)
        # A_r_3 = U_r_3 * S_r_3 @ Vh_r_3

        error1 = torch.norm(A[:,:-1] - A_r_1[:,:-1])  # Frobenius norm
        error2 = torch.norm(A2 - A_r_2)  # Frobenius norm
        
        min_error = min(error1, error2)

        if min_error == error1:
            value = Vh_r_1
            output = A2 @ (U_r_1 * S_r_1)

        elif min_error == error2:
            value = Vh_r_2 @ A
            output = U_r_2 * S_r_2

        return value, output
    
    
    def _factorize_matrix_unisvd(self, matrix1, matrix2, bias1, bias2, name1, name2, eq_rank1, eq_rank2, rank1, rank1_2, rank2, rank2_2, dev, 
                           head_dim=64, num_heads=12, num_layers=12):
        
        raw_profile = self.scaling_dict
        x_dict = self.x_dict
        
        scale_diag1, scale_diag_inv1 = whitening(dev, raw_profile, name1)
        scale_diag1, scale_diag_inv1 = scale_diag1.float(), scale_diag_inv1.float()

        scale_diag2, scale_diag_inv2 = whitening(dev, raw_profile, name2)
        scale_diag2, scale_diag_inv2 = scale_diag2.float(), scale_diag_inv2.float()
        
        if rank1 == 0:
            rank1 = eq_rank1
        elif rank1 > eq_rank1:
            print(f"Warning: {name1} rank is larger than equivalent rank!")
            return
        
        mat_scaled1 = matrix1.to(dev)
        mat_scaled2 = matrix2.to(dev)

        qkv_w = mat_scaled1
        q_w, k_w, v_w = qkv_w.chunk(3, dim=0)
        qkv_b = bias1
        q_b, k_b, v_b = qkv_b.chunk(3, dim=0)

        o_w = mat_scaled2
        o_b = bias2

        q_w = self.merge_bias(q_w, q_b)
        k_w = self.merge_bias(k_w, k_b)
        v_w = self.merge_bias(v_w, v_b)

        in_feats = q_w.shape[1]
        dim = int(head_dim * num_heads)

        wq = q_w.reshape(num_heads, head_dim, in_feats)
        wk = k_w.reshape(num_heads, head_dim, in_feats)
        wvi = v_w.reshape(num_heads, head_dim, in_feats)
        woi = o_w.reshape(dim, num_heads, head_dim).transpose(0, 1)

        r_qk = self.generate_steps(rank1, rank1_2, num_layers) 
        r_vo = self.generate_steps(rank2, rank2_2, num_layers)

        qk_head_dim = r_qk[int(name1.split(".")[1])]
        vo_head_dim = r_vo[int(name1.split(".")[1])]

        wq_l = []
        wk_r = []

        for i in range(num_heads):
            l, r = self.low_rank_approximation_qk(wq[i, :, :], wk[i, :, :], x_dict[name1], scale_diag1, scale_diag_inv1, qk_head_dim, name1, rr=24, visualize=False)
            wq_l.append(l) # torch.Size([rank, in_features + 1]) 
            wk_r.append(r) # torch.Size([rank, in_features + 1]) 
        
        wq = torch.cat(wq_l, 0)     # torch.Size([rank * head, in_features + 1]) 
        wk = torch.cat(wk_r, 0)   # torch.Size([rank * head, in_features + 1]) 

        bq = wq[:, -1]
        bk = wk[:, -1]
        wq = wq[:, :-1]
        wk = wk[:, :-1]

        wv_l = []
        wo_r = []

        for i in range(num_heads):
            v, o = self.low_rank_approximation_vo(wvi[i, :, :], woi[i, :, :], x_dict[name1], scale_diag1, scale_diag_inv1, scale_diag2, vo_head_dim, name1, rr=24, visualize=False)
            wv_l.append(v) # torch.Size([rank, in_features])
            wo_r.append(o) # torch.Size([in_features, rank])

        wv = torch.cat(wv_l, 0) # torch.Size([rank * head, in_features])
        wo = torch.cat(wo_r, 1) # torch.Size([out_features, rank * head])

        bv = wv[:, -1]
        wv = wv[:, :-1]

        return FactorizedMatrix_UniSVD(
            q_w=wq.cpu().float(),
            k_w=wk.cpu().float(),
            v_w=wv.cpu().float(),
            o_w=wo.cpu().float(),
            q_b=bq.cpu().float(),
            k_b=bk.cpu().float(),
            v_b=bv.cpu().float(),
            o_b=o_b.cpu().float(),
            eq_rank=eq_rank1,  # Equivalent rank
            active_rank=rank1,  # Active rank
            qk_head_dim = qk_head_dim,
            vo_head_dim=vo_head_dim,
        )

