import gc

import torch
from torch import nn

from ._interface import BaseFactorization
from ._interface import FactorizedMatrix
from tqdm import tqdm

import cupy as cp
from cupyx.scipy.sparse.linalg import svds, LinearOperator
#from scipy.sparse.linalg import LinearOperator, svds
import numpy as np

from ._interface import get_valid_layers

GRAD_ACC_STEPS = 32 # Number of gradient accumulation steps, can be adjusted based on memory constraints

def is_positive_definite(matrix: np.ndarray) -> bool:
    try:
        np.linalg.cholesky(matrix)
        return True
    except np.linalg.LinAlgError:
        return False

def matrix_sqrt_invsqrt(X: torch.Tensor, lmbd: float = 1e-6, alpha_increase_factor=1e-1, max_reg_tries=10, reg_alpha: float = 1e-1):
    def regularize_factor(XF_reg, factor, diag_mean, max_reg_tries, reg_alpha, eye):
        for i in range(max_reg_tries + 1):
            if is_positive_definite(XF_reg):
                print(f"  Factor is positive definite (alpha={reg_alpha:.2e})")
                break
            if i == max_reg_tries:
                raise RuntimeError(f"Failed to regularize factor after {max_reg_tries} attempts.")
            print(f"  Regularizing factor (try {i+1}, alpha={reg_alpha:.2e})")
            reg_alpha += alpha_increase_factor
            XF_reg = (1 - reg_alpha) * factor + reg_alpha * eye * diag_mean
        return XF_reg
    out_features, in_features = X.shape
    XF = X if not type(X) is torch.Tensor else X.cpu().numpy()
    eye_X = np.eye(in_features, dtype=np.float32)
    diag_mean_X = max(np.mean(np.diag(XF)), 1e-6)
    XF_reg = np.copy(XF)
    XF_reg = regularize_factor(XF_reg, X, diag_mean_X, max_reg_tries, reg_alpha, eye_X)

    try:
        X_chol = np.linalg.cholesky(XF_reg)
        print("  Cholesky decomposition successful.")
    except np.linalg.LinAlgError as e:
        print(f"ERROR: Cholesky decomposition failed: {e}")
        raise e
    try:
        inv_X_chol = np.linalg.inv(X_chol)
        print("  Cholesky factor inverses computed.")
    except np.linalg.LinAlgError as e:
        print(f"ERROR: Failed to invert Cholesky factors: {e}")
        raise e
    return torch.tensor(X_chol, dtype=X.dtype), torch.tensor(inv_X_chol, dtype=X.dtype)


