# from https://huggingface.co/blog/annotated-diffusion

import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from inspect import isfunction
from functools import partial

'''
Helper functions
'''
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

'''
ResNet block
'''
class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) / (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


'''
Attention module
'''
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

'''
Unet
'''
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8), # multipliers for increasing channels
        channels=1,
        self_condition=False,
        resnet_block_groups=4,
        n_frames = 3,
        # pattern_resolution=17,
        # image_resolution=10
    ):
        super().__init__()
        self.n_frames = n_frames

        # determine dimensions
        pattern_enc_channels = 3
        self.channels = channels + pattern_enc_channels + 3 # noisy input, history encoding, start age, delta age for prediction, pattern encoding
        self.self_condition = self_condition
        input_channels = self.channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 20, 20, 40
        in_out = list(zip(dims[:-1], dims[1:])) # (20,20), (20,40)

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # B, dim_in, H, W
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # B, dim_in, H, W
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))), # B, dim_in, H, W
                        Downsample(dim_in, dim_out) # B, dim_in, H/2, W/2
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
        
        self.init_pattern = nn.Conv2d(1, init_dim, 1, padding=0) # layer to convert patttern to init_dim channels
        in_out_pattern = list(zip(dims[:-1], dims[1:])) # (20,20), (20,40)
        self.downs_pattern = nn.ModuleList([])
        self.ups_pattern = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(in_out_pattern):
            is_last = ind >= (num_resolutions - 1)

            self.downs_pattern.append(
                Downsample(dim_in, dim_out) # B, dim_in, H/2, W/2
                if not is_last 
                else nn.Conv2d(dim_in, dim_out, 3, padding=1),
            )
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out_pattern)):
            is_last = ind == (len(in_out_pattern) - 1)

            self.ups_pattern.append(
                Upsample(dim_out, dim_in)
                if not is_last
                else nn.Conv2d(dim_out, dim_in, 3, padding=1),
            )

        # additional MLPs for embeddings
        # age_time_dim = time_dim
        age_time_dim = 100
        self.start_age_mlp = nn.Sequential( # SPE+MLP for start age, which captures global age data
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, age_time_dim),
            nn.GELU(),
            nn.Linear(age_time_dim, age_time_dim),
        )

        self.delta_age_mlp = nn.Sequential( # SPE+MLP for age deltas, which capture relative positioning in time
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, age_time_dim),
            nn.GELU(),
            nn.Linear(age_time_dim, age_time_dim),
        )

        # history embedding
        self.history_mix = nn.Conv2d(self.n_frames*2, 1, 1) # channel mixing

        # pattern embedding
        self.pattern_enc = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 8, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(8, pattern_enc_channels, 3, padding=1),
            nn.SiLU()
        )

    def forward(self, x, x_c, start_age, age_deltas, pattern, time, x_self_cond=None):        
        B = x.size(0)

        # start_age embedding, init size (B, 1)
        sa = self.start_age_mlp(start_age).view(B,1,10,10) # B, 1, self.time_dim -> B, 1, 10, 10

        # delta_age embedding, init size (B, 16)
        da = self.delta_age_mlp(age_deltas).view(B,-1,10,10) # B, 1, (+1), self.time_dim -> B, self.n_frames+1, 10, 10

        # baselines conditioning
        history = torch.stack((x_c, da[:,:-1,:,:]), dim=2).view(-1, 2*self.n_frames,10,10) # concatenate x_c and da in zipper manner
        history_agg = self.history_mix(history) # B, 1, 10, 10: aggregate these by some simple convolution
        
        # pattern encoding
        penc = self.pattern_enc(pattern) # B, 1, 10, 10 -> B, 3, 10, 10
        rp = penc.clone()
        p = self.init_pattern(pattern) # B, 1, 10, 10 -> B, 20, 10, 10

        # concatenate all inputs
        x = torch.cat((x, penc, history_agg, sa, da[:,-1,:,:].unsqueeze(1)), dim=1) # B, 5, 10, 10

        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x) # B, C, 10, 10 -> B, 20, 10, 10
        r = x.clone()

        t = self.time_mlp(time) # B, dim*4

        h = []

        for i, (block1, block2, attn, downsample) in enumerate(self.downs):
            x = x + p
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
            p = self.downs_pattern[i](p)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for i, (block1, block2, attn, upsample) in enumerate(self.ups):
            x = x + p
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
            p = self.ups_pattern[i](p)

        x = x + p
        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

