import gc

import torch
from torch import nn
import sys

from ._interface import BaseFactorization
from ._interface import FactorizedMatrix
from ._interface import Hookstuff
from ._interface import get_valid_layers
from tqdm import tqdm


@torch.no_grad()
def whitening(dev, raw_scale, name, alpha=0.0):
    """
    Computes the whitening matrix and its inverse with memory optimization.

    This version minimizes memory usage by:
    1. Using in-place operations to avoid creating large temporary matrices.
    2. Explicitly deleting intermediate tensors to free memory sooner.
    3. Removing redundant data type and device conversions.
    """
    # Use .clone() to create an explicit copy. This ensures that in-place
    # operations below do not modify the original tensor in `raw_scale`.
    raw_scale_mat = raw_scale[name].clone().float().to(dev)

    if alpha > 0.0:
        # Perform regularization in-place to avoid creating a temporary identity matrix.
        # 1. Calculate the mean of the diagonal.
        reg_term = torch.mean(raw_scale_mat.diag()) * alpha
        # 2. Scale the matrix by (1 - alpha).
        raw_scale_mat.mul_(1.0 - alpha)
        # 3. Add the regularization term to the diagonal.
        raw_scale_mat.diagonal().add_(reg_term)

    try:
        # Cholesky decomposition creates a new tensor `scale_diag`.
        scale_diag = torch.linalg.cholesky(raw_scale_mat)
    except torch.linalg.LinAlgError:
        # If the matrix is not positive definite, add a small identity shift
        # to the diagonal in-place. This avoids creating a full identity matrix.
        eigenvalues = torch.linalg.eigvalsh(raw_scale_mat)
        shift = -eigenvalues[0] + 1e-6
        raw_scale_mat.diagonal().add_(shift)
        del eigenvalues  # Free memory

        try:
            scale_diag = torch.linalg.cholesky(raw_scale_mat)
        except torch.linalg.LinAlgError:
            print(f"FATAL: Matrix '{name}' could not be made positive definite.")
            sys.exit()

    # `raw_scale_mat` is no longer needed, so its memory can be freed.
    del raw_scale_mat

    # To find the inverse of the triangular matrix `scale_diag`, we solve
    # L @ X = I. Creating an identity matrix `I` is necessary here.
    identity = torch.eye(scale_diag.shape[0], device=dev, dtype=scale_diag.dtype)
    scale_diag_inv = torch.linalg.solve_triangular(scale_diag, identity, upper=False)
    del identity  # Free memory

    # The returned tensors are already on the correct device with the float
    # data type, so the final `.float().to(dev)` calls are redundant.
    return scale_diag, scale_diag_inv


