from functools import partial

import torch
from kappamodules.layers import LinearProjection
from kappamodules.transformer import DitBlock, PrenormBlock
from torch import nn

from src.modules.act import GEGLU


class TransformerModel(nn.Module):
    def __init__(
        self,
        latent_dim,
        dim,
        depth,
        num_attn_heads,
        project_to_input_dim=False,
        drop_path_rate=0.0,
        drop_path_decay=True,
        init_weights="truncnormal",
        init_last_proj_zero=True,
        full_residual=True,
        condition_dim=None,
        input_ln=False,
        output_ln=False,
        act: nn.Module = GEGLU,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.dim = dim
        self.depth = depth
        self.num_attn_heads = num_attn_heads
        self.drop_path_rate = drop_path_rate
        self.drop_path_decay = drop_path_decay
        self.init_weights = init_weights
        self.init_last_proj_zero = init_last_proj_zero
        self.project_to_input_dim = project_to_input_dim
        self.full_residual = full_residual
        self.condition_dim = condition_dim
        self.input_ln = input_ln
        self.output_ln = output_ln

        self.input_proj = LinearProjection(latent_dim, dim, init_weights=init_weights)

        # blocks
        if self.condition_dim is not None:
            block_ctor = partial(DitBlock, cond_dim=self.condition_dim)
        else:
            block_ctor = PrenormBlock
        if drop_path_decay:
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        else:
            dpr = [drop_path_rate] * depth
        self.blocks = nn.ModuleList(
            [
                block_ctor(
                    dim=dim,
                    num_heads=num_attn_heads,
                    drop_path=dpr[i],
                    init_weights=init_weights,
                    init_last_proj_zero=init_last_proj_zero,
                )
                for i in range(self.depth)
            ]
        )

        if project_to_input_dim:
            self.output_proj = LinearProjection(
                dim, latent_dim, init_weights=init_weights
            )
        else:
            self.output_proj = nn.Identity()

    def forward(self, x, condition=None):
        assert x.ndim == 3

        if self.input_ln:
            x = nn.functional.layer_norm(x, (self.latent_dim,), eps=1e-6)

        og_x = x

        # input projection
        x = self.input_proj(x)

        # apply blocks
        blk_kwargs = dict(cond=condition) if condition is not None else dict()
        for blk in self.blocks:
            x = blk(x, **blk_kwargs)

        # output projection
        x = self.output_proj(x)

        if self.full_residual:
            x = og_x + x

        if self.output_ln:
            x = nn.functional.layer_norm(x, (self.latent_dim,), eps=1e-6)

        return x
