import torch
import torch.nn as nn
import time
import gc
import math

"""Hook installation"""

# Initialize dictionaries to store input_sum and gradients for each layer name
layer_inputs = {} 
layer_gradients = {}
layer_weights = {} 
hook_handles_input = []
hook_handles_grad = []
layer_modules = {}
U_dict = {}
U_grad_dict = {}

def get_svd_layers(model):
    modules_compressed=[]
    for name, module in model.named_modules():
        if any(param.requires_grad for param in module.parameters()):
            if isinstance(module, nn.Linear) or (isinstance(module, nn.Conv2d) and module.kernel_size == (1, 1)):
                if "classifier" in name or "head.head" in name or "bert.pooler" in name or "heads" in name:
                    continue
                modules_compressed.append(name)
    return modules_compressed

# Forward hook to store covariance of X for each layer
def save_input_hook(name, module, input, output):
    input = input[0].detach()
    if input.dim() == 3: 
        if input.shape[2] < input.shape[1]:
            X_XT = torch.matmul( input.transpose(1, 2), input)
            average_X_XT = torch.mean(X_XT, dim=0)
        else:
            X_XT =  torch.matmul( input,input.transpose(1, 2))
            average_X_XT = torch.mean(X_XT, dim=0)
    elif input.dim() == 4:
        concatenated_tensor = input.view(input.shape[0], input.shape[1], -1)
        if concatenated_tensor.shape[2] < concatenated_tensor.shape[1]:
            X_XT =  torch.matmul(concatenated_tensor.transpose(1, 2), concatenated_tensor)
            average_X_XT = torch.mean(X_XT, dim=0)
        else:
            
            X_XT =  torch.matmul(concatenated_tensor, concatenated_tensor.transpose(1, 2))
            average_X_XT = torch.mean(X_XT, dim=0)
    else:
        raise ValueError(
            f"Unexpected tensor dimensions for layer {name}: "
            f"got {input.dim()}, expected 3"
        )

    if name not in layer_inputs:
        layer_inputs[name] = average_X_XT
    else:
        layer_inputs[name] += average_X_XT

# Backward hook to store gradient of Y for each layer
def save_grad_hook(name, module, grad_input, grad_output):
    if grad_output[0].dim() == 3: 
        if grad_output[0].shape[2] < grad_output[0].shape[1]:
            gradX_gradXT =  torch.matmul(grad_output[0].transpose(1, 2), grad_output[0])
            average_gradX_gradXT = torch.mean(gradX_gradXT, dim=0)
        else:
            gradX_gradXT =  torch.matmul(grad_output[0],grad_output[0].transpose(1, 2))
            average_gradX_gradXT = torch.mean(gradX_gradXT, dim=0)
    elif grad_output[0].dim() == 4:
        concatenated_tensor = grad_output[0].view(grad_output[0].shape[0], grad_output[0].shape[1], -1)
        if concatenated_tensor.shape[2] < concatenated_tensor.shape[1]:
            gradX_gradXT =  torch.matmul(concatenated_tensor.transpose(1, 2), concatenated_tensor)
            
            average_gradX_gradXT = torch.mean(gradX_gradXT, dim=0)
        else:
            gradX_gradXT =  torch.matmul(concatenated_tensor, concatenated_tensor.transpose(1, 2))
            average_gradX_gradXT = torch.mean(gradX_gradXT, dim=0)
    else:
        raise ValueError(
            f"Unexpected tensor dimensions for layer {name}: "
            f"got {grad_output[0].dim()}, expected 3"
        )
    if name not in layer_gradients:
        layer_gradients[name] = average_gradX_gradXT
    else:
        layer_gradients[name] += average_gradX_gradXT
    del gradX_gradXT, average_gradX_gradXT

def add_activate_to_module():
    if not hasattr(nn.Module, '_original_init'):
        nn.Module._original_init = nn.Module.__init__
        
        def new_init(self, *args, **kwargs):
            self._original_init(*args, **kwargs)
            if not hasattr(self, 'activating'):
                self._parameters.pop('activating', None) 
                self._buffers.pop('activating', None)  
                self.__dict__['activating'] = True   
        
        nn.Module.__init__ = new_init

    def enable_compression(self):
        self.__dict__['activating'] = True
        for module in self.children():
            module.enable_compression()
        return self
    nn.Module.enable_compression = enable_compression

    def disable_compression(self):
        self.__dict__['activating'] = False
        for module in self.children():
            module.disable_compression()
        return self
    nn.Module.disable_compression = disable_compression

    def initialize_activate(self):
        self.__dict__['activating'] = True
        for module in self.children():
            module.initialize_activate()
    nn.Module.initialize_activate = initialize_activate

