# https://github.com/mirthAI/Fast-DDPM/blob/main/models/diffusion.py
from typing import List

import torch
import torch.nn as nn

from model.denoising_network.resnet.attnblock import AttnBlock
from model.denoising_network.resnet.block import ResnetBlock
from model.denoising_network.resnet.downsample import Downsample
from model.denoising_network.resnet.nonlinearity import non_linearity
from model.denoising_network.resnet.normalize import get_normalize_block
from model.denoising_network.resnet.timestep_embedding import get_timestep_embedding
from model.denoising_network.resnet.upsample import Upsample


class DenoisingNetworkResnet(nn.Module):

    def __init__(
        self, ch: int, ch_mult: List[int], num_res_blocks: int, attn_resolutions: List[int], dropout: float,
        seq_len: int, resamp_with_conv: bool, n_features: int, max_lag: int
    ) -> None:
        super().__init__()
        self.max_lag = max_lag
        self.ch = ch
        self.temb_ch = self.ch * 4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.seq_len = seq_len
        self.n_features = n_features

        # Timestep embedding
        self.temb = nn.Module()
        self.temb.dense = nn.ModuleList([nn.Linear(self.ch, self.temb_ch), nn.Linear(self.temb_ch, self.temb_ch)])

        # Input projection
        self.conv_in = nn.Conv1d(n_features, self.ch, kernel_size=3, stride=1, padding=1)

        # Down-sampling
        curr_res = seq_len
        in_ch_mult = (1,) + ch_mult
        self.down = nn.ModuleList()
        block_in = None
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ResnetBlock(
                        in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
                    )
                )
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # Middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(
            in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
        )
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(
            in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
        )

        # Up-sampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            skip_in = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks + 1):
                if i_block == self.num_res_blocks:
                    skip_in = ch * in_ch_mult[i_level]
                block.append(
                    ResnetBlock(
                        in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch,
                        dropout=dropout
                    )
                )
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up)  # prepend to get consistent order

        self.norm_out = get_normalize_block(block_in)

        # Output first time-steps
        self.conv_out = nn.Conv1d(block_in, n_features, kernel_size=3, stride=1, padding=1)

        # Output coefficients
        self.out_coefficients = nn.ModuleList(
            [nn.Conv1d(block_in, n_features * max_lag, kernel_size=3, stride=1, padding=1) for _ in range(n_features)]
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> [torch.Tensor, torch.Tensor]:
        x = x.transpose(1, 2)

        # Timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb.dense[0](temb)
        temb = non_linearity(temb)
        temb = self.temb.dense[1](temb)

        # Input projection
        hs = [self.conv_in(x)]

        # Down-sampling
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # Middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # Up-sampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb
                )
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # Output projection
        h = self.norm_out(h)
        h = non_linearity(h)  # [B, 128, seq_len]

        h1 = h[:, :, :self.max_lag]  # [B, 128, max_lag]
        h2 = h[:, :, self.max_lag:]  # [B, 128, seq_len-max_lag]
        out = self.conv_out(h1)  # [B, n_features, max_lag]

        coefficients = torch.stack([out_coefficient(h2) for out_coefficient in self.out_coefficients], dim=1)
        # [B, n_features, n_features*max_lag, seq_len-max_lag]
        for i in range(0, self.seq_len - self.max_lag):
            a = out[:, :, -self.max_lag:].flatten(1)  # [B, n_features*max_lag]
            b = coefficients[:, :, :, i]  # [B, n_features, n_features*max_lag]
            c = torch.einsum('Bx,Bfx->Bf', a, b)
            out = torch.cat([out, c.unsqueeze(-1)], dim=-1)

        return out.transpose(1, 2), coefficients
