import math
from inspect import isfunction
from functools import partial

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat


# from ldm.modules.diffusionmodules.util import checkpoint


def uniq(arr):
    return {el: True for el in arr}.keys()


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    # return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    return nn.BatchNorm1d(in_channels, in_channels)


# class LinearAttention(nn.Module):
#     def __init__(self, dim, heads=4, dim_head=32):
#         super().__init__()
#         self.heads = heads
#         hidden_dim = dim_head * heads
#         self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
#         self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
#     def forward(self, x):
#         b, c, h, w = x.shape
#         qkv = self.to_qkv(x)
#         q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
#         k = k.softmax(dim=-1)
#         context = torch.einsum('bhdn,bhen->bhde', k, v)
#         out = torch.einsum('bhde,bhdn->bhen', context, q)
#         out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
#         return self.to_out(out)
#
#
# class SpatialSelfAttention(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.in_channels = in_channels
#
#         self.norm = Normalize(in_channels)
#         self.q = torch.nn.Conv2d(in_channels,
#                                  in_channels,
#                                  kernel_size=1,
#                                  stride=1,
#                                  padding=0)
#         self.k = torch.nn.Conv2d(in_channels,
#                                  in_channels,
#                                  kernel_size=1,
#                                  stride=1,
#                                  padding=0)
#         self.v = torch.nn.Conv2d(in_channels,
#                                  in_channels,
#                                  kernel_size=1,
#                                  stride=1,
#                                  padding=0)
#         self.proj_out = torch.nn.Conv2d(in_channels,
#                                         in_channels,
#                                         kernel_size=1,
#                                         stride=1,
#                                         padding=0)
#
#     def forward(self, x):
#         h_ = x
#         h_ = self.norm(h_)
#         q = self.q(h_)
#         k = self.k(h_)
#         v = self.v(h_)
#
#         # compute attention
#         b, c, h, w = q.shape
#         q = rearrange(q, 'b c h w -> b (h w) c')
#         k = rearrange(k, 'b c h w -> b c (h w)')
#         w_ = torch.einsum('bij,bjk->bik', q, k)
#
#         w_ = w_ * (int(c) ** (-0.5))
#         w_ = torch.nn.functional.softmax(w_, dim=2)
#
#         # attend to values
#         v = rearrange(v, 'b c h w -> b c (h w)')
#         w_ = rearrange(w_, 'b i j -> b j i')
#         h_ = torch.einsum('bij,bjk->bik', v, w_)
#         h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
#         h_ = self.proj_out(h_)
#
#         return x + h_