# Add P, Q to the model
def add_compression_tensor_to_module():
    # Define a method to register buffers for a module instance
    def register_compression_buffers(module):
        # Register buffers with None as placeholder
        if not hasattr(module, 'CompressionTensor_x'):
            module.register_buffer('CompressionTensor_x', None)
        if not hasattr(module, 'CompressionTensor_gy'):
            module.register_buffer('CompressionTensor_gy', None)
    # Override __init__ of nn.Module to register buffers for all new modules
    original_init = nn.Module.__init__
    def new_init(self, *args, **kwargs):
        original_init(self, *args, **kwargs)
        register_compression_buffers(self)
    nn.Module.__init__ = new_init

    # Add update_compression method to nn.Module
    def update_compression(self, P_dict, Q_dict):
        for name, module in self.named_modules():
            if name in P_dict and name in Q_dict:
                compression_x = P_dict[name].detach().requires_grad_(False)
                compression_gy = Q_dict[name].detach().requires_grad_(False)

                assert compression_x.dim() == 2, f"CompressionTensor_x for {name} is not 2D"
                assert compression_gy.dim() == 2, f"CompressionTensor_gy for {name} is not 2D"

                # Register updated buffers
                module.register_buffer('CompressionTensor_x', compression_x)
                module.register_buffer('CompressionTensor_gy', compression_gy)
        return self
    nn.Module.update_compression = update_compression

# REGISTER HOOK FOR LINEAR LAYER
def register_hooks(model):
    global hook_handles_input, hook_handles_grad, hook_handles_weight

    for name, module in model.named_modules():
        if any(param.requires_grad for param in module.parameters()):
            if isinstance(module, nn.Linear) or (isinstance(module, nn.Conv2d) and module.kernel_size == (1, 1)) :
                if "classifier" in name or "head" in name or "cpb_mlp" in name or "downsample" in name :
                    continue

                handle_forward = module.register_forward_hook(lambda module, input, output, name=name: save_input_hook(name, module, input, output))
                handle_backward = module.register_backward_hook(lambda module, grad_input, grad_output, name=name: save_grad_hook(name, module, grad_input, grad_output))

                hook_handles_input.append(handle_forward)
                hook_handles_grad.append(handle_backward)

# Unregister only the specific hooks
def unregister_hooks():
    global hook_handles_input, hook_handles_grad

    for handle in hook_handles_input:
        handle.remove()
    
    for handle in hook_handles_grad:
        handle.remove()

    hook_handles_input = []  # Clear the list of input hook handles
    hook_handles_grad = []  # Clear the list of gradient hook handles

    # Clear the layer input_sum and gradients dictionaries
    layer_inputs.clear()
    layer_gradients.clear()
    U_dict.clear()
    U_grad_dict.clear()
    gc.collect() 
    torch.cuda.empty_cache()

# SVD calculation
def SVD_expected_value(input, var = 0.95, p = 0):
    U, S, V = torch.svd(input)

    # Calculate the total energy (sum of squared singular values)
    total_energy = torch.sum(S**2)

    # Compute the cumulative energy
    cumulative_energy = torch.cumsum(S**2, dim=0)

    # Find the number of singular values needed to retain 95% of the total energy
    energy_threshold = var * total_energy
    num_singular_values = torch.sum(cumulative_energy <= energy_threshold)
    num_singular_values = num_singular_values + p
    if num_singular_values >= U.shape[1]:
        num_singular_values = U.shape[1]-1
        
    if num_singular_values+1 > 0 and num_singular_values+1 < U.shape[1] :
        retained_energy = cumulative_energy[num_singular_values+1] 
        U_truncated = U[:, :num_singular_values+1]
    elif num_singular_values+1 >= U.shape[1]:
        retained_energy = total_energy
        U_truncated = U[:, :num_singular_values]
    else:
        retained_energy = 0 
    retained_energy_per = retained_energy/total_energy
    
    # Energy offset
    U_truncated = U_truncated * math.sqrt(1/retained_energy_per)
    return U_truncated