import jax
from jax import numpy as jnp

import flax
from flax import linen as nn

from typing import Any, Callable, Tuple


class ResNetUnit(nn.Module):
    act_fn : Callable  
    c_out : int  
    subsample : bool = False 
    kernel_init: Callable = nn.initializers.lecun_normal()
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x, train=True):
        z = nn.Conv(self.c_out, 
                    kernel_size=(3, 3),
                    strides=(1, 1) if not self.subsample else (2, 2),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)
        z = nn.BatchNorm(momentum=0.9, epsilon=1e-5)(z,  use_running_average=not train)
        z = self.act_fn(z)
        z = nn.Conv(self.c_out, 
                    kernel_size=(3, 3),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(z)
        z = nn.BatchNorm(momentum=0.9, epsilon=1e-5)(z, use_running_average=not train)

        if self.subsample:
            x = nn.Conv(self.c_out, 
                        kernel_size=(1, 1), 
                        strides=(2, 2), 
                        kernel_init=self.kernel_init,
                        use_bias=False,
                        dtype=self.dtype)(x)
            x = nn.BatchNorm(momentum=0.9, epsilon=1e-5)(x, use_running_average=not train)
            
        x_out = self.act_fn(z + x)
        return x_out


class ResNetFeatures(nn.Module):
    act_fn : Callable
    block_class : nn.Module
    num_blocks : Tuple = (3, 3, 3)
    c_hidden : Tuple = (16, 32, 64)
    kernel_init: Callable = nn.initializers.lecun_normal()
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x, train=True):
        x = nn.Conv(self.c_hidden[0], 
                    kernel_size=(3, 3), 
                    kernel_init=self.kernel_init, 
                    use_bias=False,
                    dtype=self.dtype)(x)
        x = nn.BatchNorm(momentum=0.9, epsilon=1e-5)(x, use_running_average=not train)
        x = self.act_fn(x)

        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                subsample = (bc == 0 and block_idx > 0)
                x = self.block_class(c_out=self.c_hidden[block_idx],
                                     act_fn=self.act_fn,
                                     kernel_init=self.kernel_init,
                                     subsample=subsample,
                                     dtype=self.dtype)(x, train=train)

        x = nn.avg_pool(x,window_shape=(4,4)) 
        x = jnp.transpose(x, (0, 3, 1, 2)) # TRANSPOSE TO USE WITH PRE-TRAINED TORCH MODELS
        x = jnp.reshape(x, (x.shape[0], -1))
        return x
    
class ResNetClassifier(nn.Module):
    num_classes: int    
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()
    dtype: Any = jnp.float32
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
                    self.num_classes, 
                    use_bias=True,
                    kernel_init=self.kernel_init, 
                    bias_init=self.bias_init,
                    dtype=self.dtype
                    )(x)

        x = jnp.asarray(x) 
        return x
  
class ResNet(nn.Module):
    num_classes: int
    act_fn : Callable = nn.relu
    block_class : nn.Module = ResNetUnit
    num_blocks : Tuple = (3, 3, 3)
    c_hidden : Tuple = (64, 128, 256, 512)
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()
    dtype: Any = jnp.float32
        
    def setup(self):
        self.features = ResNetFeatures(
                                        act_fn=self.act_fn, 
                                        block_class=self.block_class,
                                        num_blocks=self.num_blocks,
                                        c_hidden=self.c_hidden,
                                        kernel_init=self.kernel_init,
                                        dtype=self.dtype
                                        )
        self.classifier = ResNetClassifier(
                                        num_classes=self.num_classes,
                                        kernel_init=self.kernel_init,
                                        bias_init=self.bias_init,
                                        dtype=self.dtype
                                        )

    def __call__(self, x, train: bool = True):
        x = self.features(x, train)
        x = self.classifier(x)
        return x
