import gc
import sys
from tqdm import tqdm
import os, pickle
from all_utils.other import get_model_name

import torch
import torch.nn as nn

from ._interface import BaseFactorization
from ._interface import FactorizedMatrix
from ._interface import Hookstuff

@torch.no_grad()
def whitening(dev, raw_scale, name):
    """
    The Cholesky decomposition is a convenient way to perform data whitening.
    The columns of the triangular matrix L form an orthogonal basis.
    Being triangular makes it easy and fast to handle and provides a numerically
    stable and efficient way to compute the inverse of the covariance matrix.
    """
    raw_scale_diag = raw_scale[name]
    raw_scale_diag = raw_scale_diag.double().to(dev)
    # Perform Cholesky to obtain scaling matrix
    try:
        scale_diag = torch.linalg.cholesky(raw_scale_diag)
    except Exception:
        eigenvalues = torch.linalg.eigvalsh(raw_scale_diag)
        raw_scale_diag += (-eigenvalues[0] + 1e-6) * torch.eye(
            raw_scale_diag.shape[0]
        ).to(dev)
        try:
            scale_diag = torch.linalg.cholesky(raw_scale_diag)
        except Exception:
            # raise ValueError(f"Matrix not positive!: {name}")
            print(f"Warning: {name} is not positive!")
            # scale_diag = torch.linalg.qr(raw_scale_diag).R
            sys.exit()
        eigenvalues = None
        del eigenvalues
    raw_scale_diag = None
    del raw_scale_diag

    # Calculate the inverse of the scaling matrix
    try:
        scale_diag_inv = torch.linalg.inv(scale_diag)
    except Exception:
        # scale_diag += 1e-4 * torch.eye(scale_diag.shape[0]).to(dev)
        scale_diag = torch.where(
            torch.isnan(scale_diag),
            torch.tensor(1e-10, device=scale_diag.device),
            scale_diag,
        )
        try:
            scale_diag_inv = torch.linalg.inv(scale_diag)
        except Exception:
            # only try pinv as a last resort
            try:
                scale_diag_inv = torch.linalg.pinv(scale_diag)
            except Exception:
                # raise ValueError(f"Cannot invert matrix: {name}")
                print(f"Warning: {name} is not full rank!")
                sys.exit()

    return scale_diag.float().to(dev), scale_diag_inv.float().to(dev)

def whitening_fast(dev, raw_scale, name):
    """
    An accelerated version of the whitening function.

    Key Optimizations:
    1.  **Fast Triangular Inversion**: Replaces the expensive `torch.linalg.pinv`
        with `torch.linalg.solve_triangular`. Solving a triangular system is
        much faster (up to 10x or more) than inverting a general dense matrix.
    2.  **Efficient Identity Matrix Creation**: Creates the identity matrix directly
        on the target device and with the correct data type.
    """
    raw_scale_mat = raw_scale[name]
    
    # Casting to double can improve numerical stability for Cholesky at the cost of memory/speed.
    # Consider keeping float32 if precision is not an issue.
    raw_scale_mat = raw_scale_mat.double().to(dev)

    # Perform Cholesky decomposition to get the lower triangular matrix L
    try:
        scale_diag = torch.linalg.cholesky(raw_scale_mat)
    except torch.linalg.LinAlgError:
        # If the matrix is not positive definite, add a small regularization term
        # (also known as a "jitter") to the diagonal to make it so.
        # This is generally faster and more stable than computing all eigenvalues.
        eigenvalues = torch.linalg.eigvalsh(raw_scale_mat)
        raw_scale_mat += (-eigenvalues[0] + 1e-6) * torch.eye(
            raw_scale_mat.shape[0]
        ).to(dev)
        jitter = 1e-6 * torch.eye(raw_scale_mat.shape[0], device=dev, dtype=raw_scale_mat.dtype)
        try:
            #scale_diag = torch.linalg.cholesky(raw_scale_mat + jitter)
            scale_diag = torch.linalg.cholesky(raw_scale_mat)
        except torch.linalg.LinAlgError:
            eigenvalues = torch.linalg.eigvalsh(raw_scale_mat)
            raw_scale_mat += (-eigenvalues[0] + 1e-6) * torch.eye(
                raw_scale_mat.shape[0]
            ).to(dev)
            try:
                scale_diag = torch.linalg.cholesky(raw_scale_mat)
            except:
                print(f"FATAL: Matrix '{name}' could not be made positive definite.")
                sys.exit()

    # --- MAJOR OPTIMIZATION: Use a triangular solver for the inverse ---
    # To find the inverse of L, we solve L @ X = I, where I is the identity matrix.
    # This is much faster than using torch.linalg.inv() or pinv() on a triangular matrix.
    identity = torch.eye(scale_diag.shape[0], device=dev, dtype=scale_diag.dtype)
    scale_diag_inv = torch.linalg.solve_triangular(scale_diag, identity, upper=False)

    return scale_diag.float(), scale_diag_inv.float()


