import gc
import sys
from tqdm import tqdm

import torch
import torch.nn as nn

from ._interface import BaseFactorization
from ._interface import FactorizedMatrix
from ._interface import Hookstuff

@torch.no_grad()
def whitening(dev, raw_scale, name):
    """
    Perform data whitening using Singular Value Decomposition (SVD).
    This approach is more numerically stable than Cholesky, especially
    when the covariance matrix is not strictly positive definite.
    """
    raw_scale_diag = raw_scale[name]
    raw_scale_diag = raw_scale_diag.double().to(dev)

    try:
        # Perform SVD: raw_scale_diag = U @ S @ Vh
        U, S, _ = torch.linalg.svd(raw_scale_diag, full_matrices=False)

        # Regularize small singular values to avoid division by zero
        eps = 1e-6
        S_inv_sqrt = torch.diag(torch.rsqrt(S + eps))
        S_sqrt = torch.diag(torch.sqrt(S + eps))

        # Whitening matrix: W = Vh.T @ S^{-1/2} @ Vh
        whitening_matrix = U @ S_sqrt
        dewhitening_matrix = S_inv_sqrt @ U.T

    except Exception as e:
        print(f"Warning: SVD failed for {name} - {e}")
        sys.exit()
    
    del raw_scale_diag, U, S

    wmf, dwmf = whitening_matrix.float().to(dev), dewhitening_matrix.float().to(dev)
    del whitening_matrix, dewhitening_matrix

    return wmf, dwmf


class SVD_LLMV2_Hook(Hookstuff):
    def _hook_fn(self, layer_name, last_feat=False):
        def get_scaling_mat(module, input, output):
            x = input[0].detach().float()
            if x.dim() > 3:     # e.g. for convnext
                x = x.reshape(x.shape[0], -1, x.shape[-1])
            elif x.dim() == 2:  # e.g. for mamba/opt
                x = x.unsqueeze(0)
            if self.dump_shape:
                self.input_shape[layer_name] = list(x.shape)
                self.input_shape[layer_name].extend([module.out_features, 0])
                return
            if last_feat:
                if "head" in layer_name:
                    self.model.last_feat = x.clone().detach()
                return
            out_prod = torch.matmul(x.transpose(1, 2), x)
            outpro_sum = torch.sum(out_prod, dim=0)

            if layer_name not in self.profile:  # First run through each layer
                self.profile[layer_name] = outpro_sum
            else:
                self.profile[layer_name] += outpro_sum

            del x, out_prod, outpro_sum, output
            torch.cuda.empty_cache()

        return get_scaling_mat


class SVD_LLMV2Factorization(BaseFactorization):
    @property
    def post_search_calibration(self):
        return False if self._do_post_calibration == "default" else self._do_post_calibration

    def _compute_scaling(self, model, hook_module, name_prefix, calib_data, name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "):
        torch.cuda.empty_cache()

        extractor = SVD_LLMV2_Hook(
            model=hook_module,
            name_omit=name_omit, dump_shape=False,
            name_prefix=name_prefix, white_list=white_list)
        extractor.attach_hooks()
        if self.vision:
            with torch.no_grad():
                for data, target in tqdm(calib_data, desc=tqdm_message + "(Activations for SVD-LLM)"):
                    model_inps, targets = mixup_fn(data, target) if mixup_fn is not None else (data, target)
                    model_inps = model_inps.to(self.dev)
                    model(model_inps)
                    del model_inps, targets
                # get shapes
                extractor.dump_shape = True
                dummy_input = torch.randn(20, 3, 224, 224).to(self.dev)
                model(dummy_input)
                for key, value in extractor.input_shape.items():
                    self.input_shapes[key] = value
                del dummy_input
        else:
            with torch.no_grad():
                for batch in tqdm(calib_data, desc=tqdm_message + "(Activations for SVD-LLM)"):
                    batch = {k: v.to(self.dev)
                            for k, v in batch.items()}
                    model(**batch)
                    del batch

        # extractor.clear_hooks()
        # for key, value in extractor.profile.items():
        #     self.scaling_dict[key] = value
        #     extractor.profile[key] = None  # free memory
        # del extractor
        
        # torch.cuda.empty_cache()
        # gc.collect()

        extractor.clear_hooks()
        for key, value in extractor.profile.items():
            self.scaling_dict[key] = value
        
        # CLEAR the dictionary on the extractor explicitly
        extractor.profile.clear() 
        
        del extractor
        # Force a collection to clean up the cyclic hook references
        gc.collect() 
        torch.cuda.empty_cache()
        return

    def _factorize_cleanup(self, name):
        if name in self.scaling_dict:
            del self.scaling_dict[name]
        torch.cuda.empty_cache()
        gc.collect()

    @torch.no_grad()
    def _factorize_matrix(self, matrix, eq_rank, rank, name, dev, verbose=False):
        raw_profile = self.scaling_dict
        scale_diag, scale_diag_inv = whitening(dev, raw_profile, name)
        
        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return

        dtype = matrix.dtype
        mat_scaled = torch.matmul(matrix.float().to(dev), scale_diag)
        # Convert to float32 to avoid "svd_cuda_gesvdj" error for attempting svd on float16

        u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)
        s_val = torch.sqrt(torch.diag(s))  # half singular value
        mat_l = u @ s_val
        
        mat_l = mat_l[:, :rank].cpu().to(dtype)
        mat_r = s_val @ torch.matmul(vh, scale_diag_inv)
        mat_r = mat_r[:rank, :].cpu().to(dtype)

        del mat_scaled, u, vh, scale_diag, scale_diag_inv, matrix
        del s_val
        gc.collect()
        torch.cuda.empty_cache()

        return FactorizedMatrix(
            mat_l=mat_l,  # Left singular vectors
            mat_r=mat_r,  # Right singular vectors
            eq_rank=eq_rank,  # Equivalent rank
            active_rank=rank,  # Active rank
            singular_values=s
        )