class KFAC_SVD_Hook(Hookstuff):
    def __init__(self, model, name_omit, dump_shape=False, white_list: list = [], name_prefix="", vision=False):
        super().__init__(model=model, name_omit=name_omit, dump_shape=dump_shape, white_list=white_list, name_prefix=name_prefix)
        self.vision = vision

    def _hook_fn(self, layer_name, last_feat=False):
        def get_scaling_mat(module, input, output):
            x = input[0].detach().clone()
            if x.dim() > 3:
                x = x.reshape(x.shape[0], -1, x.shape[-1])
            elif x.dim() == 2:
                x = x.unsqueeze(0)
            if self.dump_shape:
                self.input_shape[layer_name] = list(x.shape)
                self.input_shape[layer_name].extend([module.out_features, 0])
                return
            if last_feat:
                if "head" in layer_name:
                    self.model.last_feat = x
                return

            self.activation_cache[layer_name] = x

        return get_scaling_mat
    
    def _bw_hook_fn(self, layer_name):
        def get_scaling_mat_grad(module, ginput, goutput, last_feat=False):
            if not self.layer_trigger:
                self.layer_trigger = layer_name
            gout = goutput[0].detach().clone().float()
            if gout.dim() > 3:
                gout = gout.reshape(gout.shape[0], -1, gout.shape[-1])
            elif gout.dim() == 2:
                gout = gout.unsqueeze(0)

            x = self.activation_cache[layer_name].detach().float()
            seq_len = x.shape[1]
            cutoff = 0 if self.vision else int(seq_len*0.125)
            gout_middle = gout[:, cutoff:seq_len-cutoff, :]
            
            token_norms = torch.norm(gout_middle, dim=2, keepdim=True)  # [batch_size, seq_len, 1]

            # Calculate mean norm across all tokens (excluding zeros)
            mean_norms = token_norms.clone()
            mean_norms[token_norms < 1e-10] = 0  # Mask out zeros to avoid skewing the mean
            batch_mean_norms = mean_norms.sum(dim=1) / torch.maximum(
                (mean_norms > 0).float().sum(dim=1), 
                torch.ones_like((mean_norms > 0).float().sum(dim=1))
            )  # [batch_size, 1, 1] # not sure if this is correct

            # Ensure we have a valid mean and proper broadcasting
            batch_mean_norms = torch.maximum(batch_mean_norms, torch.ones_like(batch_mean_norms) * 1e-10)
            
            # Expand dimensions for proper broadcasting
            min_allowed_norms = (batch_mean_norms * 0.1).unsqueeze(1)  # [batch_size, 1, 1] -> [batch_size, 1, 1]  # 0.5
            max_allowed_norms = (batch_mean_norms * 10.0).unsqueeze(1)  # [batch_size, 1, 1] -> [batch_size, 1, 1]  # 1.5

            # First boost small non-zero tokens using element-wise operations
            boost_factors = torch.ones_like(token_norms)
            boost_mask = (token_norms < min_allowed_norms) & (token_norms > 1e-10)
            
            # Safe division with broadcasting
            boost_factors = torch.where(
                boost_mask,
                min_allowed_norms / torch.maximum(token_norms, torch.ones_like(token_norms) * 1e-10),
                boost_factors
            )

            # Then clip large tokens
            clip_factors = torch.ones_like(token_norms)
            clip_mask = token_norms > max_allowed_norms
            
            # Safe division with broadcasting
            clip_factors = torch.where(
                clip_mask,
                max_allowed_norms / torch.maximum(token_norms, torch.ones_like(token_norms) * 1e-10),
                clip_factors
            )

            # Combine both factors
            scale_factors = torch.minimum(boost_factors, clip_factors)

            # Apply scaling to all tokens
            gout_middle = gout_middle * scale_factors

            # 7. Use this extended diagonal in the column scaling calculation
            batch_column_scaling = x.transpose(1, 2) @ x
            batch_row_scaling = gout_middle.transpose(1, 2) @ gout_middle

            # Now take mean over batch dimension
            row_scaling = torch.mean(batch_row_scaling, dim=0).float()
            column_scaling = torch.mean(batch_column_scaling, dim=0).float()                  

            if layer_name not in self.row_scale:  # First run through each layer
                # Create pinned copies for first allocation
                self.buf_1[layer_name] = torch.zeros_like(row_scaling, device='cpu', pin_memory=True)
                self.buf_2[layer_name] = torch.zeros_like(column_scaling, device='cpu', pin_memory=True)
                self.row_scale[layer_name] =  torch.zeros_like(row_scaling, device='cpu', pin_memory=True)
                self.column_scale[layer_name] = torch.zeros_like(column_scaling, device='cpu', pin_memory=True)

                torch.cuda.synchronize()
                self.buf_1[layer_name].copy_(row_scaling, non_blocking=True)
                self.buf_2[layer_name].copy_(column_scaling, non_blocking=True)
            else:
                # Synchronize before accumulation to ensure prior operations completed
                if self.layer_trigger == layer_name:
                    torch.cuda.synchronize()
                self.row_scale[layer_name] += self.buf_1[layer_name]
                self.column_scale[layer_name] += self.buf_2[layer_name]

                self.buf_1[layer_name].copy_(row_scaling, non_blocking=True)
                self.buf_2[layer_name].copy_(column_scaling, non_blocking=True)
            
            del module, goutput, ginput, gout, batch_row_scaling, batch_column_scaling, x
            del normalized_batch_row_scaling #batch_row_diags, batch_row_diag_means,
            # gc.collect()  # -> this slows it down insanly, smth like 10x on small models

        return get_scaling_mat_grad


