from typing import Any, Optional, Tuple
import math
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn

zeros = nn.initializers.zeros


# TODO vmap rng bug
class UNet(nn.Module):
    model_channels: int
    out_channels: int
    num_res_blocks: Tuple[int]
    attention_resolutions: Tuple[int]
    action_dim: int
    action_embed_dim: int
    dropout: float = 0
    channel_mult: Tuple[int] = (1, 2, 4, 8)
    conv_resample: bool = True
    num_head_dim: int = 64
    use_scale_shift_norm: bool = True

    @nn.compact
    def __call__(self, x, timesteps, actions, deterministic=False, idxs=None):
        # x: BTHWC, timesteps: B
        attention_resolutions = [x.shape[2] // ar for ar in self.attention_resolutions]
        num_res_blocks = self.num_res_blocks
        if isinstance(num_res_blocks, int):
            num_res_blocks = [num_res_blocks] * len(self.channel_mult)

        resblock = lambda **kwargs: jax.vmap(
            ResBlock(dropout=self.dropout, use_scale_shift_norm=self.use_scale_shift_norm,
                     **kwargs), (1, None, 1, None), 1
        )
        attn = partial(FactorizedAttentionBlock, self.num_head_dim)
        
        emb = nn.Sequential([
            nn.Dense(self.model_channels * 4),
            nn.swish,
            nn.Dense(self.model_channels * 4)
        ])(timestep_embedding(timesteps, self.model_channels))

        actions = nn.Embed(self.action_dim + 1, self.action_embed_dim)(actions)
        
        h, hs = x, []
        ch = self.model_channels

        # Encoder
        ds = 1
        h = jax.vmap(nn.Conv(ch, [3, 3]), 1, 1)(x)
        hs = [h]
        for level, mult in enumerate(self.channel_mult):
            for _ in range(num_res_blocks[level]):
                h = resblock(out_channels=mult * ch)(h, emb, actions, deterministic)
                if ds in attention_resolutions:
                    h = attn()(h)
                hs.append(h)
            if level != len(self.channel_mult) - 1:
                h = jax.vmap(Downsample(self.conv_resample), 1, 1)(h)
                hs.append(h)
                ds *= 2
        
        # Middle
        h = resblock()(h, emb, actions, deterministic)
        h = attn()(h)
        h = resblock()(h, emb, actions, deterministic)

        # Decoder
        applied_idxs = False
        for level, mult in list(enumerate(self.channel_mult))[::-1]:
            if idxs is not None and ds < min(attention_resolutions) and not applied_idxs: # TODO check
                hs = [h_i[:, idxs] for h_i in hs]
                h = h[:, idxs]
                actions = actions[:, idxs]
                applied_idxs = True

            for i in range(num_res_blocks[level] + 1):
                h = jnp.concatenate([h, hs.pop()], axis=-1)
                h = resblock(out_channels=mult * ch)(h, emb, actions, deterministic)
                if ds in attention_resolutions:
                    h = attn()(h)
                if level and i == num_res_blocks[level]:
                    h = jax.vmap(Upsample(self.conv_resample), 1, 1)(h)
                    ds //= 2
 
        out = jax.vmap(nn.Sequential([
            nn.GroupNorm(),
            nn.swish,
            nn.Conv(self.out_channels, [3, 3], kernel_init=zeros, bias_init=zeros)
        ]), 1, 1)(h)
        
        return out

        
class Upsample(nn.Module):
    use_conv: bool

    @nn.compact
    def __call__(self, x):
        # x: BHWC
        H, W, C = x.shape[-3:]
        output_shape = (*x.shape[:-3], H * 2, W * 2, C)
        x = jax.image.resize(x, output_shape, 'nearest')
        if self.use_conv:
            k = min(3, x.shape[-2])
            x = nn.Conv(C, [k, k])(x)
        return x

        
class Downsample(nn.Module):
    use_conv: bool

    @nn.compact
    def __call__(self, x):
        # x: BHWC
        if self.use_conv:
            k = min(3, x.shape[-2])
            x = nn.Conv(x.shape[-1], (k, k), strides=(2, 2))(x)
        else:
            x = nn.avg_pool(x, (2, 2), strides=(2, 2))
        return x
            

class ResBlock(nn.Module):
    dropout: float
    out_channels: Optional[int] = None
    use_conv: bool = False
    use_scale_shift_norm: bool = False

    @nn.compact
    def __call__(self, x, emb, actions, deterministic):
        # x: NHWC, emb: ND, actions: ND
        k = min(3, x.shape[-2])
        out_channels = self.out_channels or x.shape[-1]
        
        h = nn.Sequential([
            nn.GroupNorm(),
            nn.swish,
            nn.Conv(out_channels, [k, k])
        ])(x)

        emb_out = nn.Sequential([
            nn.swish,
            nn.Dense(2 * out_channels if self.use_scale_shift_norm else out_channels)    
        ])(emb)

        actions = nn.Sequential([
            nn.swish,
            nn.Dense(2 * out_channels if self.use_scale_shift_norm else out_channels)
        ])(actions)
        emb_out += actions

        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None, :]
        
        if self.use_scale_shift_norm:
            scale, shift = jnp.split(emb_out, 2, axis=-1)
            h = nn.GroupNorm()(h) * (1 + scale) + shift
            h = nn.Sequential([
                nn.swish,
                nn.Dropout(self.dropout, deterministic=deterministic),
                nn.Conv(out_channels, [k, k], kernel_init=zeros, bias_init=zeros)
            ])(h)
        else:
            h = h + emb_out
            h = nn.Sequential([
                nn.GroupNorm(),
                nn.swish, 
                nn.Dropout(self.dropout, deterministic=deterministic),
                nn.Conv(out_channels, [k, k], kernel_init=zeros, bias_init=zeros)
            ])(h)
        
        if out_channels != x.shape[-1]:
            if self.use_conv:
                x = nn.Conv(out_channels, [k, k])(x)
            else:
                x = nn.Dense(out_channels)(x)
        return x + h

        
