from typing import Tuple, Optional, Any
import jax
import jax.numpy as jnp
import flax.linen as nn
from . import transformer


class ResNetDecoder(nn.Module):
    image_size: int
    depths: Tuple
    blocks: int
    out_dim: int
    dtype: Optional[Any] = jnp.float32
    
    @nn.compact
    def __call__(self, deter, embeddings=None):
        depths = list(reversed(self.depths))
        x = deter
        if embeddings is not None:
            x = jnp.concatenate([x, embeddings], axis=-1)

        x = nn.Conv(self.depths[0], [3, 3], dtype=self.dtype)(x)
            
        for i in range(len(depths) - 1):
            for _ in range(self.blocks):
                x = ResNetBlock(depths[i], dtype=self.dtype)(x)
            x = jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
                                 jax.image.ResizeMethod.NEAREST)
        for _ in range(self.blocks):
            x = ResNetBlock(depths[-1], dtype=self.dtype)(x)
        x = nn.LayerNorm(dtype=self.dtype)(x)

        x = nn.Dense(self.out_dim, dtype=self.dtype)(x) 
        return x
     

class ResNetBlock(nn.Module):
    depth: int
    dtype: Optional[Any] = jnp.float32
    
    @nn.compact
    def __call__(self, x):
        skip = x
        if skip.shape[-1] != self.depth:
            skip = nn.Conv(self.depth, [1, 1], use_bias=False, 
                           dtype=self.dtype, name='skip')(skip)

        x = nn.elu(nn.GroupNorm(dtype=self.dtype)(x))
        x = nn.Conv(self.depth, [3, 3], dtype=self.dtype)(x)
        x = nn.elu(nn.GroupNorm(dtype=self.dtype)(x))
        x = nn.Conv(self.depth, [3, 3], dtype=self.dtype, use_bias=False)(x)
        x = transformer.AddBias(dtype=self.dtype)(x)
        return skip + 0.1 * x 
