import torch
import torch.nn as nn
from einops import rearrange


class DataProvider:
    def __init__(self):
        self.data = None

    def set(self, **kwargs):
        self.data = {k: v for k, v in kwargs.items()}

    def get(self, key):
        assert self.data is not None, "Error: need to set data first"
        assert key in self.data, f"Error: need to set key {key} first"
        return self.data[key]

    def reset(self):
        self.data = None


class SimpleLoraLinear(torch.nn.Module):
    def __init__(
        self,
        out_features: int,
        in_features: int,
        c_dim: int,
        rank: int | float,
        data_provider: DataProvider,
        alpha: float = 1.0,
        lora_scale: float = 1.0,
        broadcast_tokens: bool = True,
        depth: int | None = None,
        use_depth: bool = False,
        n_transformations: int = 1,
        with_conditioning: bool = True,
        base_bias: bool = True,
        lora_bias: bool = False,
        frozen_weights_dtype=torch.bfloat16,
        target_path=None,
        **kwargs,
    ):
        super().__init__()

        self.data_provider = data_provider
        self.lora_scale = lora_scale
        self.broadcast_tokens = broadcast_tokens
        self.depth = depth
        self.use_depth = use_depth
        self.n_transformations = n_transformations
        self.rank = rank
        self.target_path = target_path

        # original weight of the matrix
        self.W = nn.Linear(in_features, out_features, bias=base_bias).to(dtype=frozen_weights_dtype)
        for p in self.W.parameters():
            p.requires_grad_(False)

        if type(rank) == float:
            self.rank = int(in_features * self.rank)

        self.A = nn.Linear(in_features, self.rank, bias=False)
        self.B = nn.Linear(self.rank, out_features, bias=lora_bias)

        nn.init.zeros_(self.B.weight)
        if lora_bias:
            nn.init.zeros_(self.B.bias)
        nn.init.kaiming_normal_(self.A.weight, a=1)

        self.with_conditioning = with_conditioning
        if with_conditioning:
            self.emb_gamma = nn.Linear(c_dim, self.rank * n_transformations, bias=False)
            self.emb_beta = nn.Linear(c_dim, self.rank * n_transformations, bias=False)

        # self.__old_A_weights = self.A.weight.detach().clone()

    def forward(self, x: torch.Tensor, *args, **kwargs):
        # diff = (self.A.weight.detach().cpu() - self.__old_A_weights).abs().sum()
        w_out = self.W(x)

        if self.lora_scale == 0.0:
            return w_out

        c: torch.Tensor = self.data_provider.get("cond_lora")
        if self.use_depth:
            assert self.depth is not None, "block depth has to be provided"
            c = c[self.depth]

        if self.with_conditioning:
            scale = self.emb_gamma(c) + 1.0
            shift = self.emb_beta(c)

            # we need to do that when we only get a single embedding vector
            # e.g pooled clip img embedding
            # out is [B, 1, rank]
            if self.broadcast_tokens:
                scale = scale.unsqueeze(2)
                shift = shift.unsqueeze(2)

            if self.n_transformations > 1:
                # out is [B, 1, trans, rank]
                scale = scale.reshape(-1, 1, self.n_transformations, self.rank)
                shift = shift.reshape(-1, 1, self.n_transformations, self.rank)

        a_out = self.A(x.to(self.A.weight.dtype))  # [B, N, D]
        if self.n_transformations > 1:
            a_out = a_out.unsqueeze(-2).expand(-1, -1, self.n_transformations, -1)  # [B, N, trans, rank]

        # reshape for temporal injection
        if self.with_conditioning:
            a_out = rearrange(a_out, "b (t l) d -> b t l d", t=c.shape[1])
            a_cond = scale * a_out
            a_cond = a_cond + shift
            a_cond = rearrange(a_cond, "b t l d -> b (t l) d")
        else:
            a_cond = a_out

        if self.n_transformations > 1:
            a_cond = a_cond.mean(dim=-2)

        b_out = self.B(a_cond)

        return w_out + b_out.to(dtype=self.W.weight.dtype) * self.lora_scale