class FactorizedAttentionBlock(nn.Module):
    num_head_dim: bool

    @nn.compact
    def __call__(self, x):
        # x: BTHWD
        h = jax.vmap(RelativeAttention(self.num_head_dim), 1, 1)(x) # spatial attention
        x = x + h

        h = jax.vmap(jax.vmap(RelativeAttention(self.num_head_dim), 2, 2), 2, 2)(x) # temporal attention
        x = x + h
        
        return x

        
class RelativePosition(nn.Module):
    num_units: int
    k: Optional[int] = None

    @nn.compact
    def __call__(self, q_len, k_len):
        k = self.k or max(q_len, k_len)

        range_vec_q = jnp.arange(q_len, dtype=jnp.int32)
        range_vec_k = jnp.arange(k_len, dtype=jnp.int32)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None] # qk
        distance_mat_clipped = jnp.clip(distance_mat, a_min=-k, a_max=k)
        final_mat = distance_mat_clipped + k
        embeddings = nn.Embed(2 * k + 1, self.num_units)(final_mat)
        return embeddings


class RelativeAttention(nn.Module):
    num_head_dim: int
    
    @nn.compact
    def __call__(self, x):
        # x: B...D
        old_shape = x.shape[1:-1]
        B, *_, C = x.shape
        num_heads = C // self.num_head_dim
        scale = 1 / math.sqrt(math.sqrt(self.num_head_dim))
        
        x = x.reshape(B, -1, C)
        x = nn.GroupNorm()(x)
        qkv = nn.DenseGeneral(
            [num_heads, 3 * self.num_head_dim]
        )(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        # Standard attention
        weight1 = jnp.einsum(
            'bqhd,bkhd->bqkh', q * scale, k * scale
        )

        # Relative biases
        # bqhd, qkd
        r_k = RelativePosition(self.num_head_dim)(q.shape[1], k.shape[1])
        weight2 = jnp.einsum(
            'bqhd,qkd->bqkh', q * scale, r_k * scale
        )
        weight = weight1 + weight2 
        weight = jax.nn.softmax(weight.astype(jnp.float32), axis=2).astype(weight.dtype)

        r_v = RelativePosition(self.num_head_dim)(q.shape[1], v.shape[1])
        out1 = jnp.einsum('bqkh,bkhd->bqhd', weight, v)
        out2 = jnp.einsum('bqkh,qkd->bqhd', weight, r_v)
        out = out1 + out2

        out = nn.DenseGeneral(
            C, axis=(-2, -1), kernel_init=zeros
        )(out)

        return out.reshape(B, *old_shape, C)


def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = jnp.exp(
        -jnp.log(max_period) * jnp.arange(0, half, dtype=jnp.float32) / half
    )
    args = timesteps[:, None].astype(jnp.float32) * freqs[None]
    embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
    if dim % 2 == 1:
        embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
    return embedding
