from typing import Tuple, Optional, Any
import math
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
from ..utils import f_psum, g_psum, create_g_all_gather
from .. import sharding
from .. import transformer


zeros = nn.initializers.zeros


g_all_gather_3 = create_g_all_gather(axis=3)
g_all_gather_4 = create_g_all_gather(axis=4)


class UNetShard(nn.Module):
    num_shards: int
    model_channels: int
    out_channels: int
    num_res_blocks: Tuple[int]
    attention_resolutions: Tuple[int]
    action_dim: int
    action_embed_dim: int
    dropout: float = 0
    channel_mult: Tuple[int] = (1, 2, 4, 8)
    conv_resample: bool = True
    num_head_dim: int = 64
    use_scale_shift_norm: bool = True # not used
    
    @nn.compact
    def __call__(self, x, timesteps, actions, deterministic=False):
        # x: BTHWC, timesteps: B
        attention_resolutions = [x.shape[2] // ar for ar in self.attention_resolutions]
        num_res_blocks = self.num_res_blocks
        if isinstance(num_res_blocks, int):
            num_res_blocks = [num_res_blocks] * len(self.channel_mult)

        # TODO jax.vmap -> flax.vmap
        resblock = lambda **kwargs: jax.vmap(
            ResBlock(num_shards=self.num_shards, dropout=self.dropout, **kwargs), (1, None, 1, None), 1
        )
        attn = partial(FactorizedAttentionBlock, self.num_shards, self.num_head_dim)

        emb = timestep_embedding(timesteps, self.model_channels)
        emb = f_psum(emb)
        emb = nn.Sequential([
            nn.Dense(self.model_channels * 4 // self.num_shards),
            nn.swish,
            nn.Dense(self.model_channels * 4, use_bias=False,
                     kernel_init=nn.initializers.variance_scaling(
                         1.0 / self.num_shards, 'fan_in', 'normal'
                     ))
        ])(emb)
        emb = g_psum(emb)
        emb = transformer.AddBias()(emb)

        actions = nn.Embed(self.action_dim + 1, self.action_embed_dim)(actions)

        h, hs = x, []
        assert self.model_channels % self.num_shards == 0
        ch = self.model_channels

        # Encoder
        ds = 1
        x = f_psum(x)
        h = jax.vmap(nn.Conv(ch // self.num_shards, [3, 3]), 1, 1)(x)
        h = g_all_gather_4(h)
        h = h.reshape(*h.shape[:-2], -1)

        hs = [h]
        for level, mult in enumerate(self.channel_mult):
            for _ in range(num_res_blocks[level]):
                h = resblock(out_channels=mult * ch)(h, emb, actions, deterministic)
                if ds in attention_resolutions:
                    h = attn()(h)
                hs.append(h)
            if level != len(self.channel_mult) - 1:
                h = jax.vmap(Downsample(self.num_shards, self.conv_resample), 1, 1)(h)
                hs.append(h)
                ds *= 2
        
        # Middle
        h = resblock()(h, emb, actions, deterministic)
        h = attn()(h)
        h = resblock()(h, emb, actions, deterministic)

        # Decoder
        for level, mult in list(enumerate(self.channel_mult))[::-1]:
            for i in range(num_res_blocks[level] + 1):
                h = jnp.concatenate([h, hs.pop()], axis=-1)
                h = resblock(out_channels=mult * ch)(h, emb, actions, deterministic)
                if ds in attention_resolutions:
                    h = attn()(h)
                if level and i == num_res_blocks[level]:
                    h = jax.vmap(Upsample(self.num_shards, self.conv_resample), 1, 1)(h)
                    ds //= 2
 
        out = jax.vmap(nn.Sequential([
            nn.GroupNorm(),
            nn.swish,
            nn.Conv(self.out_channels, [3, 3], kernel_init=zeros, bias_init=zeros)
        ]), 1, 1)(h)
        
        return out

    @staticmethod
    def model_spec(image_size, model_channels, num_res_blocks, attention_resolutions, 
                   channel_mult, conv_resample, **kwargs):
        attention_resolutions = [image_size // ar for ar in attention_resolutions]
        if isinstance(num_res_blocks, int):
            num_res_blocks = [num_res_blocks] * len(channel_mult)
        
        spec = dict()
        spec.update({
            'Dense_0': sharding.Dense(use_bias=True, axis=1),
            'Dense_1': sharding.Dense(use_bias=False, axis=0),
            'AddBias_0': sharding.GenericReplicated(reduce_mode='identity'),
            'Embed_0': sharding.GenericReplicated(reduce_mode='identity'),
            'Conv_0': sharding.Conv(use_bias=True, axis=1)
        })

        ds = 1
        hs = [model_channels]
        prev_ch = model_channels
        res_id, attn_id, ds_id, us_id = 0, 0, 0, 0
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks[level]):
                ch = mult * model_channels
                spec[f'ResBlock_{res_id}'] = ResBlock.model_spec(has_skip=prev_ch != ch, use_conv=False)
                res_id += 1
                if ds in attention_resolutions:
                    spec[f'FactorizedAttentionBlock_{attn_id}'] = FactorizedAttentionBlock.model_spec()
                    attn_id += 1
                prev_ch = ch
                hs.append(ch)
            if level != len(channel_mult) - 1:
                if conv_resample:
                    spec[f'Downsample_{ds_id}'] = Downsample.model_spec()
                ds_id += 1
                ds *= 2
                hs.append(ch)

        # Middle 
        spec.update({
            f'ResBlock_{res_id}': ResBlock.model_spec(has_skip=False, use_conv=False),
            f'FactorizedAttentionBlock_{attn_id}': FactorizedAttentionBlock.model_spec(),
            f'ResBlock_{res_id + 1}': ResBlock.model_spec(has_skip=False, use_conv=False), 
        })
        res_id += 2
        attn_id += 1

        # Decoder
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks[level] + 1):
                prev_ch += hs.pop()
                ch = mult * model_channels
                spec[f'ResBlock_{res_id}'] = ResBlock.model_spec(has_skip=prev_ch != ch, use_conv=False)
                res_id += 1
                prev_ch = ch
                if ds in attention_resolutions:
                    spec[f'FactorizedAttentionBlock_{attn_id}'] = FactorizedAttentionBlock.model_spec()
                    attn_id += 1
                if level and i == num_res_blocks[level]:
                    if conv_resample:
                        spec[f'Upsample_{us_id}'] = Upsample.model_spec()
                    us_id += 1
                    ds //= 2
        
        spec.update({
            'GroupNorm_0': sharding.GenericReplicated(reduce_mode='identity'),
            'Conv_1': sharding.GenericReplicated(reduce_mode='identity')
        })

        return sharding.GenericDict(spec)

                
class Upsample(nn.Module):
    num_shards: int
    use_conv: bool
    
    @nn.compact
    def __call__(self, x):
        # x: BHWC
        H, W, C = x.shape[-3:]
        output_shape = (*x.shape[:-3], H * 2, W * 2, C)
        x = jax.image.resize(x, output_shape, 'nearest')
        if self.use_conv:
            k = min(3, x.shape[-2])
            x = f_psum(x)
            x = nn.Conv(C // self.num_shards, [k, k])(x)
            x = g_all_gather_3(x)
            x = x.reshape(*x.shape[:-2], -1)
        return x

    @staticmethod
    def model_spec():
        return sharding.GenericDict({'Conv_0': sharding.Conv(axis=1, use_bias=True)})

                
class Downsample(nn.Module):
    num_shards: int
    use_conv: bool 

    @nn.compact
    def __call__(self, x):
        if self.use_conv:
            k = min(3, x.shape[-2])
            x = f_psum(x)
            x = nn.Conv(x.shape[-1] // self.num_shards, [k, k], strides=(2, 2))(x)
            x = g_all_gather_3(x)
            x = x.reshape(*x.shape[:-2], -1)
        else:
            x = nn.avg_pool(x, (2, 2), strides=(2, 2))
        return x 

    @staticmethod
    def model_spec():
        return sharding.GenericDict({'Conv_0': sharding.Conv(axis=1, use_bias=True)})


class ResBlock(nn.Module):
    num_shards: int
    dropout: float
    out_channels: Optional[int] = None
    use_conv: bool = False

    @nn.compact
    def __call__(self, x, emb, actions, deterministic):
        # x: NHWC, emb: ND, actions: ND
        k = min(3, x.shape[-2])
        out_channels = self.out_channels or x.shape[-1]
        assert out_channels % self.num_shards == 0
        skip = x

        x = f_psum(x)
        h = nn.Sequential([
            nn.GroupNorm(),
            nn.swish,
            nn.Conv(out_channels // self.num_shards, [k, k])
        ])(x)

        emb_fn = lambda: nn.Sequential([
            nn.swish,
            nn.Dense(out_channels // self.num_shards)
        ])
        emb = f_psum(emb)
        scale, shift = emb_fn()(emb), emb_fn()(emb)

        actions = f_psum(actions)
        scale += emb_fn()(actions)
        shift += emb_fn()(actions)

        while len(scale.shape) < len(h.shape):
            scale = scale[..., None, :]
            shift = shift[..., None, :]

        h = nn.GroupNorm(num_groups=32 // self.num_shards)(h) * (1 + scale) + shift
        h = nn.Sequential([
            nn.swish,
            nn.Dropout(self.dropout, deterministic=deterministic),
            nn.Conv(out_channels, [k, k], kernel_init=zeros, use_bias=False)
        ])(h)
        h = g_psum(h)
        h = transformer.AddBias()(h)
        
        if out_channels != x.shape[-1]:
            skip = f_psum(skip)
            if self.use_conv:
                skip = nn.Conv(out_channels // self.num_shards, [k, k])(skip)
            else:
                skip = nn.Dense(out_channels // self.num_shards)(skip) 
            skip = g_all_gather_3(skip)
            skip = skip.reshape(*skip.shape[:-2], -1)
        return  skip + h

    @staticmethod
    def model_spec(has_skip, use_conv):
        spec = dict()
        
        spec.update({
            'GroupNorm_0': sharding.GenericReplicated(reduce_mode='sum'),
            'Conv_0': sharding.Conv(use_bias=True, axis=1),
            'Dense_0': sharding.Dense(use_bias=True, axis=1),
            'Dense_1': sharding.Dense(use_bias=True, axis=1),
            'Dense_2': sharding.Dense(use_bias=True, axis=1),
            'Dense_3': sharding.Dense(use_bias=True, axis=1),
            'GroupNorm_1': sharding.GroupNorm(),
            'Conv_1': sharding.Conv(use_bias=False, axis=0),
            'AddBias_0': sharding.GenericReplicated(reduce_mode='identity')
        })

        if has_skip:
            if use_conv:
                spec['Conv_2'] = sharding.Conv(use_bias=True, axis=1)
            else:
                spec['Dense_4'] = sharding.Dense(use_bias=True, axis=1)
        
        return sharding.GenericDict(spec)


class FactorizedAttentionBlock(nn.Module):
    num_shards: int
    num_head_dim: int

    @nn.compact
    def __call__(self, x):
        h = jax.vmap(RelativeAttention(self.num_shards, self.num_head_dim), 1, 1)(x) # spatial attention
        x = x + h

        h = jax.vmap(jax.vmap(RelativeAttention(self.num_shards, self.num_head_dim), 2, 2), 2, 2)(x) # temporal attention
        x = x + h
        
        return x
        
    @staticmethod 
    def model_spec():
        return sharding.GenericDict({
            'RelativeAttention_0': RelativeAttention.model_spec(),
            'RelativeAttention_1': RelativeAttention.model_spec(),
        })

        
class RelativePosition(nn.Module):
    num_units: int
    k: Optional[int] = None

    @nn.compact
    def __call__(self, q_len, k_len):
        k = self.k or max(q_len, k_len)

        range_vec_q = jnp.arange(q_len, dtype=jnp.int32)
        range_vec_k = jnp.arange(k_len, dtype=jnp.int32)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None] # qk
        distance_mat_clipped = jnp.clip(distance_mat, a_min=-k, a_max=k)
        final_mat = distance_mat_clipped + k
        embeddings = nn.Embed(2 * k + 1, self.num_units)(final_mat)
        return embeddings


class RelativeAttention(nn.Module):
    num_shards: int
    num_head_dim: int

    @nn.compact
    def __call__(self, x):
        # x: B...D
        old_shape = x.shape[1:-1]
        B, *_, C = x.shape
        num_heads = C // self.num_head_dim
        assert num_heads % self.num_shards == 0
        num_heads_per_shard = num_heads // self.num_shards
        scale = 1 / math.sqrt(math.sqrt(self.num_head_dim))
        
        x = f_psum(x)
        x = x.reshape(B, -1, C)
        x = nn.GroupNorm()(x)
        qkv = nn.DenseGeneral(axis=-1, features=(num_heads_per_shard, 3 * self.num_head_dim))(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        weight1 = jnp.einsum(
            'bqhd,bkhd->bqkh', q * scale, k * scale
        )

        # bqhd, bkd
        r_k = RelativePosition(self.num_head_dim)(q.shape[1], k.shape[1])
        weight2 = jnp.einsum(
            'bqhd,qkd->bqkh', q * scale, r_k * scale
        )
        weight = weight1 + weight2 
        weight = jax.nn.softmax(weight.astype(jnp.float32), axis=2).astype(weight.dtype)

        r_v = RelativePosition(self.num_head_dim)(q.shape[1], v.shape[1])
        out1 = jnp.einsum('bqkh,bkhd->bqhd', weight, v)
        out2 = jnp.einsum('bqkh,qkd->bqhd', weight, r_v)
        out = out1 + out2

        out = nn.DenseGeneral(
            C, axis=(-2, -1), kernel_init=zeros, use_bias=False
        )(out)
        out = g_psum(out)
        out =  transformer.AddBias()(out)

        return out.reshape(B, *old_shape, C)

    @staticmethod
    def model_spec():
        return sharding.GenericDict({
            'GroupNorm_0': sharding.GenericReplicated(reduce_mode='sum'),
            'DenseGeneral_0': sharding.DenseGeneral(use_bias=True, shard_mode='out'),
            'RelativePosition_0': sharding.GenericReplicated(reduce_mode='sum'),
            'RelativePosition_1': sharding.GenericReplicated(reduce_mode='sum'),
            'DenseGeneral_1': sharding.DenseGeneral(use_bias=False, shard_mode='in'),
            'AddBias_0': sharding.GenericReplicated(reduce_mode='identity')
        })
        

def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = jnp.exp(
        -jnp.log(max_period) * jnp.arange(0, half, dtype=jnp.float32) / half
    )
    args = timesteps[:, None].astype(jnp.float32) * freqs[None]
    embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
    if dim % 2 == 1:
        embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
    return embedding
