from flax import linen as nn
import jax.numpy as jnp
import jax
from jax.tree_util import tree_map, Partial

class ResNetModel(nn.Module):
    width_multiplier: int
    use_fast_variance: bool
    bn_eps: float
    bn_momentum: float
    final_bias: bool
    residual_scale: float
    final_scale: float
    batchnorm_before_act: bool
    init_scale: float
    init_mode: str
    init_distribution: str
    activation_fn: str
    max_pool: bool
    big_first_conv: bool

    @nn.compact
    def __call__(self, x, train=True, inverse_padder=None):
        WIDTH_MULTIPLIER = self.width_multiplier
        act = nn.gelu if self.activation_fn == 'gelu' else nn.relu
        pooler = nn.max_pool if self.max_pool else nn.avg_pool
        initializer = jax.nn.initializers.variance_scaling(
            scale=self.init_scale,
            mode=self.init_mode,
            distribution=self.init_distribution
        )
        conv_layer = Partial(nn.Conv, kernel_init=initializer)

        batch_norm = lambda: nn.BatchNorm(use_running_average=not train, 
                                  use_fast_variance=self.use_fast_variance, 
                                  epsilon=self.bn_eps, 
                                  momentum=self.bn_momentum)
        assert isinstance(x, tuple), 'x must be a tuple'
        x = jnp.concatenate(x, axis=0)
        if inverse_padder is not None:
            x = x[inverse_padder]
        conv1 = conv_layer(features=int(64 * WIDTH_MULTIPLIER), kernel_size=(3, 3))
        x = conv1(x)
        x = act(x)

        if self.big_first_conv:
            x = conv_layer(features=int(128 * WIDTH_MULTIPLIER), kernel_size=(5, 5), strides=(2, 2))(x)
        else:
            x = conv_layer(features=int(128 * WIDTH_MULTIPLIER), kernel_size=(3, 3))(x)
            if self.batchnorm_before_act:
                x = batch_norm()(x)
                x = act(x)
            else:
                x = act(x)
                x = batch_norm()(x)
            x = pooler(x, (2, 2), strides=(2, 2))

        # layer1
        residual = x
        y = conv_layer(features=int(128  * WIDTH_MULTIPLIER), kernel_size=(3, 3))(x)
        if self.batchnorm_before_act:
            y = batch_norm()(y)
            y = act(y)
        else:
            y = act(y)
            y = batch_norm()(y)
        y = conv_layer(features=int(128  * WIDTH_MULTIPLIER), kernel_size=(3, 3))(y)
        if self.batchnorm_before_act:
            y = batch_norm()(y)
            y = act(y)
        else:
            y = act(y)
            y = batch_norm()(y)
        x = y * self.residual_scale + residual

        # layer2
        x = conv_layer(features=int(256 * WIDTH_MULTIPLIER), kernel_size=(3, 3))(x)
        if self.batchnorm_before_act:
            x = batch_norm()(x)
            x = act(x)
        else:
            x = act(x)
            x = batch_norm()(x)
        x = pooler(x, (2, 2), strides=(2, 2))

        # layer3
        x = conv_layer(features=int(512 * WIDTH_MULTIPLIER), kernel_size=(3, 3))(x)
        if self.batchnorm_before_act:
            x = batch_norm()(x)
            x = act(x)
        else:
            x = act(x)
            x = batch_norm()(x)
        x = pooler(x, (2, 2), strides=(2, 2))
        residual = x
        y = conv_layer(features=int(512 * WIDTH_MULTIPLIER), kernel_size=(3, 3))(x)
        if self.batchnorm_before_act:
            y = batch_norm()(y)
            y = act(y)
        else:
            y = act(y)
            y = batch_norm()(y)
        y = conv_layer(features=int(512 * WIDTH_MULTIPLIER), kernel_size=(3, 3))(y)
        if self.batchnorm_before_act:
            y = batch_norm()(y)
            y = act(y)
        else:
            y = act(y)
            y = batch_norm()(y)
        x = y * self.residual_scale + residual

        x = pooler(x, (4, 4), strides=(4, 4))
        x = x.reshape(x.shape[0], -1)
        x = nn.Dense(10, use_bias=self.final_bias)(x)
        return x * self.final_scale

