import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum, rearrange, repeat, einsum
from einops.layers.torch import Rearrange

from ntldm.networks.gconv import GConv
from ntldm.networks.s4 import FFTConv, S4Block


class AutoEncoderBlock(nn.Module):
    def __init__(
        self,
        C,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        num_lin_per_mlp=2,
        use_act2=False,
    ):
        super().__init__()
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        self.time_mixer = self.get_time_mixer()
        self.post_tm_scale = nn.Conv1d(
            C, C, 1, bias=True, groups=C, padding="same"
        )  # channel-wise scale for post-act
        self.channel_mixer = self.get_channel_mixer(num_lin_per_mlp=num_lin_per_mlp)
        self.norm1 = nn.InstanceNorm1d(C, affine=False)  # make sure input is [B, C, L]!
        self.norm2 = nn.InstanceNorm1d(C, affine=False)  # we will use adaLN
        self.act1 = nn.GELU()
        self.act2 = nn.GELU() if use_act2 else nn.Identity()

        self.ada_ln = nn.Parameter(
            torch.zeros(1, C * 6, 1), requires_grad=True
        )  # 3 for each mixer, shift, scale, gate. gate remains unused for now

    @staticmethod
    def affine_op(x_, shift, scale):
        # x is [B, C, L], shift and scale are [B, C, 1]
        assert len(x_.shape) == len(shift.shape), f"{x_.shape} != {shift.shape}"
        return x_ * (1 + scale) + shift

    def get_time_mixer(self):
        if self.kernel == "s4":
            time_mixer = FFTConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
            )
        elif self.kernel == "gconv":
            time_mixer = GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                l_max=self.L,
            )
        elif self.kernel == "conv":
            time_mixer = GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                n_scales=1,
            )  # gconv with num_scales=1
        elif "simpleconv" in self.kernel:
            kernel_size = int(self.kernel.split("_")[-1])
            time_mixer = nn.Conv1d(
                self.C, self.C, kernel_size, padding="same", padding_mode="circular"
            )
        else:
            raise ValueError(f"unknown kernel {self.kernel}")

        return time_mixer

    def get_channel_mixer(self, num_lin_per_mlp=2):
        layers = [
            Rearrange("b c l -> b l c"),
            nn.Linear(self.C, self.C * 2, bias=False),  # required for zero-init block
        ]
        # extra linear layers prepended by GELU
        for _ in range(max(num_lin_per_mlp - 2, 0)):
            layers.extend(
                [
                    nn.GELU(),
                    nn.Linear(self.C * 2, self.C * 2, bias=False),
                ]
            )
        layers.extend(
            [
                nn.GELU(),
                nn.Linear(self.C * 2, self.C, bias=False),
            ]
        )
        # finally rearrange back to [B, C, L]
        layers.append(Rearrange("b l c -> b c l"))
        return nn.Sequential(*layers)

    def forward(self, x):
        y = x  # x is residual stream
        y = self.norm1(y)
        ada_ln = repeat(self.ada_ln, "1 d c -> b d c", b=x.shape[0])
        shift_tm, scale_tm, gate_tm, shift_cm, scale_cm, gate_cm = ada_ln.chunk(
            6, dim=1
        )
        y = self.affine_op(y, shift_tm, scale_tm)
        y = self.time_mixer(y)
        y = y[0]  # get output not state for gconv and fftconv
        # y = x + gate_tm.unsqueeze(-1) * self.act1(y)
        y = x + self.post_tm_scale(self.act1(y))

        x = y  # x is again residual stream from last layer
        y = self.norm2(y)
        y = self.affine_op(y, shift_cm, scale_cm)
        # y = x + gate_cm.unsqueeze(-1) * self.act2(self.channel_mixer(y))
        y = x + self.act2(self.channel_mixer(y))
        return y


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class AutoEncoder(nn.Module):
    def __init__(
        self,
        C_in,
        C,
        C_latent,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        in_groups=None,
        bottleneck_groups=None,
        num_blocks=4,
        num_blocks_decoder=None,
        num_lin_per_mlp=2,
        use_act_bottleneck=False,
    ):
        super().__init__()
        self.C_in = C_in
        self.C = C
        self.C_latent = C_latent
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        if in_groups is None:
            if C % C_in == 0 and C > C_in:
                in_groups = C_in
            else:
                in_groups = 1

        if bottleneck_groups is None:
            if C % C_latent == 0 and C > C_latent:
                bottleneck_groups = C_latent
            else:
                bottleneck_groups = 1

        self.encoder_in = nn.Conv1d(
            C_in,
            C,
            1,
            # groups=in_groups,
        )  # in_groups matter for encoding count data
        self.encoder = nn.ModuleList(
            [
                AutoEncoderBlock(
                    C,
                    L,
                    kernel,
                    bidirectional,
                    kernel_params,
                    num_lin_per_mlp=num_lin_per_mlp,
                )
                for _ in range(num_blocks)
            ]
        )

        self.bottleneck = nn.Conv1d(C, C_latent, 1, groups=bottleneck_groups)
        self.act_bottleneck = nn.GELU() if use_act_bottleneck else nn.Identity()
        self.unbottleneck = nn.Conv1d(C_latent, C, 1, groups=bottleneck_groups)

        if num_blocks_decoder == 0:
            self.decoder = nn.ModuleList([nn.GELU()])  # jsut the activation

        else:
            self.decoder = nn.ModuleList(
                [
                    AutoEncoderBlock(
                        C,
                        L,
                        kernel,
                        bidirectional,
                        kernel_params,
                        num_lin_per_mlp=num_lin_per_mlp,
                    )
                    for _ in range(
                        (
                            num_blocks
                            if num_blocks_decoder is None
                            else num_blocks_decoder
                        )
                    )
                ]
            )

        # self.decoder_out = nn.Conv1d(C, C_in, 1, groups=in_groups)
        self.decoder_out = nn.Conv1d(C, C_in, 1)

    def encode(self, x):
        z = self.encoder_in(x)
        for block in self.encoder:
            z = block(z)
        z = self.bottleneck(z)
        return z

    def decode(self, z):
        xhat = self.act_bottleneck(z)
        xhat = self.unbottleneck(xhat)
        for block in self.decoder:
            xhat = block(xhat)
        xhat = self.decoder_out(xhat)
        return xhat

    def forward(self, x):
        z = self.encode(x)
        xhat = self.decode(z)
        return xhat, z


