from typing import List, Tuple, Dict, Any
from functools import partial
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class Lorax(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool,
        lora_r: int = 0,
        num_loras: int = 0,
    ):
        nn.Linear.__init__(self, in_features, out_features, bias=bias)
        
        self.lora_r = lora_r
        self.num_loras = num_loras
        
        if lora_r > 0:
            self.lora_A = nn.Linear(in_features, lora_r, bias=False)

            self.lora_B_list = nn.ModuleList([
                nn.Linear(lora_r, out_features, bias=False) 
                for _ in range(num_loras)
            ])
            self.lorax_reset_parameters()
    
    def lorax_reset_parameters(self):
        if hasattr(self, 'lora_A'):
            nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
            for B in self.lora_B_list:
                nn.init.zeros_(B.weight)
        
    def forward(
        self,
        x: torch.Tensor,
        lora_B_idx: torch.Tensor,
        should_store: bool = False,
        representations: List[torch.Tensor] = None,
    ):
        base_output = F.linear(x, self.weight, self.bias)  # (B, S, out_features)

        
        if self.lora_r > 0:
            B, S, _ = x.shape
            assert x.size(0) == lora_B_idx.size(0), f"x.shape[0]: {x.size(0)} != lora_B_idx.shape[0]: {lora_B_idx.size(0)}" 
            assert torch.all(lora_B_idx < self.num_loras), f"lora_B_idx should be less than {self.num_loras}, but got {lora_B_idx.max()}"

            lora_A_out = self.lora_A(x)  # (B, S, r)
            all_B = torch.stack([B.weight for B in self.lora_B_list], dim=0)  # (num_loras, out_features, r)

            mask = F.one_hot(lora_B_idx, num_classes=self.num_loras)  # (B, num_loras)
            mask = mask.to(dtype=all_B.dtype, device=all_B.device).unsqueeze(2) # (B, num_loras, 1)

            delta = torch.einsum("lor,bsr,blk->bso", 
                                all_B, 
                                lora_A_out, 
                                mask)  # (B, S, out_features)

            final_output = base_output + delta
                        
            if should_store and representations is not None:
                representations.append(lora_A_out)
        else:
            final_output = base_output
                
        return final_output


def wrap_linear(
    model: nn.Module,
    target_modules: Tuple[str],
    config: Dict[str, Any],
):
    def _get_submodule(key):
        parent = model.get_submodule(".".join(key.split(".")[:-1]))
        target_name = key.split(".")[-1]
        target = model.get_submodule(key)
        return parent, target_name, target
    
    lora_impl = partial(Lorax, **config)
    key_list = [key for key, _ in model.named_modules()]
    
    for key in key_list:
        if any(target_name in key for target_name in target_modules):
            parent, target_name, target = _get_submodule(key)
            if isinstance(target, Lorax):
                continue
            
            new_module = lora_impl(
                in_features=target.in_features,
                out_features=target.out_features,
                bias=True if target.bias is not None else False,
            )
            new_module.to(device=model.device, dtype=model.dtype)
            new_module.load_state_dict(target.state_dict(), strict=False)
            new_module.lorax_reset_parameters()
            setattr(parent, target_name, new_module)
    
    return model
            

if __name__ == "__main__":

    B, in_dim, out_dim, lora_r, num_loras = 4, 8, 16, 4, 3
    x = torch.randn(B, 16, in_dim, requires_grad=True)
    lora_B_idx = torch.LongTensor([2, 1, 0, 1])


    # result check
    model = Lorax(in_dim, out_dim, bias=True, lora_r=lora_r, num_loras=num_loras)
    model_output = model(x, lora_B_idx)


    def ori_impl_forward(model, x, lora_B_idx):
        result = F.linear(x, model.weight, model.bias)
        lora_A_out = model.lora_A(x)
        for i in range(x.size(0)):
            for j in range(model.num_loras):
                if lora_B_idx[i] == j:
                    result[i] += model.lora_B_list[j](lora_A_out[i])
        return result

    ori_output = ori_impl_forward(model, x, lora_B_idx)
    assert torch.allclose(model_output, ori_output, atol=1e-6), f"model_output: {model_output}, ori_output: {ori_output}"

    # gradient check
    def check_grad(model, x, lora_B_idx):
        loss = model(x, lora_B_idx).sum()
        loss.backward()
        for i, B in enumerate(model.lora_B_list):
            print(f"lora_B_list[{i}].weight.grad: {B.weight.grad.abs().sum()}")
            assert B.weight.grad is not None, f"lora_B_list[{i}].weight.grad is None"

    check_grad(model, x, lora_B_idx)

