import torch
from torch import nn
import torch.nn.utils.parametrize as parametrize


# Define the custom parametrization for LoRA layer
class LoRAParametrization(nn.Module):
    def __init__(self, A_shape, B_shape, device=None):
        super().__init__()
        self.A = nn.Parameter(torch.randn(*A_shape, device=device))
        self.B = nn.Parameter(torch.zeros(*B_shape, device=device))
        
    def forward(self, weight):
        # Replace the model's weight with the linear combination of weight and ref_weight
        return weight + torch.einsum('ik...,ko->io...', self.A, self.B)


# Recursive function to collect leaf layers and their names
def collect_leaf_layers(layer, prefix=""):
    leaf_layers = []  # List to store the leaf layers and their names

    for name, sub_layer in layer.named_children():
        full_name = prefix + "." + name if prefix else name

        # If the sub-layer has parameters (leaf layer), collect the parameters to list them
        if len(list(sub_layer.children())) == 0:  # Leaf layer check
            for param_name, param in sub_layer.named_parameters(recurse=False):
                # Append the leaf layer, its full name, and parameter name to the list
                leaf_layers.append((sub_layer, full_name, param_name))

        # If not a leaf layer, recursively go deeper
        else:
            leaf_layers.extend(collect_leaf_layers(sub_layer, full_name))

    return leaf_layers  # Return the collected list of leaf layers


# Recursive function to apply LoRA to all layers and their parameters
def apply_parametrization_to_leaf_layers(model, rank=1):
    """ f(x; \\theta) <- f(x; \\theta + \\alpha @ \\beta)"""
    total_added_params = 0
    for leaf_layer, full_name, param_name in collect_leaf_layers(model):
        param = dict(model.named_parameters())[full_name + "." + param_name]
        shape = param.shape
        if shape.__len__() < 2:
            continue
        else:
            shape_A = [shape[0], rank, *shape[2:]]
            shape_B = [rank, shape[1]]

        # Calculate added parameters for LoRA (A and B)
        added_params = torch.prod(torch.tensor(shape_A)) + torch.prod(torch.tensor(shape_B))
        total_added_params += added_params.item()

        # NOTE: Freeze the original weights by setting requires_grad to False
        param.requires_grad = False

        parametrize.register_parametrization(leaf_layer, param_name, LoRAParametrization(shape_A, shape_B, param.device))
    print(f"Total number of added parameters: {total_added_params}")
    return model


def wrapper(name: str, model: torch.nn.Module, rank: int):
    if name == "mobilenet_v2":
        return apply_parametrization_to_leaf_layers(model, rank=rank)
    elif name == "resnet18":
        return apply_parametrization_to_leaf_layers(model, rank=rank)
    elif name == "resnet50":
        return apply_parametrization_to_leaf_layers(model, rank=rank)
    elif name == "resnext50_32x4d":
        return apply_parametrization_to_leaf_layers(model, rank=rank)
    elif name in ["vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14"]:
        return apply_parametrization_to_leaf_layers(model, rank=rank)
    else:
        raise ValueError