class DenoiserBlock(nn.Module):
    def __init__(
        self, C, L, kernel="s4", bidirectional=True, kernel_params=None, use_act2=False
    ):
        super().__init__()
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        self.time_mixer = self.get_time_mixer()
        self.channel_mixer = self.get_channel_mixer()
        self.norm1 = nn.InstanceNorm1d(
            C, affine=False
        )  # affine=False because we will use adaLN
        self.norm2 = nn.InstanceNorm1d(C, affine=False)
        self.act1 = nn.GELU()
        self.act2 = nn.GELU() if use_act2 else nn.Identity()
        self.ada_ln = nn.Sequential(  # gets as input [B, C]
            nn.GELU(),
            nn.Linear(
                C // 4,
                C * 6,
                bias=True,
            ),
        )

        # zero-init all weights and biases of ada_ln linear layer
        self.ada_ln[-1].weight.data.zero_()
        self.ada_ln[-1].bias.data.zero_()

    def get_time_mixer(self):
        if self.kernel == "s4":
            return FFTConv(
                self.C, self.C, bidirectional=self.bidirectional, activation=None
            )
        elif self.kernel == "gconv":
            return GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                l_max=self.L,
            )
        elif self.kernel == "conv":
            return GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                n_scales=1,
            )  # gconv with num_scales=1
        elif "simpleconv" in self.kernel:
            kernel_size = int(self.kernel.split("_")[-1])
            return nn.Conv1d(
                self.C, self.C, kernel_size, padding="same", padding_mode="circular"
            )
        else:
            raise ValueError(f"unknown kernel {self.kernel}")

    def get_channel_mixer(self):
        return nn.Sequential(
            Rearrange("b c l -> b l c"),
            nn.Linear(self.C, self.C * 2),
            nn.GELU(),
            nn.Linear(self.C * 2, self.C),
            Rearrange("b l c -> b c l"),
        )

    @staticmethod
    def affine_op(x, shift, scale):
        # x is [B, C, L], shift and scale are [B, C]
        return x * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)

    def forward(self, x, t_cond):
        y = x  # x is residual stream
        y = self.norm1(y)
        shift_tm, scale_tm, gate_tm, shift_cm, scale_cm, gate_cm = self.ada_ln(
            t_cond
        ).chunk(6, dim=1)
        y = self.affine_op(y, shift_tm, scale_tm)
        y = self.time_mixer(y)
        y = y[0]  # get output not state for gconv and fftconv
        y = x + gate_tm.unsqueeze(-1) * self.act1(y)

        x = y  # x is again residual stream from last layer
        y = self.norm2(y)
        y = self.affine_op(y, shift_cm, scale_cm)
        y = x + gate_cm.unsqueeze(-1) * self.act2(self.channel_mixer(y))
        return y