class KFAC_SVDFactorization(BaseFactorization):
    def __init__(self, vision, *args, **kwargs):
        super().__init__(vision=vision, *args, **kwargs)
        self.scaling_dict_gout = {}
        self.column_scaling_dict = {}
        self.row_scaling_dict = {}
        self.factorize_cache_dict: dict[str, FactorizedMatrix] = {}
        self.use_debug_cache = False
    
    @property
    def post_search_calibration(self):
        # if the factorization method requires recalibration after search
        # (e.g. because it uses the scaling statistics to determine the rank)
        return False if self._do_post_calibration == "default" else self._do_post_calibration

    def _compute_scaling(self, model, hook_module, name_prefix, calib_data, name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "):
        dev = self.dev
        
        extractor = KFAC_SVD_Hook(hook_module, name_omit, False, name_prefix=name_prefix, vision=self.vision, white_list=white_list)
        extractor.attach_hooks()
        extractor.attach_bw_hooks()

        # disable gradient for all layers
        for p in model.parameters():
            p.requires_grad = False
        
        # reenable gradient for relevant layers
        copied_modules = get_valid_layers(hook_module, name_omit, white_list=white_list)
        for _, module in copied_modules:
            if isinstance(module, nn.Linear):
                for n, p in module.named_parameters():
                    if not "bias" in n:
                        p.requires_grad = True

        if self.vision:
            loss_fn = torch.nn.CrossEntropyLoss()
            #forward pass to collect gradient importance scores
            for data, target in tqdm(calib_data, desc=tqdm_message):
                #model_inps, targets = mixup_fn(data, target)
                model_inps, targets = data.to(dev), target.to(dev)
                model_inps = model_inps.to(dev)
                out = model(model_inps)
                loss = loss_fn(out, targets.to(dev))
                loss.backward()
                del model_inps, targets, loss, out
        else:
            loss_fct = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-100)
            #forward pass to collect gradient importance scores
            for batch in tqdm(calib_data, desc=tqdm_message):
                with torch.autocast(device_type="cuda", enabled=False):
                    batch = {k: v.to(dev) for k, v in batch.items()}
                    out = model(**batch)    # use_cache=False
                    lm_logits = out.logits

                    loss = nn.CrossEntropyLoss()
                    if torch.isfinite(lm_logits).all():
                        shift_logits = lm_logits[:, :-1, :].contiguous()
                        shift_labels = batch["input_ids"][:, 1:].clone().contiguous()
                        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
                        loss.backward()
                    else:
                        print("Warning: Non-finite logits detected, skipping batch.")
                        continue
                    model.zero_grad()
                    with torch.cuda.device(torch.cuda.current_device()):
                        torch.cuda.empty_cache()
                    del batch, out, loss #,lm_logits, shift_logits, shift_labels

        for layer_name in extractor.row_scale:
            torch.cuda.synchronize()
            extractor.row_scale[layer_name] += extractor.buf_1[layer_name]
            extractor.column_scale[layer_name] += extractor.buf_2[layer_name]

        self.column_scaling_dict.update(extractor.column_scale)
        self.row_scaling_dict.update(extractor.row_scale)
        extractor.clear_hooks()

        del extractor.activation_cache, extractor.buf_1, extractor.buf_2, extractor.row_scale, extractor.column_scale, extractor
        gc.collect()

        return

    def _factorize_cleanup(self, name):
        self.row_scaling_dict[name] = None
        self.column_scaling_dict[name] = None
        del self.row_scaling_dict[name], self.column_scaling_dict[name]

    def _factorize_matrix(self, matrix, eq_rank, rank, name, dev, verbose=False):
        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return
        
        print("Factorizing matrix:", name) if verbose else None
        dev = torch.device(torch.cuda.current_device())

        if self.vision:
            column_scale_diag, column_scale_diag_inv = whitening(dev, self.column_scaling_dict, name, alpha=0.1)
            row_scale_diag, row_scale_diag_inv = whitening(dev, self.row_scaling_dict, name, alpha=0.7)
        else:
            column_scale_diag, column_scale_diag_inv = whitening(dev, self.column_scaling_dict, name, alpha=0.1)
            row_scale_diag, row_scale_diag_inv = whitening(dev, self.row_scaling_dict, name, alpha=0.3)

        dtype_final = matrix.dtype
        print(matrix.shape) if verbose else None
        temp_dtype = row_scale_diag.dtype
        mat_scaled = row_scale_diag @ matrix.to(dev).to(temp_dtype) @ column_scale_diag

        try:
            u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)
        except:
            if not torch.all(torch.isfinite(row_scale_diag)) or not torch.all(torch.isfinite(row_scale_diag_inv)):
                print(f"⚠️  Warning: Row scaling for layer {name} is non-finite. Replacing with identity.")
                row_scale_diag = torch.eye(row_scale_diag.shape[0], device=mat_scaled.device, dtype=temp_dtype)
                row_scale_diag_inv = torch.eye(row_scale_diag_inv.shape[0], device=mat_scaled.device, dtype=temp_dtype)
            if not torch.all(torch.isfinite(column_scale_diag)) or not torch.all(torch.isfinite(column_scale_diag_inv)):
                print(f"⚠️  Warning: Column scaling for layer {name} is non-finite. Replacing with identity.")
                column_scale_diag = torch.eye(column_scale_diag.shape[0], device=mat_scaled.device, dtype=temp_dtype)
                column_scale_diag_inv = torch.eye(column_scale_diag_inv.shape[0], device=mat_scaled.device, dtype=temp_dtype)
            
            mat_scaled = row_scale_diag @ matrix.to(dev).to(temp_dtype) @ column_scale_diag
            u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)

        s_val = torch.sqrt(s)  # half singular value
        mat_l = u * s_val.unsqueeze(0)
        mat_l = row_scale_diag_inv @ mat_l[:, :rank]
        mat_r = s_val.unsqueeze(1) * torch.matmul(vh, column_scale_diag_inv)
        mat_r = mat_r[:rank, :]

        self.row_scaling_dict[name] = self.row_scaling_dict[name].to('cpu', non_blocking=True)
        self.column_scaling_dict[name] = self.column_scaling_dict[name].to('cpu', non_blocking=True)
        return FactorizedMatrix(
            mat_l=mat_l.cpu().to(dtype_final),  # Left singular vectors
            mat_r=mat_r.cpu().to(dtype_final),  # Right singular vectors
            eq_rank=eq_rank,  # Equivalent rank
            active_rank=rank,  # Active rank
            singular_values=s
        )
