from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn

from cleandiffuser.nn_diffusion import BaseNNDiffusion
from cleandiffuser.nn_diffusion.jannerunet import ResidualBlock, Downsample1d


class HalfJannerUNet1d(BaseNNDiffusion):
    def __init__(
            self,
            horizon: int,
            in_dim: int,
            out_dim: int = 1,
            kernel_size: int = 3,
            model_dim: int = 32,
            emb_dim: int = 32,
            dim_mult: List[int] = [1, 2, 2, 2],
            timestep_emb_type: str = "positional",
            norm_type: str = "groupnorm",
    ):
        super().__init__(emb_dim, timestep_emb_type)

        dims = [in_dim] + [model_dim * m for m in np.cumprod(dim_mult)]
        in_out = list(zip(dims[:-1], dims[1:]))

        self.map_emb = nn.Sequential(
            nn.Linear(emb_dim, model_dim * 4), nn.Mish(),
            nn.Linear(model_dim * 4, model_dim))

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ResidualBlock(dim_in, dim_out, model_dim, kernel_size, norm_type),
                ResidualBlock(dim_out, dim_out, model_dim, kernel_size, norm_type),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon // 2

        mid_dim = dims[-1]
        mid_dim_2 = mid_dim // 2
        mid_dim_3 = mid_dim // 4

        self.mid_block1 = nn.ModuleList([
            ResidualBlock(mid_dim, mid_dim_2, model_dim, kernel_size=5, norm_type=norm_type),
            Downsample1d(mid_dim_2)])
        horizon = horizon // 2

        self.mid_block2 = nn.ModuleList([
            ResidualBlock(mid_dim_2, mid_dim_3, model_dim, kernel_size=5, norm_type=norm_type),
            Downsample1d(mid_dim_3)])
        horizon = horizon // 2

        fc_dim = mid_dim_3 * max(horizon, 1)

        self.final_block = nn.Sequential(
            nn.Linear(fc_dim + model_dim, fc_dim // 2),
            nn.Mish(),
            nn.Linear(fc_dim // 2, out_dim))

    def forward(self,
                x: torch.Tensor, noise: torch.Tensor,
                condition: Optional[torch.Tensor] = None):

        x = x.permute(0, 2, 1)

        emb = self.map_noise(noise)
        if condition is not None:
            emb = emb + condition
        emb = self.map_emb(emb)

        for resnet1, resnet2, downsample in self.downs:
            x = resnet1(x, emb)
            x = resnet2(x, emb)
            x = downsample(x)

        x = self.mid_block1[0](x, emb)
        x = self.mid_block1[1](x)
        x = self.mid_block2[0](x, emb)
        x = self.mid_block2[1](x)

        x = x.flatten(1)
        out = self.final_block(torch.cat([x, emb], dim=-1))
        return out