class Unet_MoCap(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8), # multipliers for increasing channels
        channels=1,
        self_condition=False,
        resnet_block_groups=4,
        n_frames = 10,
        # pattern_resolution=17,
        # image_resolution=10
    ):
        super().__init__()
        self.n_frames = n_frames
        self.n_horizon = out_dim

        # determine dimensions
        pattern_enc_channels = 3
        self.channels = out_dim + 2*n_frames + pattern_enc_channels # noisy input, history encoding, rotation encoding, pattern encoding
        self.self_condition = self_condition
        input_channels = self.channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 20, 20, 40
        in_out = list(zip(dims[:-1], dims[1:])) # (20,20), (20,40)

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # B, dim_in, H, W
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # B, dim_in, H, W
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))), # B, dim_in, H, W
                        Downsample(dim_in, dim_out) # B, dim_in, H/2, W/2
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
        
        self.init_pattern = nn.Conv2d(self.n_horizon, init_dim, 1, padding=0) # layer to convert patttern to init_dim channels
        in_out_pattern = list(zip(dims[:-1], dims[1:])) # (20,20), (20,40)
        self.downs_pattern = nn.ModuleList([])
        self.ups_pattern = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(in_out_pattern):
            is_last = ind >= (num_resolutions - 1)

            self.downs_pattern.append(
                Downsample(dim_in, dim_out) # B, dim_in, H/2, W/2
                if not is_last 
                else nn.Conv2d(dim_in, dim_out, 3, padding=1),
            )
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out_pattern)):
            is_last = ind == (len(in_out_pattern) - 1)

            self.ups_pattern.append(
                Upsample(dim_out, dim_in)
                if not is_last
                else nn.Conv2d(dim_out, dim_in, 3, padding=1),
            )

        # additional MLPs for embeddings
        # noise encoding
        self.noise_feat = nn.Sequential(
            nn.Linear(17*3+2, 64),
            nn.GELU(),
            nn.Linear(64,100), # TODO: de-hardcode this outut dim
        )

        # history embedding
        self.history_feat = nn.Sequential(
            nn.Linear(17*3, 64),
            nn.GELU(),
            nn.Linear(64,100), # TODO: de-hardcode this outut dim
        )

        # rotation embedding
        self.rotation_feat = nn.Sequential(
            nn.Linear(2, 16),
            nn.GELU(),
            nn.Linear(16, 64),
            nn.GELU(),
            nn.Linear(64,100), # TODO: de-hardcode this outut dim
        )

        # pattern embedding
        self.pattern_feat = nn.Sequential(
            nn.Linear(17*3, 64),
            nn.GELU(),
            nn.Linear(64,100), # TODO: de-hardcode this outut dim
        )
        self.pattern_mix = nn.Conv2d(self.n_horizon, 1, 1) # channel mixing
        self.pattern_enc = nn.Sequential(
            nn.Conv2d(self.n_horizon, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 8, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(8, pattern_enc_channels, 3, padding=1),
            nn.SiLU()
        )

        # final prediction layer
        self.pred_feat = nn.Sequential(
            nn.Linear(100, 64),
            nn.GELU(),
            nn.Linear(64,(17*3+2)), # TODO: de-hardcode this outut dim
        )

    def forward(self, x, x_c, rot, pattern, time, x_self_cond=None):        
        B = x.size(0)

        # noise encoding
        x = x.view(B, self.n_horizon, -1) # B x H x 17 x 3 -> B x H x 17*3
        x_reshaped = self.noise_feat(x) # B x H x 100
        x_reshaped = x_reshaped.view(B, self.n_horizon, 10, 10) # B x H x 10 x 10

        # history conditioning
        history = x_c.view(B, self.n_frames, -1) # B x h x 17 x 3 -> B x h x 17*3
        history_reshaped = self.history_feat(history) # B x h x 100
        history_reshaped = history_reshaped.view(B, self.n_frames, 10, 10) # B x h x 10 x 10

        # rotation information
        rotation_reshaped = self.rotation_feat(rot) # B x h x 2 -> B x h x 100
        rotation_reshaped = rotation_reshaped.view(B, self.n_frames, 10, 10) # B x h x 10 x 10
        
        # pattern encoding
        patt = pattern.view(B, self.n_horizon, -1) # B x H x 17 x 3 -> B x H x 17*3
        pattern_reshaped = self.pattern_feat(patt).view(B, self.n_horizon, 10, 10) # B x H x 10 x 10
        
        penc = self.pattern_enc(pattern_reshaped) # B, H, 10 x 10 -> B, p, 10, 10
        rp = penc.clone()
        p = self.init_pattern(pattern_reshaped) # B, H, 10, 10 -> B, 20, 10, 10

        # concatenate all inputs
        x = torch.cat((x_reshaped, penc, history_reshaped, rotation_reshaped), dim=1) # B, H + p + h + h, 10, 10

        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x) # B, C, 10, 10 -> B, 20, 10, 10
        r = x.clone()

        t = self.time_mlp(time) # B, dim*4

        h = []

        for i, (block1, block2, attn, downsample) in enumerate(self.downs):
            x = x + p
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
            p = self.downs_pattern[i](p)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for i, (block1, block2, attn, upsample) in enumerate(self.ups):
            x = x + p
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
            p = self.ups_pattern[i](p)

        x = x + p
        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        x = self.final_conv(x).view(B, self.n_horizon, -1) # B x H x 10 x 10 -> B x H x 100
        x = self.pred_feat(x) # B x H x 17*3+2
        return x