import flax.linen as nn
import jax.numpy as jnp
import flaxmodels as fm
import jax
from flax import core
import copy

class Linear_mnist(nn.Module):
    num_inputs: int
    num_labels: int

    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, self.num_inputs)
        x = nn.Dense(self.num_labels)(x)
        return x
    
class MLP_fmnist(nn.Module):
    num_inputs: int
    num_labels: int
    units: list
    activation: str = 'relu'

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        for unit in self.units:
            x = nn.Dense(unit)(x)
            x = getattr(nn, self.activation)(x)
        x = nn.Dense(self.num_labels)(x)
        return x



class BasicBlock(nn.Module):
    features: int
    stride: int = 1
    use_projection: bool = False

    @nn.compact
    def __call__(self, x):
        residual = x
        x = nn.Conv(features=self.features, kernel_size=(3, 3), strides=(self.stride, self.stride), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)

        x = nn.Conv(features=self.features, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=False)(x)
        if self.use_projection:
            residual = nn.Conv(features=self.features, kernel_size=(1, 1), strides=(self.stride, self.stride), padding='SAME')(residual)
            residual = nn.BatchNorm(use_running_average=False)(residual)

        x = nn.relu(x + residual)
        return x


class ResNet18(nn.Module):
    num_classes: int
    in_channels: int = 1

    @nn.compact
    def __call__(self, x):
       
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)


        x = self._make_layer(64, 2, stride=1)(x)
  
        x = self._make_layer(128, 2, stride=2)(x)

        x = self._make_layer(256, 2, stride=2)(x)

        x = self._make_layer(512, 2, stride=2)(x)

    
        x = jnp.mean(x, axis=(1, 2))

        x = nn.Dense(features=self.num_classes)(x)
        return x

    def _make_layer(self, features, num_blocks, stride):
        layers = []
        for i in range(num_blocks):
            if i == 0:
                layers.append(BasicBlock(features=features, stride=stride, use_projection=True))
            else:
                layers.append(BasicBlock(features=features, stride=1))
        return nn.Sequential(layers)

class BasicBlock(nn.Module):
    out_channels: int
    stride: int = 1

    @nn.compact
    def __call__(self, x):
        residual = x

        # First convolution
        x = nn.Conv(
            features=self.out_channels,
            kernel_size=(3, 3),
            strides=(self.stride, self.stride),
            padding='SAME',
            use_bias=False
        )(x)
        x = nn.relu(x)

        # Second convolution
        x = nn.Conv(
            features=self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding='SAME',
            use_bias=False
        )(x)
        if residual.shape != x.shape:
            residual = nn.Conv(
                features=self.out_channels,
                kernel_size=(1, 1),
                strides=(self.stride, self.stride),
                padding='SAME',
                use_bias=False
            )(residual)

        x = x + residual
        x = nn.relu(x)

        return x

class ResNet34_NoBN(nn.Module):
    num_classes: int = 11
    @nn.compact
    def __call__(self, x):
        # Ensure input is in NHWC format (batch_size, height, width, channels)
        # Input shape: (batch_size, 1, 28, 28)
        x = jnp.transpose(x, (0, 2, 3, 1))  # Now x has shape (batch_size, 28, 28, 1)

        # Initial convolution
        x = nn.Conv(
            features=64,
            kernel_size=(3, 3),  # Adjusted for small images
            strides=(1, 1),
            padding='SAME',
            use_bias=False
        )(x)
        x = nn.relu(x)


        # Residual layers
        x = self._make_layer(x, out_channels=64, blocks=3, stride=1)
        x = self._make_layer(x, out_channels=128, blocks=4, stride=2)
        x = self._make_layer(x, out_channels=256, blocks=6, stride=2)
        x = self._make_layer(x, out_channels=512, blocks=3, stride=2)

        # Global average pooling
        x = jnp.mean(x, axis=(1, 2))  # Shape: (batch_size, 512)

        # Fully connected layer
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)

        # Output layer
        x = nn.Dense(features=self.num_classes)(x)

        return x

    def _make_layer(self, x, out_channels, blocks, stride):
        x = BasicBlock(out_channels=out_channels, stride=stride)(x)
        for _ in range(1, blocks):
            x = BasicBlock(out_channels=out_channels, stride=1)(x)
        return x

class KMNIST(nn.Module):
    num_inputs: int
    num_labels: int

    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, self.num_inputs)
        x = nn.Dense(features=256)(x)  
        x = nn.relu(x)                  
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.num_labels)(x)
        return x
def get_model(config):
    if config.dataset_name in ['mnist']:
        return Linear_mnist(num_inputs=config.num_inputs, num_labels=config.num_labels)
    elif config.dataset_name in ['kmnist']:
        return KMNIST(num_inputs=config.num_inputs, num_labels=config.num_labels)
    elif config.dataset_name in ['fmnist']:
        return MLP_fmnist(num_inputs=config.num_inputs, num_labels=config.num_labels, units=[64, 64], activation='relu')
    elif config.dataset_name in ['organamnist'] :
        print("config.num_inputs:", config.num_inputs)
        return ResNet18_NoBN(num_classes=11)
    else:
        raise ValueError(f"Unknown dataset: {config.dataset_name}")
    