class SVD_LLM_Hook(Hookstuff):
    def _hook_fn(self, layer_name, last_feat=False):
        def get_scaling_mat(module, input, output):
            x = input[0].detach().float()
            if x.dim() > 3:     # e.g. for convnext
                x = x.reshape(x.shape[0], -1, x.shape[-1])
            elif x.dim() == 2:  # e.g. for mamba/opt
                x = x.unsqueeze(0)
            if self.dump_shape:
                self.input_shape[layer_name] = list(x.shape)
                self.input_shape[layer_name].extend([module.out_features, 0])
                return
            if last_feat:
                if "head" in layer_name:
                    self.model.last_feat = x.clone()
                return
            out_prod = torch.matmul(x.transpose(1, 2), x)
            outpro_sum = torch.sum(out_prod, dim=0)

            if layer_name not in self.profile:  # First run through each layer
                self.profile[layer_name] = outpro_sum
            else:
                self.profile[layer_name] += outpro_sum

            del x, out_prod, outpro_sum, output
            torch.cuda.empty_cache()

        return get_scaling_mat


class SVD_LLMFactorization(BaseFactorization):
    @property
    def post_search_calibration(self):
        return False if self._do_post_calibration == "default" else self._do_post_calibration

    def _compute_scaling(self, model, hook_module, name_prefix, calib_data, name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "):
        torch.cuda.empty_cache()
        model = model.eval().to(self.dev)
        extractor = SVD_LLM_Hook(
            model=hook_module,
            name_omit=name_omit, dump_shape=False,
            name_prefix=name_prefix, white_list=white_list)
        extractor.attach_hooks()
        if self.vision:
            with torch.no_grad():
                for data, target in tqdm(calib_data, desc=tqdm_message + "(Activations for SVD-LLM)"):
                    model_inps, targets = mixup_fn(data, target) if mixup_fn is not None else (data, target)
                    model_inps = model_inps.to(self.dev)
                    model(model_inps)
                    del model_inps, targets
                # get shapes
                extractor.dump_shape = True
                dummy_input = torch.randn(20, 3, 224, 224).to(self.dev)
                model(dummy_input)
                for key, value in extractor.input_shape.items():
                    self.input_shapes[key] = value
                del dummy_input
        else:
            with torch.no_grad():
                for batch in tqdm(calib_data, desc=tqdm_message + "(Activations for SVD-LLM)"):
                    batch = {k: v.to(self.dev)
                            for k, v in batch.items()}
                    model(**batch)
                    del batch

        extractor.clear_hooks()
        for key, value in extractor.profile.items():
            self.scaling_dict[key] = value
        del extractor
        
        torch.cuda.empty_cache()
        gc.collect()
        return

    def _factorize_cleanup(self, name):
        if name in self.scaling_dict:
            del self.scaling_dict[name]
        torch.cuda.empty_cache()
        gc.collect()

    def _factorize_matrix(self, matrix, eq_rank, rank, name, dev, verbose=False):
        raw_profile = self.scaling_dict
        if self.vision:
            scale_diag, scale_diag_inv = whitening(dev, raw_profile, name)
        else:
            scale_diag, scale_diag_inv = whitening_fast(dev, raw_profile, name)
        
        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return

        dtype = matrix.dtype
        mat_scaled = torch.matmul(matrix.float().to(dev), scale_diag)
        # Convert to float32 to avoid "svd_cuda_gesvdj" error for attempting svd on float16

        u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)
        s_val = torch.sqrt(s)  # half singular value
        mat_l = u * s_val.unsqueeze(0)
        
        mat_l = mat_l[:, :rank].cpu().to(dtype)
        mat_r = s_val.unsqueeze(1) * torch.matmul(vh, scale_diag_inv)
        mat_r = mat_r[:rank, :].cpu().to(dtype)

        torch.cuda.empty_cache()
        gc.collect()

        return FactorizedMatrix(
            mat_l=mat_l,  # Left singular vectors
            mat_r=mat_r,  # Right singular vectors
            eq_rank=eq_rank,  # Equivalent rank
            active_rank=rank,  # Active rank
            singular_values=s
        )
