from pydoc import locate
from torch import nn
import torch
from jaxtyping import Float
from einops import rearrange


class Level(nn.ModuleList):
    def forward(self, x, *args, **kwargs):
        for layer in self:
            x = layer(x, *args, **kwargs)
        return x


class Transformer(nn.Module):
    def __init__(self, in_features, out_features, main_level=None, up_levels=[], down_levels=[]):
        super().__init__()

        # assumes up and down levels are already in the correct order

        prev_in = in_features
        prev_out = out_features

        self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
        self.merges, self.splits = nn.ModuleList(), nn.ModuleList()

        # high res -> low res
        for i, level in enumerate(down_levels):
            width = level.width
            self.merges.append(
                locate(level.proj_cls)(
                    type="merge",
                    in_features=prev_in,
                    out_features=width,
                    **level.proj_params,
                )
            )
            layer_factory = lambda: locate(level.layer_class)(d_model=width, **level.layer_params)
            self.down_levels.append(Level([layer_factory() for _ in range(level.depth)]))
            prev_in = width

        # NOTE: up levels in same order as down levels
        # so high res -> low res
        for i, level in enumerate(up_levels):
            width = level.width
            layer_factory = lambda: locate(level.layer_class)(d_model=width, **level.layer_params)
            self.up_levels.append(Level([layer_factory() for _ in range(level.depth)]))
            self.splits.append(
                locate(level.proj_cls)(
                    type="split" if i + 1 < len(up_levels) else "split_last",
                    in_features=width,
                    out_features=prev_out,
                    **level.proj_params,
                )
            )
            prev_out = width

        self.mid_level, self.mid_merge, self.mid_split = None, None, None
        if main_level is not None:
            width = main_level.width
            layer_factory = lambda: locate(main_level.layer_class)(d_model=main_level.width, **main_level.layer_params)
            self.mid_level = Level([layer_factory() for _ in range(main_level.depth)])
            self.mid_merge = locate(main_level.proj_cls)(
                type="merge",
                in_features=prev_in,
                out_features=main_level.width,
                **main_level.proj_params,
            )
            self.mid_split = locate(main_level.proj_cls)(
                type="split" if len(up_levels) > 0 else "split_last",
                in_features=main_level.width,
                out_features=prev_out,
                **main_level.proj_params,
            )
            self.forward = torch.compile(self.forward, fullgraph=True, dynamic=False)

    def forward(
        self,
        x: Float[torch.Tensor, "B C *DIMS"],
        pos: Float[torch.Tensor, "B cn *DIM"],
        **kwargs,
    ):
        check_dict = {k: False for k in kwargs.keys()}

        x = rearrange(x, "b c ... -> b ... c")
        pos = rearrange(pos, "b cn ... -> b ... cn")

        C_pos = pos.shape[-1]

        skips, poses = [], []
        for merge, level in zip(self.merges, self.down_levels):
            skips.append(x)
            x, pos = merge(x, pos, check_dict=check_dict, **kwargs)
            poses.append(pos)
            B, *DIMS, C = x.shape
            x = x.reshape(B, -1, C)
            pos = pos.reshape(B, -1, C_pos)
            x = level(x, pos=pos, check_dict=check_dict, **kwargs)
            x = x.reshape(B, *DIMS, C)
            pos = pos.reshape(B, *DIMS, C_pos)

        if self.mid_level is not None:
            skip = x
            x, pos = self.mid_merge(x, pos, check_dict=check_dict, **kwargs)
            B, *DIMS, C = x.shape
            x = x.reshape(B, -1, C)
            pos = pos.reshape(B, -1, C_pos)
            x = self.mid_level(x, pos=pos, check_dict=check_dict, **kwargs)
            x = x.reshape(B, *DIMS, C)
            pos = pos.reshape(B, *DIMS, C_pos)
            x = self.mid_split(x, skip=skip, check_dict=check_dict, **kwargs)

        for split, level, skip, pos in reversed(list(zip(self.splits, self.up_levels, skips, poses))):
            B, *DIMS, C = x.shape
            x = x.reshape(B, -1, C)
            pos = pos.reshape(B, -1, C_pos)
            x = level(x, pos=pos, check_dict=check_dict, **kwargs)
            x = x.reshape(B, *DIMS, C)
            pos = pos.reshape(B, *DIMS, C_pos)
            x = split(x, skip=skip, check_dict=check_dict, **kwargs)

        x = rearrange(x, "b ... c -> b c ...")

        assert all([v for v in check_dict.values()]), f"Not all kwargs were used in the forward pass: {check_dict}"

        return x
