import math
from inspect import isfunction
from functools import partial

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv1d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (d p) -> b (c p) d", p=2),
        nn.Conv1d(dim * 2, default(dim_out, dim), 1),
    )


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    def __init__(self, dim, dim_out, feature_dim):
        super().__init__()
        self.proj = nn.Conv1d(dim, dim_out, kernel_size=3, padding=1)
        # self.norm = nn.BatchNorm1d(dim_out, dim_out)
        self.norm = nn.LayerNorm(int(feature_dim))
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        # print('x.shape', x.shape)
        x = self.proj(x)
        x = self.norm(x)
        if exists(scale_shift):
            scale, shift = scale_shift
            # print('scale.shape', scale.shape)
            # print('shift.shape', shift.shape)
            # print('x.shape', x.shape)
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, feature_dim=None):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, feature_dim)
        self.block2 = Block(dim_out, dim_out, feature_dim)
        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1")
            # print('time_emb', time_emb.shape)
            scale_shift = time_emb.chunk(2, dim=1)
        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, d = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) d -> b h c d", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h d h_d -> b (h h_d) d", d=d)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv1d(hidden_dim, dim, 1),
                                    nn.BatchNorm1d(dim))

    def forward(self, x):
        b, c, d = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) d -> b h c d", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c d -> b (h c) d", h=self.heads, d=d)
        return self.to_out(out)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class UNet(nn.Module):
    def __init__(
            self,
            dim,  # 128
            layer_channels=None,  # (1,2,4)
            init_ch=None,
            final_ch=None,
            self_condition=False,
            # resnet_block_groups=4,
    ):
        super().__init__()

        # determine dimensions
        self.self_condition = self_condition
        self.init_ch = init_ch
        self.final_ch = final_ch
        
        self.init_conv = nn.Conv1d(self.init_ch, layer_channels[0], 1, padding=0)
        self.init_pos_transform = nn.Linear(dim, dim)

        in_out = list(zip(layer_channels[:-1], layer_channels[1:]))  # [(32, 128), (128, 256), (256, 512)]
        # print('in_out:', in_out)
        
        block_klass = partial(ResnetBlock)

        # time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (ch_in, ch_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            # print(dim_in, dim_out)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(ch_in, ch_in, time_emb_dim=time_dim, feature_dim=dim),
                        block_klass(ch_in, ch_in, time_emb_dim=time_dim, feature_dim=dim),
                        nn.Identity(),
                        Downsample(ch_in, ch_out)
                        if not is_last
                        else nn.Conv1d(ch_in, ch_out, 3, padding=1),
                    ]
                )
            )
            if not is_last:
                dim /= 2

        mid_dim = layer_channels[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim, feature_dim=dim)
        # self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_attn = nn.Identity()
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim, feature_dim=dim)

        for ind, (ch_in, ch_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(ch_out + ch_in, ch_out, time_emb_dim=time_dim, feature_dim=dim),
                        block_klass(ch_out + ch_in, ch_out, time_emb_dim=time_dim, feature_dim=dim),
                        nn.Identity(),
                        Upsample(ch_out, ch_in)
                        if not is_last
                        else nn.Conv1d(ch_out, ch_in, 3, padding=1),
                    ]
                )
            )
            
            if not is_last:
                dim *= 2

        self.final_res_block = block_klass(layer_channels[0] * 2, layer_channels[0], time_emb_dim=time_dim, feature_dim=dim)
        self.final_conv = nn.Conv1d(layer_channels[0], self.final_ch, 1)

    def forward(self, x, time, x_self_cond=None):
        # print('-' * 10)
        # print('x.shape', x.shape)
        # print('x_self_cond.shape', x_self_cond.shape)
        
        if self.init_ch == 3:
            assert x_self_cond.shape[1] == 2
            x_cond, pos_emb = x_self_cond.chunk(2, dim=1)
            pos_emb = self.init_pos_transform(pos_emb)
            x_self_cond = torch.cat((x_cond, pos_emb), dim=1)
            
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)
        # print('x.shape', x.shape)
        # print('~' * 10)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
            # print('enc x.shape', x.shape)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
        # print('mid x.shape', x.shape)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
            # print('dec x.shape', x.shape)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)


if __name__ == '__main__':
    from torchsummary import summary

    unet = UNet(
        dim=128,
        ch=1,
        ch_mults=(1, 2, 4,)
    )

    print(unet)

    summary(unet, (1, 28, 28))