def get_kron_factors(list_of_grads, top_k=10, layer_name="linear", device="cuda:0", chunk_size=4):
    """
    Perform parallel by input layers Fisher Matrix approximation in the form of Kronecker Decomposition.
    """

    def matvec(vec, grad_vectors, chunk_size=4):
        k, m, n = grad_vectors.shape
        V = vec.reshape(n, n, order="F")
        result = cp.zeros((m, m), dtype=cp.float32)
        for i in range(0, k, chunk_size):
            chunk = grad_vectors[i : i + chunk_size]
            prod = chunk @ V @ chunk.transpose(0, 2, 1)
            result += cp.sum(prod, axis=0)
        return (result / k).T.ravel()

    def r_matvec(vec, grad_vectors, chunk_size=4):
        k, m, n = grad_vectors.shape
        V = vec.reshape(m, m, order="F")
        result = cp.zeros((n, n), dtype=cp.float32)
        for i in range(0, k, chunk_size):
            chunk = grad_vectors[i : i + chunk_size]
            prod = chunk.transpose(0, 2, 1) @ V @ chunk
            result += cp.sum(prod, axis=0)
        return (result / k).T.ravel()

    device_id = 0
    num_devices = cp.cuda.runtime.getDeviceCount()
    device_pool = [cp.cuda.Device(i) for i in range(num_devices)]

    m, n = list_of_grads[0].shape

    # ADDDED
    total_grad_norm = 0.0
    for i, g in enumerate(list_of_grads):
        # Check for NaN or Inf values first.
        if not torch.all(torch.isfinite(g)):
            print(f"⚠️  Warning: Gradients for layer {layer_name} {i}/{len(list_of_grads)} are non-finite. Replacing with small value.")
            list_of_grads[i] = torch.where(torch.isfinite(g) == False, torch.tensor(1e-15, device=g.device), g)
        if not torch.all(torch.isfinite(list_of_grads[i])):
            print("Still not finite after replacement, skipping SVD.")
        total_grad_norm += torch.sum(torch.abs(g))

    # If gradients are non-finite OR all-zero, skip SVD and return zero factors.
    if total_grad_norm < 1e-9:
        print(f"⚠️  Warning: Gradients for layer {layer_name} are all-zero. Skipping SVD.")
        zero_factor_m = torch.zeros((m, m), dtype=torch.float32)
        zero_factor_n = torch.zeros((n, n), dtype=torch.float32)
        return zero_factor_m, zero_factor_n
    try:
        with device_pool[device_id]:
            print(n, m)
            grad_vectors = cp.stack([cp.asarray(grad).reshape(m, n, order="F") for grad in list_of_grads], dtype=cp.float32)
            linop = LinearOperator(
                shape=(m * m, n * n),
                matvec=lambda vec: matvec(vec, grad_vectors),
                rmatvec=lambda vec: r_matvec(vec, grad_vectors),
                dtype=cp.float32,
            )

            u, s, vt = svds(linop, k=top_k, return_singular_vectors=True)
            print(f"✔ Layer {layer_name} on device {device_id} done | singular values: {s}")
            sidx = cp.argsort(-s)
            s = s[sidx]
            u = u[:, sidx]
            v = vt[sidx, :].T

            XF = (u[:, 0] * s[0] ** 0.5).reshape(m, m, order="F")
            YF = (s[0] ** 0.5 * v[:, 0]).reshape(n, n, order="F")

            return torch.tensor(XF.get(), dtype=torch.float32), torch.tensor(YF.get(), dtype=torch.float32)
    except:
        print(f"⚠️  Warning: SVD failed for layer {layer_name}. Returning zero factors.")
        zero_factor_m = torch.zeros((m, m), dtype=torch.float32)
        zero_factor_n = torch.zeros((n, n), dtype=torch.float32)
        return zero_factor_m, zero_factor_n


