import torch
import torch.nn as nn 
import torch.nn.functional as F

# from utils.init_utils import init_adapters

class BaseLayer(nn.Module):
    def __init__(self, weight):
        super().__init__()
        # Register the original weight as a parameter in this submodule.
        self.register_parameter('weight', nn.Parameter(weight.data))
        
    def forward(self, x):
        # This submodule is just a container.
        return self.weight

def add_lora_adapters(model, rank, alpha, b_var, a_var, target_modules): 
    for mod_name, module in list(model.named_modules()): 
        if hasattr(module, 'base_layer'):
            continue

        if any(target in mod_name for target in target_modules): 
            orig_weight = module.weight
            # embed_dim = orig_weight.shape[0]
            out_dim, in_dim = orig_weight.shape

            module.base_layer = BaseLayer(orig_weight)
            del module._parameters['weight']

            # Define LoRA Parameters
            B = nn.Parameter(torch.randn(out_dim, rank, device=module.base_layer.weight.device) * b_var)
            A = nn.Parameter(torch.randn(rank, in_dim, device=module.base_layer.weight.device) * a_var)

            module.register_parameter('lora_B', B)
            module.register_parameter('lora_A', A)

            # New module forward method that is defined as Output = Wx + BAx 
            def new_forward(x, module=module): 
                if hasattr(module, 'bias'): 
                    result = F.linear(x, module.base_layer.weight, module.bias)
                else: 
                    result = F.linear(x, module.base_layer.weight)

                B = module.lora_B 
                A = module.lora_A

                update = F.linear(x, A)
                update = F.linear(update, B)
                result = result + (alpha / rank) * update 
                # result = result + update

                return result
            
            module.forward = new_forward 
    
    # Freeze all other parameters in the model 
    for name, parameter in model.named_parameters(): 
        if 'lora' not in name: 
            parameter.requires_grad = False 
    
    return model 

def add_vravan_adapters(model, rank, alpha, b_var, r_var, a_var, num_heads, target_modules): 
    for mod_name, module in list(model.named_modules()): 
        if hasattr(module, 'base_layer'):
            continue

        if any(target in mod_name for target in target_modules): 
            orig_weight = module.weight
            # embed_dim = orig_weight.shape[0]
            out_dim, in_dim = orig_weight.shape

            module.base_layer = BaseLayer(orig_weight)
            del module._parameters['weight']

            # Define LoRA Parameters
            for i in range(num_heads): 
                B_i = nn.Parameter(torch.randn(out_dim, rank, device=module.base_layer.weight.device) * b_var, requires_grad=True)
                R_i = nn.Parameter(torch.randn(rank, rank, device=module.base_layer.weight.device) * r_var, requires_grad=True)
                A_i = nn.Parameter(torch.randn(rank, in_dim, device=module.base_layer.weight.device) * a_var, requires_grad=True)
                
                module.register_parameter(f'lora_A_{i}', A_i)
                module.register_parameter(f'lora_R_{i}', R_i)
                module.register_parameter(f'lora_B_{i}', B_i)

            # New module forward method that is defined as Output = Wx + BAx 
            def new_forward(x, module=module, num_heads=num_heads): 
                if hasattr(module, 'bias'): 
                    result = F.linear(x, module.base_layer.weight, module.bias)
                else: 
                    result = F.linear(x, module.base_layer.weight)

                mh_update = 0 
                for i in range(num_heads): 
                    A_i = getattr(module, f'lora_A_{i}')
                    R_i = getattr(module, f'lora_R_{i}')
                    B_i = getattr(module, f'lora_B_{i}')

                    update = F.linear(x, A_i)
                    update = F.linear(update, R_i)
                    update = F.linear(update, B_i)
                    mh_update = mh_update + update

                # return result + (alpha / rank) * mh_update
                return result + mh_update
            
            module.forward = new_forward 
    
    # Freeze all other parameters in the model 
    for name, parameter in model.named_parameters(): 
        if 'lora_R' not in name: 
            parameter.requires_grad = False 
    
    return model 