class TTAModel(nn.Module):
    width_multiplier: float
    bn_eps: float = 1e-5
    bn_momentum: float = 0.99
    use_fast_variance: bool = True
    final_bias: bool = False
    residual_scale: float = 1.
    final_scale: float = 0.125
    batchnorm_before_act: bool = True
    init_scale: float = 1.
    init_mode: str = 'fan_in'
    init_distribution: str = 'truncated_normal'
    activation_fn: str = 'gelu'
    max_pool: bool = False
    big_first_conv: bool = False
    tta: bool = True

    def setup(self):
        self.model = ResNetModel(width_multiplier=self.width_multiplier,
                                 use_fast_variance=self.use_fast_variance,
                                 bn_eps=self.bn_eps,
                                 bn_momentum=self.bn_momentum,
                                 final_bias=self.final_bias,
                                 residual_scale=self.residual_scale,
                                 final_scale=self.final_scale,
                                 batchnorm_before_act=self.batchnorm_before_act,
                                 init_scale=self.init_scale,
                                 init_mode=self.init_mode,
                                 init_distribution=self.init_distribution,
                                 activation_fn=self.activation_fn,
                                 max_pool=self.max_pool,
                                 big_first_conv=self.big_first_conv)
    

    def simple_tta(self, x, train, inverse_padder):
        flipped_x = tree_map(Partial(jnp.flip, axis=2), x)
        return (self.model(flipped_x, train=train, inverse_padder=inverse_padder) + 
                self.model(x, train=train, inverse_padder=inverse_padder)) / 2

    def __call__(self, x, train=True, inverse_padder=None):
        if train or not self.tta:
            return self.model(x, train=train, inverse_padder=inverse_padder)
        else:
            # Create a version of x that is shifted by 1 pixel in the positive x direction
            # shifted_x = tree_map(lambda x: jnp.roll(x, (1, 1), axis=(2, 3)), x)
            # shifted_x_vert = tree_map(lambda x: jnp.roll(x, (-1, -1), axis=(2, 3)), x)
            shifted_x = tree_map(lambda x: jnp.roll(x, -1, axis=2), x)
            shifted_x_vert = tree_map(lambda x: jnp.roll(x, 1, axis=2), x)
            return self.simple_tta(x, train, inverse_padder) * 0.5 \
                + self.simple_tta(shifted_x_vert, train, inverse_padder) * 0.25 \
                + self.simple_tta(shifted_x, train, inverse_padder) * 0.25
        
    # def __call__(self, x, train=True, inverse_padder=None):
    #     if train or not self.tta:
    #         return self.model(x, train=train, inverse_padder=inverse_padder)
    #     else:
    #         flipped_x = tree_map(Partial(jnp.flip, axis=2), x)
    #         return (self.model(flipped_x, train=train, inverse_padder=inverse_padder) + 
    #                 self.model(x, train=train, inverse_padder=inverse_padder)) / 2

def construct_model(seed=0, 
                    dtype=jnp.float32, 
                    init_params=False, 
                    **kwargs):
    model = TTAModel(**kwargs)
    if not init_params:
        return model, None 
    rng = jax.random.PRNGKey(seed)
    dummy_input = jnp.ones((1, 32, 32, 3), dtype)
    variables = model.init(rng, (dummy_input,))
    res = {
        'params': variables['params'],
        'batch_stats': variables['batch_stats']
    }
    return model, res

def reinit_model(model,
                 seed=0,
                 dtype=jnp.float32):
    rng = jax.random.PRNGKey(seed)
    dummy_input = jnp.ones((1, 32, 32, 3), dtype)
    variables = model.init(rng, (dummy_input,))
    res = {
        'params': variables['params'],
        'batch_stats': variables['batch_stats']
    }
    return res