class Denoiser(nn.Module):
    def __init__(
        self,
        C_in,
        C,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        in_groups=None,
        num_blocks=6,
    ):
        super().__init__()
        self.C_in = C_in
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        if in_groups is None:  # grouped by default
            if C % C_in == 0 and C > C_in:
                in_groups = C_in
            else:
                in_groups = 1

        self.conv_in = nn.Conv1d(C_in, C, 1, groups=in_groups)
        self.blocks = nn.ModuleList(
            [
                DenoiserBlock(C, L, kernel, bidirectional, kernel_params)
                for _ in range(num_blocks)
            ]
        )
        self.conv_out = nn.Conv1d(C, C_in, 1, groups=in_groups)

        self.t_emb = TimestepEmbedder(C // 4)  # [B, C//4] to keep param count in check

    def forward(self, x, t):
        x = self.conv_in(x)
        t_emb = self.t_emb(t.to(x.device))
        # print(x.shape, t_emb.shape)
        for block in self.blocks:
            x = block(x, t_emb)
            # print(x.shape, t_emb.shape)
        x = self.conv_out(x)
        return x


class ConditionalDenoiser(nn.Module):
    def __init__(
        self,
        C_in,
        C,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        in_groups=None,
        num_blocks=6,
        condition_dim=2,  # for now, just 2D condition
    ):
        super().__init__()
        self.C_in = C_in
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        if in_groups is None:  # grouped by default
            if C % C_in == 0 and C > C_in:
                in_groups = C_in
            else:
                in_groups = 1

        self.conv_in = nn.Conv1d(C_in, C, 1, groups=in_groups)
        self.blocks = nn.ModuleList(
            [
                DenoiserBlock(C, L, kernel, bidirectional, kernel_params)
                for _ in range(num_blocks)
            ]
        )
        self.conv_out = nn.Conv1d(C, C_in, 1, groups=in_groups)

        self.t_emb = TimestepEmbedder(C // 4)  # [B, C//4] to keep param count in check
        self.c_emb = nn.Sequential(
            nn.Linear(condition_dim, C // 4, bias=True),
            nn.SiLU(),
            nn.Linear(C // 4, C // 4, bias=True),
        )  # [B, C//4]

    def forward(self, x, t, c=None):
        x = self.conv_in(x)
        t_emb = self.t_emb(t.to(x.device))

        if c is not None:
            c_emb = self.c_emb(c.to(x.device))
            t_emb = t_emb + c_emb  # otherwise just use timestep embedding

        for block in self.blocks:
            x = block(x, t_emb)
        x = self.conv_out(x)
        return x


class TSConditionalDenoiserBlock(nn.Module):
    def __init__(
        self,
        C,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        use_act2=False,
        time_condition_dim=None,  # time-point specific conditioning
    ):
        super().__init__()
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        self.time_mixer = self.get_time_mixer()
        self.channel_mixer = self.get_channel_mixer()
        self.norm1 = nn.InstanceNorm1d(
            C, affine=False
        )  # affine=False because we will use adaLN
        self.norm2 = nn.InstanceNorm1d(C, affine=False)
        self.act1 = nn.GELU()
        self.act2 = nn.GELU() if use_act2 else nn.Identity()
        self.ada_ln = nn.Sequential(  # gets as input [B, C]
            nn.GELU(),
            nn.Linear(
                C // 4,
                C * 6,
                bias=True,
            ),
        )

        # zero-init all weights and biases of ada_ln linear layer
        self.ada_ln[-1].weight.data.zero_()
        self.ada_ln[-1].bias.data.zero_()

        self.mask_cond_emb = nn.Sequential(
            nn.GELU(),
            Rearrange("b c l -> b l c"),
            nn.Linear(C // 4, C, bias=False),
            Rearrange("b l c -> b c l"),
        )

        self.time_condition_dim = time_condition_dim
        if time_condition_dim is not None:
            self.t_cond_time_mixer = self.get_time_mixer()
        else:
            self.t_cond_time_mixer = nn.Identity()

    def get_time_mixer(self):
        if self.kernel == "s4":
            return FFTConv(
                self.C, self.C, bidirectional=self.bidirectional, activation=None
            )
        elif self.kernel == "gconv":
            return GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                l_max=self.L,
            )
        elif self.kernel == "conv":
            return GConv(
                self.C,
                self.C,
                bidirectional=self.bidirectional,
                activation=None,
                n_scales=1,
            )  # gconv with num_scales=1
        elif "simpleconv" in self.kernel:
            kernel_size = int(self.kernel.split("_")[-1])
            return nn.Conv1d(
                self.C, self.C, kernel_size, padding="same", padding_mode="circular"
            )
        else:
            raise ValueError(f"unknown kernel {self.kernel}")

    def get_channel_mixer(self):
        return nn.Sequential(
            Rearrange("b c l -> b l c"),
            nn.Linear(self.C, self.C * 2),
            nn.GELU(),
            nn.Linear(self.C * 2, self.C),
            Rearrange("b l c -> b c l"),
        )

    @staticmethod
    def affine_op(x, shift, scale):
        # x is [B, C, L], shift and scale are [B, C]
        return x * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)

    def forward(
        self,
        x,
        t_emb,
        mask_cond,
        t_cond=None,
    ):
        """
        Args:
            x: [B, C, L]
            t_emb: timestep embedding [B, C//4] --> [B, C]
                    (gets conditioned through AdaIN1)
            t_cond [Optional]: timepoint wise conditioning [B, C, L]
                    (gets added to x before gate_tm)
            mask_cond [Optional]: timepoint wise mask conditioning [B, C//4, L] --> [B, C, L]
                    (gets added to x before gate_tm)

        """

        ## time mixing
        #
        y = x  # x is residual stream
        y = self.norm1(y)
        shift_tm, scale_tm, gate_tm, shift_cm, scale_cm, gate_cm = self.ada_ln(
            t_emb
        ).chunk(6, dim=1)
        y = self.affine_op(y, shift_tm, scale_tm)
        y = self.time_mixer(y)
        y = y[0]  # get output not state for gconv and fftconv

        # add mask conditioning (subspace change for mask vs no mask)
        y = y + self.mask_cond_emb(mask_cond)

        # add timepoint conditioning
        if self.time_condition_dim is not None:
            y = y + (self.t_cond_time_mixer(t_cond)[0] if t_cond is not None else 0)
        elif t_cond is not None:
            raise ValueError(
                "timepoint conditioning is passed but time_condition_dim is None"
            )

        y = x + gate_tm.unsqueeze(-1) * self.act1(y)

        ## channel mixing
        #
        x = y  # x is again residual stream from last layer
        y = self.norm2(y)
        y = self.affine_op(y, shift_cm, scale_cm)
        y = x + gate_cm.unsqueeze(-1) * self.act2(self.channel_mixer(y))
        return y


class TSConditionalDenoiser(nn.Module):
    def __init__(
        self,
        C_in,
        C,
        L,
        kernel="s4",
        bidirectional=True,
        kernel_params=None,
        in_groups=None,
        num_blocks=6,
        time_condition_dim=None,  # time-point specific conditioning
        condition_dim=None,  # global conditioning across all timesteps
    ):
        super().__init__()
        self.C_in = C_in
        self.C = C
        self.L = L
        self.bidirectional = bidirectional
        self.kernel = kernel
        self.kernel_params = kernel_params
        if in_groups is None:  # grouped by default
            if C % C_in == 0 and C > C_in:
                in_groups = C_in
            else:
                in_groups = 1

        self.conv_in = nn.Conv1d(C_in, C, 1, groups=in_groups)
        self.blocks = nn.ModuleList(
            [
                TSConditionalDenoiserBlock(
                    C,
                    L,
                    kernel,
                    bidirectional,
                    kernel_params,
                    time_condition_dim=time_condition_dim,
                )
                for _ in range(num_blocks)
            ]
        )
        self.conv_out = nn.Conv1d(C, C_in, 1, groups=in_groups)

        self.t_emb = TimestepEmbedder(C // 4)  # [B, C//4] to keep param count in check

        self.c_t_emb = nn.Conv1d(time_condition_dim, C, 1)  # [B, C, L]

        self.condition_dim = condition_dim
        if condition_dim is not None:
            self.c_emb = nn.Sequential(
                nn.Linear(condition_dim, C // 4, bias=True),
                nn.SiLU(),
                nn.Linear(C // 4, C // 4, bias=True),
            )  # [B, C//4]
        else:
            self.c_emb = nn.Identity()

        self.mask_emb = nn.Sequential(
            Rearrange("b c l -> b l c"),
            nn.Linear(2, C // 4, bias=True),  # one for latent one for embedding
            nn.SiLU(),
            nn.Linear(C // 4, C // 4, bias=True),
            Rearrange("b l c -> b c l"),
        )  # [B, C//4, L]

    def forward(self, x, t, t_mask, c_t=None, c=None):

        x = self.conv_in(x)

        # compute diffusion timestep embedding
        t_emb = self.t_emb(t.to(x.device))

        # compute timepoint mask embedding
        mask_emb = self.mask_emb(t_mask.to(x.device))

        # add global external conditioning if passed
        if c is not None:
            # raise error if c is passed and condition_dim is None
            if self.condition_dim is None:
                raise ValueError(
                    "global conditioning c is passed but condition_dim is None"
                )

            c_emb = self.c_emb(c.to(x.device))
            t_emb = t_emb + c_emb  # otherwise just use timestep embedding

        # add timepoint-specific external conditioning
        if c_t is not None:
            c_t_emb = self.c_t_emb(c_t.to(x.device))
        else:
            c_t_emb = None

        for block in self.blocks:
            x = block(x, t_emb, mask_cond=mask_emb, t_cond=c_t_emb)

        x = self.conv_out(x)

        return x


class S4AE(nn.Module):
    def __init__(self, d_model, d_hidden, num_blocks, downsample_factor, d_latent):
        super().__init__()
        self.downsample_factor = downsample_factor
        self.d_model = d_model

        self.conv_start = nn.Conv1d(d_model, d_hidden, 1)

        common_args = dict(
            bottleneck=None,  # or 4 for GSS
            gate=None,  # or 4 for GSS
            channels=1,
            gate_act="glu",  # or 4 for GSS
            mult_act=None,
            final_act="glu",  # Final activation after FF layer; new name for 'postact'
            initializer=None,
            weight_norm=False,
            dropout=0.0,  # Same as null
            tie_dropout=None,  # or model.tie_dropout
            transposed=True,  # Choose backbone axis ordering,
            use_slconv=False,
        )
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        for i in range(num_blocks):
            if i != num_blocks - 1:
                downsample_factor = self.downsample_factor
            else:
                pass
                # downsample_factor = (d_hidden // (self.downsample_factor**i))/(d_model//(self.downsample_factor**(i+1))) # conv does the downsampling
            self.encoder.append(
                S4Block(
                    # d_model=d_hidden // (self.downsample_factor**i),
                    d_model=d_hidden,
                    downsample_factor=1,
                    **common_args,
                )
            )
        self.conv_middle = nn.Conv1d(
            # d_hidden // (self.downsample_factor**num_blocks),
            # d_model // (self.downsample_factor**num_blocks),
            d_hidden,
            d_latent,
            1,
        )

        for i in range(num_blocks):
            if i != 0:
                downsample_factor = self.downsample_factor
                # d_model_ = d_hidden // (self.downsample_factor ** (num_blocks - i))
                d_model_ = d_hidden
            else:
                # downsample_factor = (d_hidden // (self.downsample_factor**i)) / (
                #     d_model // (self.downsample_factor ** (i + 1))
                # )
                # d_model_ = d_model // (self.downsample_factor ** (num_blocks - i))
                downsample_factor = d_latent / d_hidden
                d_model_ = d_latent
            self.decoder.append(
                S4Block(
                    d_model=d_model_,
                    downsample_factor=(1.0 if i != 0 else downsample_factor),
                    **common_args,
                )
            )
        self.conv_end = nn.Conv1d(d_hidden, d_model, 1)

    def forward(self, x):
        x = self.conv_start(x)
        for i, block in enumerate(self.encoder):
            # print(i, x.shape)
            x, _ = block(x)

        # print(x.shape)
        x = self.conv_middle(x)
        # print(x.shape)

        z = x
        # x = torch.nn.functional.glu(x, dim=1)
        for i, block in enumerate(self.decoder):
            # print(i, x.shape)
            x, _ = block(x)
        x = self.conv_end(x)
        return x, z

    def forward_encoder(self, x):
        x = self.conv_start(x)
        for i, block in enumerate(self.encoder):
            x, _ = block(x)
        x = self.conv_middle(x)
        return x

    def forward_decoder(self, x):
        x = self.conv_end(x)
        for i, block in enumerate(self.decoder):
            x, _ = block(x)
        return x


class SimpleEncoder(nn.Module):
    def __init__(self, input_channels, hidden_sizes, bottleneck_size):
        super(SimpleEncoder, self).__init__()

        self.temporal_conv = nn.Conv1d(
            input_channels,
            input_channels,
            kernel_size=5,
            groups=input_channels,
            padding=2,
            stride=1,
            bias=False,
            padding_mode="replicate",
        )

        self.linear1 = nn.Sequential(
            nn.Linear(input_channels, hidden_sizes[0]), nn.ELU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(hidden_sizes[0], hidden_sizes[1]), nn.ELU()
        )

        # map to mean and variance
        self.mean = nn.Linear(hidden_sizes[1], bottleneck_size)
        self.logvar = nn.Sequential(
            nn.Linear(hidden_sizes[1], bottleneck_size), nn.Softplus()
        )

    def forward(self, x):
        x = self.temporal_conv(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        x = self.linear1(x)
        x = self.linear2(x)
        mean = self.mean(x)
        logvar = self.logvar(x)
        return mean, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


class Decoder(nn.Module):
    def __init__(
        self, bottleneck_size, output_channels, frozen=True, C=None, b=None, bias=True
    ):
        super(Decoder, self).__init__()
        if frozen:
            self.decoder = nn.Linear(bottleneck_size, output_channels, bias=bias)
            self.decoder.weight = nn.Parameter(C.T)
            print("Setting decoder weights to be frozen")
            self.decoder.weight.requires_grad = False

            if bias and (b is not None):
                print("Setting bias")
                self.decoder.bias = nn.Parameter(b.T)
                self.decoder.bias.requires_grad = False

        else:
            print("Simple linear layer with learnable weights and biases")
            self.decoder = nn.Linear(bottleneck_size, output_channels)

    def forward(self, x):
        # x: [batch, seqlen, bottleneck], no need to permute for Linear
        return F.softplus(self.decoder(x))


class VAE(nn.Module):
    def __init__(
        self,
        input_channels,
        hidden_sizes,
        bottleneck_size,
        output_channels,
        C=None,
        b=None,
        bias=True,
        frozen=True,
    ):
        super(VAE, self).__init__()
        self.encoder = SimpleEncoder(input_channels, hidden_sizes, bottleneck_size)
        self.decoder = Decoder(
            bottleneck_size, output_channels, C=C, frozen=frozen, bias=bias, b=b
        )

    def forward(self, x):
        mu, logvar = self.encoder(x)  # Encode to mean and logvar
        z = self.encoder.reparameterize(mu, logvar)  # Sample from latent space
        out = self.decoder(z)  # Decode back to original space
        return out, mu, logvar

    def decode_mu(self, x):
        mu, logvar = self.encoder(x)  # Encode to mean and logvar
        out = self.decoder(mu)  # Decode back to original space
        return out, mu, logvar


class QueryBasedPooling(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads=8):
        super().__init__()

        self.query = nn.Parameter(torch.randn(1, hidden_dim))

        self.key_layer = nn.Linear(in_dim, hidden_dim)
        self.value_layer = nn.Linear(in_dim, hidden_dim)
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, hidden_dim)  # reduce dim

    def forward(self, x):

        keys = self.key_layer(x)  # (B, T, C)
        values = self.value_layer(x)  # (B, T, C//4)
        queries = repeat(self.query, "1 c -> b 1 c", b=x.shape[0])  # (B, 1, C)

        nH = self.num_heads
        keys = rearrange(keys, "b t (h c) -> b h t c", h=nH)
        values = rearrange(values, "b t (h c) -> b h t c", h=nH)
        queries = rearrange(queries, "b t (h c) -> b h t c", h=nH)

        att = einsum(queries, keys, "b h t c, b h s c -> b h t s")
        att = att / np.sqrt(self.hidden_dim // nH)  # scale
        att = torch.nn.functional.softmax(att, dim=-1)

        out = einsum(att, values, "b h t s, b h s c -> b h t c")
        out = rearrange(out, "b h t c -> b t (h c)")

        out = self.proj(out).squeeze(1)

        return out


if __name__ == "__main__":

    import lovely_tensors
    from torchinfo import summary

    lovely_tensors.monkey_patch()

    batch_size = 10
    seq_len = 20
    in_dim = 768
    hidden_dim = 128
    num_heads = 4

    embeddings = torch.randn(batch_size, seq_len, in_dim)
    model = QueryBasedPooling(in_dim, hidden_dim, num_heads)
    print("number of params", sum(p.numel() for p in model.parameters()) / 1e6, "M")
    pooled_output = model(embeddings)
    print(pooled_output.shape)  # should be (batch_size, hidden_dim)

    # # test AutoEncoderBlock
    # block = AutoEncoderBlock(128, 500, kernel="s4")
    # x = torch.randn(2, 128, 500)
    # print(summary(block, (2, 128, 500), device="cpu"))
    # print(x - block(x))

    # # test DenoiserBlock
    # denoiser_block = DenoiserBlock(128, 500, kernel="gconv")
    # x = torch.randn(2, 128, 500)
    # t_cond = torch.randn(2, 128 // 4)  # [B, C//4] to keep param count in check
    # print(summary(denoiser_block, [(2, 128, 500), (2, 128 // 4)], device="cpu"))
    # print(x - denoiser_block(x, t_cond))

    # test AutoEncoder
    # ae = AutoEncoder(128, 512, 16, 500, kernel="s4", num_blocks=4, num_lin_per_mlp=1)
    # x = torch.randn(2, 128, 500)
    # print(summary(ae, (2, 128, 500), device="cpu", depth=4))

    # # test Denoiser
    # denoiser = Denoiser(8, 256, 500, kernel="s4")
    # x = torch.randn(2, 8, 500)
    # t = torch.randn(2)
    # print(summary(denoiser, [(2, 8, 500), (2,)], device="cpu"))

    # vae = VAE(3, [32, 16], 8, 3, bias=False, frozen=False)

    # x = torch.randn(32, 10, 3)
    # out, mu, logvar = vae(x)
    # print(out.shape)
    # print(mu.shape)
    # print(logvar.shape)