def add_ravan_adapters(model, rank, alpha, b_var, r_var, a_var, num_heads, target_modules):
    for mod_name, module in list(model.named_modules()):
        if hasattr(module, 'base_layer'):
            continue

        if any(target in mod_name for target in target_modules):
            orig_weight = module.weight
            # embed_dim = orig_weight.shape[0]
            out_dim, in_dim = orig_weight.shape

            module.base_layer = BaseLayer(orig_weight)
            del module._parameters['weight']

            # Define Multi-Head Ravan Parameters
            for i in range(num_heads):
                B_i = nn.Parameter(
                    torch.randn(out_dim, rank, device=module.base_layer.weight.device) * b_var,
                    requires_grad=True
                )
                R_i = nn.Parameter(
                    torch.randn(rank, rank, device=module.base_layer.weight.device) * r_var,
                    requires_grad=True
                )
                A_i = nn.Parameter(
                    torch.randn(rank, in_dim, device=module.base_layer.weight.device) * a_var,
                    requires_grad=True
                )
                # Trainable scaling factor for this head
                scaling_i = nn.Parameter(
                    torch.tensor(1.0, device=module.base_layer.weight.device),
                    requires_grad=True
                )

                module.register_parameter(f'lora_A_{i}', A_i)
                module.register_parameter(f'lora_R_{i}', R_i)
                module.register_parameter(f'lora_B_{i}', B_i)
                module.register_parameter(f'lora_scaling_{i}', scaling_i)

            # Define new forward method: Output = Wx + (alpha / rank) * sum_i s_i * B_i R_i A_i x
            def new_forward(x, module=module, num_heads=num_heads):
                if hasattr(module, 'bias'): 
                    result = F.linear(x, module.base_layer.weight, module.bias)
                else: 
                    result = F.linear(x, module.base_layer.weight)

                mh_update = 0
                for i in range(num_heads):
                    A_i = getattr(module, f'lora_A_{i}')
                    R_i = getattr(module, f'lora_R_{i}')
                    B_i = getattr(module, f'lora_B_{i}')
                    scaling_i = getattr(module, f'lora_scaling_{i}')
                    update = F.linear(x, A_i)
                    update = F.linear(update, R_i)
                    update = F.linear(update, B_i)
                    update = scaling_i * update  # Apply the trainable scaling factor
                    mh_update = mh_update + update
                return result + mh_update

            module.forward = new_forward

    # Freeze all parameters except those that are trainable in the adapter (R and scaling factors)
    for name, parameter in model.named_parameters():
        if ('lora_R' not in name) and ('lora_scaling' not in name):
            parameter.requires_grad = False

    return model

def add_sb_adapters(model, rank, alpha, target_modules, delta=None): 
    for mod_name, module in list(model.named_modules()): 
        if hasattr(module, 'base_layer'):
            continue

        if any(target in mod_name for target in target_modules): 
            orig_weight = module.weight
            # embed_dim = orig_weight.shape[0]
            out_dim, in_dim = orig_weight.shape

            module.base_layer = BaseLayer(orig_weight)
            del module._parameters['weight']

            if delta: 
                param_name = mod_name + ".weight"
                param_delta = delta[param_name]
                U, S, Vh = torch.linalg.svd(param_delta)
                U, S, Vh = U.cuda(), S.cuda(), Vh.cuda() 

                # Define LoRA Parameters
                B = nn.Parameter(U[:, :rank])
                R = nn.Parameter(torch.diag(S)[:rank, :rank])
                A = nn.Parameter (Vh[:rank, :])
            
            else: 
                B = nn.Parameter(torch.zeros(out_dim, rank, device=module.base_layer.weight.device))
                R = nn.Parameter(torch.zeros(rank, rank, device=module.base_layer.weight.device))
                A = nn.Parameter(torch.zeros(rank, in_dim, device=module.base_layer.weight.device))

            module.register_parameter('lora_B', B)
            module.register_parameter('lora_R', R)
            module.register_parameter('lora_A', A)

            # New module forward method that is defined as Output = Wx + BAx 
            def new_forward(x, module=module): 
                if hasattr(module, 'bias'): 
                    result = F.linear(x, module.base_layer.weight, module.bias)
                else: 
                    result = F.linear(x, module.base_layer.weight)

                B = module.lora_B 
                R = module.lora_R
                A = module.lora_A

                update = F.linear(x, A)
                update = F.linear(update, R)
                update = F.linear(update, B)
                # result = result + (alpha / rank) * update 
                result = result + update

                return result
            
            module.forward = new_forward 
    
    # Freeze all other parameters in the model 
    for name, parameter in model.named_parameters(): 
        if 'lora_R' not in name: 
            parameter.requires_grad = False 
    
    return model 

def add_adapters(model, rank, alpha, b_var, r_var, a_var, num_heads, adaptation_method, delta=None): 
    model_name = model.config._name_or_path 
    if model_name == "google/vit-base-patch16-224-in21k" or model_name == 'roberta-large': 
        target_modules = ['query', 'value']
    elif model_name == "t5-base": 
        target_modules = ['SelfAttention.q', 'SelfAttention.v']
    elif 'llama' in model_name: 
        target_modules = ["q_proj", "v_proj"]
    else: 
        raise NotImplementedError()
    
    if adaptation_method == 'lora': 
        model = add_lora_adapters(model, rank, alpha, b_var, a_var, target_modules)
    elif adaptation_method == 'sb': 
        model = add_sb_adapters(model, rank, alpha, target_modules, delta=delta)
    elif adaptation_method == 'vravan': 
        model = add_vravan_adapters(model, rank, alpha, b_var, r_var, a_var, num_heads, target_modules)
    elif adaptation_method == 'ravan': 
        model = add_ravan_adapters(model, rank, alpha, b_var, r_var, a_var, num_heads, target_modules)
    elif adaptation_method == 'full_ft': 
        pass
    else: 
        raise NotImplementedError()
    
    return model