class GFWSVDFactorization(BaseFactorization): # Removed BaseFactorization inheritance if not needed
    def __init__(self, alpha=1, processing_chunk_size=64,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.grad_dict = {}
        self.scaling_matrices = {}
        # OPTIMIZATION: Tunable parameter to balance speed and memory
        self.processing_chunk_size = processing_chunk_size

    @property
    def post_search_calibration(self):
        return False if self._do_post_calibration == "default" else self._do_post_calibration

    def _compute_scaling(self, model, hook_module, name_prefix, calib_data, name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "):
        model = model.to(self.dev)

        for p in model.parameters():
            p.requires_grad = False
        
        # reenable gradient for relevant layers
        copied_modules = get_valid_layers(hook_module, name_omit, white_list=white_list)
        for _, module in copied_modules:
            if isinstance(module, nn.Linear):
                for n, p in module.named_parameters():
                    if not "bias" in n:
                        p.requires_grad = True
        
        idx = 1
        for batch in tqdm(calib_data, desc=tqdm_message + " (generalized fisher information)"):
            # Forward and backward pass logic remains the same
            if self.vision:
                loss_fn = nn.CrossEntropyLoss()
                data, target = batch
                if mixup_fn is not None:
                    model_inputs, target_mix = mixup_fn(data, target)
                model_inputs = data.to(self.dev, non_blocking=True)
                target_mix = target.to(self.dev, non_blocking=True)
                out = model(model_inputs)
                loss = loss_fn(out, target_mix)
                batch_dim = data.shape[0]
            else:
                input_ids = batch["input_ids"].to(self.dev, non_blocking=True)
                out = model(input_ids=input_ids[:, :-1], labels=input_ids[:, 1:])
                loss = out.loss
                batch_dim = input_ids.shape[0]
            loss.backward()
            if idx % GRAD_ACC_STEPS == 0 or batch_dim > GRAD_ACC_STEPS:
                for name, module in copied_modules:
                    key = name_prefix + name 
                    if module.weight.grad is not None:
                        if batch_dim > GRAD_ACC_STEPS:
                            grad_cpu = module.weight.grad.detach().cpu()
                        else:
                            grad_cpu = module.weight.grad.detach().cpu() / GRAD_ACC_STEPS
                        # print(grad_cpu.dtype)
                        # Ensure grad_cpu is a tensor
                        #grad_cpu = extractor.profile_gout[name].cpu() / 32
                        self.grad_dict.setdefault(key, []).append(grad_cpu)
                        idx = 1
            else:
                idx += 1
        if idx != 1:
            for name, module in copied_modules:
                key = name_prefix + name 
                if module.weight.grad is not None:
                    grad_cpu = module.weight.grad.detach().cpu() / (idx-1)
                    #grad_cpu = extractor.profile_gout[name].cpu() / (idx-1)
                    self.grad_dict.setdefault(key, []).append(grad_cpu)
        model.zero_grad(set_to_none=True)
        gc.collect()
        torch.cuda.empty_cache()
        model = model.eval()
        return

    def _factorize_cleanup(self, name):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if name in self.grad_dict:
            del self.grad_dict[name]
        if name in self.scaling_matrices:
            del self.scaling_matrices[name]

    def _factorize_matrix(self, matrix, eq_rank, rank, name, dev, verbose=False):
        dev = matrix.device
        dtype = matrix.dtype

        if name in self.scaling_matrices:
            A_scale_inv, B_scale_inv, A_scale, B_scale = self.scaling_matrices[name]
        else:
            # OPTIMIZATION: Pass CPU gradients and chunk size to the memory-efficient function
            print(len(self.grad_dict[name])) if verbose else None
            A, B = get_kron_factors(
                torch.stack(self.grad_dict[name]).float(), 
                top_k=1, 
                layer_name=name,
                device=dev,
                chunk_size=self.processing_chunk_size
            )
            A_scale, A_scale_inv = matrix_sqrt_invsqrt(A)
            B_scale, B_scale_inv = matrix_sqrt_invsqrt(B)
            self.scaling_matrices[name] = [A_scale_inv, B_scale_inv, A_scale, B_scale]
            print(f"Hessian whitening matrix A {name} min: {torch.diag(A_scale).min()}, max: {torch.diag(A_scale).max()}, median: {torch.diag(A_scale).median()}") if verbose else None
            print(f"Hessian whitening matrix B {name} min: {torch.diag(B_scale).min()}, max: {torch.diag(B_scale).max()}, median: {torch.diag(B_scale).median()}") if verbose else None

        if rank == 0:
            rank = eq_rank
        elif rank > eq_rank:
            print(f"Warning: {name} rank ({rank}) is larger than equivalent rank ({eq_rank})!")
            rank = eq_rank

        mat_scaled = A_scale.to(dev) @ matrix.to(torch.float32).to(dev) @ B_scale.to(dev)
        u, s, vh = torch.linalg.svd(mat_scaled, full_matrices=False)
        
        active_rank = min(rank, len(s))
        s_val = torch.sqrt(s[:active_rank])

        mat_l = (A_scale_inv.to(dev) @ u[:, :active_rank]) * s_val
        mat_r = (vh[:active_rank, :] @ B_scale_inv.to(dev)) * s_val.unsqueeze(1)
        
        return FactorizedMatrix(
            mat_l=mat_l.to(dtype),
            mat_r=mat_r.to(dtype),
            eq_rank=eq_rank,
            active_rank=active_rank,
            singular_values=s,
        )