import torch.nn as nn
import torch
from custom_op.linear.linear_lora import LoRALinear
def get_all_linear_with_name(model):
    linear_layers = {}
    visited = set()

    for name, mod in model.named_modules():
        if any(name.startswith(v + ".") for v in visited):
            continue

        if isinstance(mod, LoRALinear) and 'mlp' in name:
            linear_layers[name] = mod
            visited.add(name)
        elif isinstance(mod, nn.modules.linear.Linear) and 'mlp' in name:
            linear_layers[name] = mod

    return linear_layers

def get_active_linear_with_name(model):
    total_linear_layer = get_all_linear_with_name(model)
    if model.num_of_finetune == "all" or model.num_of_finetune > len(total_linear_layer):
        return total_linear_layer
    elif model.num_of_finetune == None or model.num_of_finetune == 0:
        return -1
    else:
        active_linear_layers = dict(list(total_linear_layer.items())[-model.num_of_finetune:])
        return active_linear_layers

def get_all_conv_with_name(model):
    conv_layers = {}
    for name, mod in model.named_modules():
        if isinstance(mod, nn.modules.conv.Conv2d): 
            conv_layers[name] = mod
    return conv_layers

def get_all_conv_with_name_and_previous_non_linearity(model, parent_name='', previous_mod=None, previous_name=None):
    name_conv_layers_with_relu = [] 
    conv_layers = []

    for name, mod in model.named_children():
        full_name = f"{parent_name}.{name}" if parent_name else name

        if len(list(mod.children())) > 0:
            result, name_with_relu, previous_mod, previous_name = get_all_conv_with_name_and_previous_non_linearity(mod, full_name, previous_mod, previous_name)
            name_conv_layers_with_relu.extend(name_with_relu)
            conv_layers.extend(result)
        else:

            if isinstance(mod, (nn.Conv2d)): 
                conv_layers.append(full_name)
                if isinstance(previous_mod, (nn.ReLU, nn.ReLU6)):
                    name_conv_layers_with_relu.append(f"{full_name}_relu")
                else:
                    name_conv_layers_with_relu.append(full_name)
        
            previous_mod = mod
            previous_name = name

    return conv_layers, name_conv_layers_with_relu, previous_mod, previous_name
    
def get_active_conv_with_name(model):
    total_conv_layer = get_all_conv_with_name(model)
    if model.num_of_finetune == "all" or model.num_of_finetune > len(total_conv_layer):
        return total_conv_layer
    elif model.num_of_finetune == None or model.num_of_finetune == 0:
        return -1
    else:
        active_conv_layers = dict(list(total_conv_layer.items())[-model.num_of_finetune:])
        return active_conv_layers

class Hook:
    def __init__(self, module):
        self.module = module
        self.input_size = None
        self.output_size = None
        self.inputs = []
        self.outputs = []

        self.weight_size = None
        self.weight = None

        self.is_lora = False
        self.original_weight = None
        self.lora_A_weight = None
        self.lora_B_weight = None
        self.lora_rank = None
        self.lora_alpha = None
        self.lora_scaling = None
        self.effective_weight = None 
        
        self.active = True
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        if not self.active:
            return
        Input = input[0].clone().detach()
        Output = output.clone().detach()

        self.input_size = torch.tensor(Input.shape)
        self.output_size = torch.tensor(Output.shape)

        self.inputs.append(Input)
        self.outputs.append(Output)

        if hasattr(module, 'original_layer') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            self.is_lora = True
            
            if hasattr(module.original_layer, 'weight') and module.original_layer.weight is not None:
                self.original_weight = module.original_layer.weight.clone().detach()
            
            if hasattr(module.lora_A, 'weight') and module.lora_A.weight is not None:
                self.lora_A_weight = module.lora_A.weight.clone().detach()
                
            if hasattr(module.lora_B, 'weight') and module.lora_B.weight is not None:
                self.lora_B_weight = module.lora_B.weight.clone().detach()
            
            self.lora_rank = module.rank if hasattr(module, 'rank') else None
            self.lora_scaling = module.scaling if hasattr(module, 'scaling') else None
            self.lora_alpha = module.alpha if hasattr(module, 'alpha') else None
            
            if self.original_weight is not None and self.lora_A_weight is not None and self.lora_B_weight is not None:
                if len(self.lora_A_weight.shape) == 2 and len(self.lora_B_weight.shape) == 2:
                    lora_weight = self.lora_B_weight @ self.lora_A_weight
                    self.effective_weight = self.original_weight + lora_weight * self.lora_scaling
                    
            self.weight = self.original_weight
            self.weight_size = self.original_weight.shape if self.original_weight is not None else None
            
        else:
            if hasattr(module, 'weight') and module.weight is not None:
                self.weight_size = module.weight.shape
                self.weight = module.weight.clone().detach()

    def get_lora_details(self):
        if not self.is_lora:
            return "Not LoRA"
        
        details = {
            "original_weight_shape": self.original_weight.shape if self.original_weight is not None else None,
            "lora_A_weight_shape": self.lora_A_weight.shape if self.lora_A_weight is not None else None,
            "lora_B_weight_shape": self.lora_B_weight.shape if self.lora_B_weight is not None else None,
            "lora_rank": self.lora_rank,
            "lora_scaling": self.lora_scaling,
            "lora_alpha": self.lora_alpha
        }

        return details

    def activate(self, active):
        self.active = active

    def remove(self):
        self.input_size = None
        self.output_size = None
        self.inputs.clear()
        self.outputs.clear()
        self.weight_size = None
        self.weight =  None

        self.original_weight = None
        self.lora_A_weight = None
        self.lora_B_weight = None
        self.lora_rank = None
        self.lora_scaling = None
        self.effective_weight = None

    
        self.active = False
        self.hook.remove()

def calculate_flops_subspace_iteration(size_1, size_2, rank):
    if isinstance(size_1, torch.Tensor):
        m = torch.max(size_1, rank)
        n = torch.min(size_1, rank)
    else:
        m = max(size_1, rank)
        n = min(size_1, rank)
    return size_1 * rank * (2*size_2 - 1) + size_2 * rank * (2*size_1 - 1) + 2*m*n**2

def calculate_flops_SVD(size_1, size_2):
    if isinstance(size_1, torch.Tensor):
        m = torch.max(size_1, size_2)
        n = torch.min(size_1, size_2)
    else:
        m = max(size_1, size_2)
        n = min(size_1, size_2)
    return 4*m*n**2 + 8*n**3