class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        # print('num_heads, dim_head:', heads, dim_head)
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        # print('q,k,v shape 1', q.size(), k.size(), v.size())
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        # print('q,k,v shape 2', q.size(), k.size(), v.size())

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        # print('attn.shape', attn.shape)

        out = einsum('b i j, b j d -> b i d', attn, v)
        # print('out.shape', out.shape)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
                                    dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        # self.norm1 = nn.Identity()
        # self.norm2 = nn.Identity()
        # self.norm3 = nn.Identity()
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
        return self._forward(x, context)

    def _forward(self, x, context=None):
        # print('x.shape', x.shape)
        # print('context.shape', context.shape)  # torch.Size([40, 1, 768])
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """

    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv1d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
             for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv1d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        # b, c, d = x.shape
        x_in = x
        # print('-' * 10)
        # print('attn in x.shape', x.shape)
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c d -> b d c')
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b d c -> b c d')
        x = self.proj_out(x)
        return x + x_in


# ---------------------------- unet_conv

def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv1d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (d p) -> b (c p) d", p=2),
        nn.Conv1d(dim * 2, default(dim_out, dim), 1),
    )


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    def __init__(self, dim, dim_out):
        super().__init__()
        self.proj = nn.Conv1d(dim, dim_out, kernel_size=3, padding=1)
        self.norm = nn.BatchNorm1d(dim_out, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        # print('x.shape', x.shape)
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            # print('scale.shape', scale.shape)
            # print('shift.shape', shift.shape)
            # print('x.shape', x.shape)
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out)
        self.block2 = Block(dim_out, dim_out)
        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1")
            # print('time_emb', time_emb.shape)
            scale_shift = time_emb.chunk(2, dim=1)
        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class UNet(nn.Module):
    def __init__(
            self,
            dim,  # 128
            dim_cond=768,
            ch=1,  # 1 # =input ch
            ch_mults=(1, 2, 4, 8),  # (1,2,4)
            init_ch=None,
            final_ch=None,  # =output_ch
            num_heads=4,
            self_condition=False,
            # resnet_block_groups=4,
    ):
        super().__init__()

        # determine dimensions
        self.ch = ch
        self.self_condition = self_condition
        input_ch = ch * (3 if self_condition else 1)

        init_ch = default(init_ch, dim // 4)  # 32 # divide by 4: 2x (2x dimensionality reduction)
        self.init_conv = nn.Conv1d(input_ch, init_ch, 1, padding=0)  # changed to 1 and 0 from 7,3

        chs = [init_ch, *map(lambda m: dim * m, ch_mults)]  # [32, 128, 256, 512] # TODO change dim
        # print('chs:', chs)
        in_out = list(zip(chs[:-1], chs[1:]))  # [(32, 128), (128, 256), (256, 512)]
        # print('in_out:', in_out)

        final_ch = default(final_ch, ch)

        block_klass = partial(ResnetBlock)

        # time embeddings
        dim_time = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, dim_time),
            nn.GELU(),
            nn.Linear(dim_time, dim_time),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (ch_in, ch_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            # print('enc:', ind, ch_in, ch_out)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(ch_in, ch_in, time_emb_dim=dim_time),
                        block_klass(ch_in, ch_in, time_emb_dim=dim_time),
                        # Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        # nn.Identity(),
                        SpatialTransformer(in_channels=ch_in, n_heads=num_heads,
                                           d_head=ch_in // num_heads, context_dim=dim_cond),
                        Downsample(ch_in, ch_out)
                        if not is_last
                        else nn.Conv1d(ch_in, ch_out, 3, padding=1),  # 2x dimensionality reduction
                    ]
                )
            )

        ch_mid = chs[-1]
        self.mid_block1 = block_klass(ch_mid, ch_mid, time_emb_dim=dim_time)
        # self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        # self.mid_attn = nn.Identity()
        self.mid_attn = SpatialTransformer(in_channels=ch_mid, n_heads=num_heads,
                                           d_head=ch_mid // num_heads, context_dim=dim_cond)
        self.mid_block2 = block_klass(ch_mid, ch_mid, time_emb_dim=dim_time)

        for ind, (ch_in, ch_out) in enumerate(reversed(in_out)):
            # print('dec:', ind, ch_in, ch_out)
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(ch_out + ch_in, ch_out, time_emb_dim=dim_time),
                        block_klass(ch_out + ch_in, ch_out, time_emb_dim=dim_time),
                        # Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        # nn.Identity(),
                        SpatialTransformer(in_channels=ch_out, n_heads=num_heads,
                                           d_head=ch_out // num_heads, context_dim=dim_cond),
                        Upsample(ch_out, ch_in)
                        if not is_last
                        else nn.Conv1d(ch_out, ch_in, 3, padding=1),  # 2x increase in dimension
                    ]
                )
            )

        self.final_res_block = block_klass(init_ch * 2, init_ch, time_emb_dim=dim_time)
        self.final_conv = nn.Conv1d(init_ch, final_ch, 1)

        self.cond_conv = nn.Conv1d(ch, init_ch, 1)

    def forward(self, x, time, cond=None):
        # print('=' * 10)
        x_self_cond, x_guide_cond = cond
        # print('x_self_cond.shape', x_self_cond.shape)
        # print('x_guide_cond.shape', x_guide_cond.shape)

        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)
        # print('x.shape', x.shape)

        # x_guide_cond = self.cond_conv(x_guide_cond)
        # print('~' * 10)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            # x = attn(x, x_guide_cond)
            h.append(x)

            x = downsample(x)
            # print('enc x.shape', x.shape)

        x = self.mid_block1(x, t)
        # x = self.mid_attn(x, x_guide_cond)
        x = self.mid_block2(x, t)
        # print('mid x.shape', x.shape)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            # x = attn(x, x_guide_cond)

            x = upsample(x)
            # print('dec x.shape', x.shape)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)


if __name__ == '__main__':
    from torchsummary import summary

    unet = UNet(dim=128, dim_cond=768, ch=1, ch_mults=(1, 2, 4,), self_condition=True, num_heads=4)

    print(unet)

    summary(unet, (1, 28, 28))
