# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
from typing import List, Set, Optional, Type

import torch
import torch.nn as nn


SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'}
CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'}
ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES
MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2
TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES


def get_LoRA_module_names(id: str) -> Set[str]:
    """ Returns a list of module names that are LoRA-adapted for the given id. """
    id = id.lower()
    if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']:
        return SELF_ATTENTION_MODULES
    elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']:
        return CROSS_ATTENTION_MODULES
    elif id in ['attn', 'attention']:
        return ATTENTION_MODULES
    elif id in ['mlp']:
        return MLP_MODULES
    elif id in ['all', 'transformer']:
        return TRANSFORMER_MODULES
    else:
        raise ValueError(f'Unknown LoRA module id {id}.')


class LoRAWrapper(nn.Module):
    """Low-Rank Adaptation Wrapper for linear layers.
    See https://arxiv.org/abs/2106.09685
    
    Args:
        linear: nn.Linear layer to wrap
        rank: Rank of adaptation matrix B@A
        scale: x = W_0@x + scale * B@A@x
        num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights.
            Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will
            be num_packed_linear times larger.
    """
    def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1):
        super().__init__()
        self.rank = rank
        self.scale = scale
        self.in_features, self.out_features = linear.in_features, linear.out_features
        assert num_packed_linear * rank <= min(self.in_features, self.out_features), \
            f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}'
        
        self.linear = linear
        self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False)
        self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False)

        nn.init.normal_(self.lora_down.weight, std=1/rank)
        nn.init.zeros_(self.lora_up.weight)
        
    def fuse_LoRA_into_linear(self) -> nn.Linear:
        """ Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """
        fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None)
        fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight)
        if self.linear.bias is not None:
            fused_linear.bias.data = self.linear.bias
        return fused_linear

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ LoRA adapted linear layer forward pass. """
        return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
    

def _find_modules(
    model,
    ancestor_class: Optional[Set[str]] = None,
    search_class: List[Type[nn.Module]] = [nn.Linear],
    exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper],
):
    """
    Find all modules of a certain class (or union of classes) that are direct or
    indirect descendants of other modules of a certain class (or union of classes).

    Returns all matching modules, along with the parent of those moduless and the
    names they are referenced by.
    
    Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
    """
    # Get the targets we should replace all linears under
    if ancestor_class is not None:
        ancestors = (
            module
            for module in model.modules()
            if module.__class__.__name__ in ancestor_class
        )
    else:
        # this, incase you want to naively iterate over all modules.
        ancestors = [module for module in model.modules()]

    # For each target find every linear_class module that isn't a child of a LoRA layer
    for ancestor in ancestors:
        for fullname, module in ancestor.named_modules():
            if any([isinstance(module, _class) for _class in search_class]):
                # Find the direct parent if this is a descendant, not a child, of target
                *path, name = fullname.split(".")
                parent = ancestor
                while path:
                    parent = parent.get_submodule(path.pop(0))
                # Skip this linear if it's a child of a LoRA layer
                if exclude_children_of and any(
                    [isinstance(parent, _class) for _class in exclude_children_of]
                ):
                    continue
                # Otherwise, yield it
                yield parent, name, module
                

def inject_trainable_LoRA(
    model: nn.Module, 
    rank: int = 4, 
    scale: float = 1.0,
    target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
    """Replaces all linear layers of the specified modules with LoRA-adapted linear layers.
    Modifies the model in-place.
    
    Args:
        model: nn.Module to modify
        rank: Rank of adaptation matrix B@A
        scale: x = W_0@x + scale * B@A@x
        target_replace_modules: Set of module names to replace linear layers in.
    """
    for _module, name, _child_module in _find_modules(
        model, target_replace_modules, search_class=[nn.Linear]
    ):
        if sorted(name) == sorted('qkv'):
            num_packed_linear = 3
        elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]:
            num_packed_linear = 2
        else:
            num_packed_linear = 1
        
        _module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear)
        

def fuse_LoRA_into_linear(
    model: nn.Module,
    target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
    """Fuses all LoRA-adapted linear layers back into single linear layers.
    Modifies the model in-place.

    Args:
        model: nn.Module to modify
        target_replace_modules: Set of module names to replace linear layers in.
    """
    for _module, name, _child_module in _find_modules(
        model, target_replace_modules, search_class=[LoRAWrapper]
    ):
        _module._modules[name] = _module._modules[name].fuse_LoRA_into_linear()


def unfreeze_all_LoRA_layers(model: nn.Module) -> None:
    """ Unfreezes all LoRA-adapted linear layers. """
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
