import jax
import jax.numpy as jnp
from jax import random
from jax.nn import initializers

import inspect
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
from typing import Callable, Tuple

"""implementation of a convolutional neural network model"""

class CNN(nn.Module):
    """A simple CNN model with custom initialization."""

    output_size: int = 14
    use_bias: bool = True
    use_bias_output: bool = False
    init_scale: float = 0.0002
    bias_init_scale: float = 0.0
    kernel_init: Callable = initializers.normal
    bias_init: Callable = initializers.constant
    convolution_init_scale: float = 0.05

    def setup(self):
        self.conv1 = nn.Conv(features=32, kernel_size=(5, 5), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.conv2 = nn.Conv(features=64, kernel_size=(3, 3), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.conv3 = nn.Conv(features=96, kernel_size=(3, 3), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.dense1 = nn.Dense(features=512, 
                               use_bias=self.use_bias, 
                               kernel_init=self.custom_initialisation(self.kernel_init), 
                               bias_init=self.custom_initialisation_bias(self.bias_init))
        
        self.dense2 = nn.Dense(features=256, 
                               use_bias=self.use_bias, 
                               kernel_init=self.custom_initialisation(self.kernel_init), 
                               bias_init=self.custom_initialisation_bias(self.bias_init))
        
        self.output_layer = nn.Dense(features=self.output_size, 
                                     use_bias=self.use_bias_output, 
                                     kernel_init=self.custom_initialisation(self.kernel_init), 
                                     bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias_output else None)

    def __call__(self, x):
        x = self.conv1(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = self.conv2(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = self.conv3(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1)) # Flatten the output for the dense layer
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        x = nn.relu(x)
        x = self.output_layer(x)
        return x

    def custom_initialisation(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.init_scale)
        elif 'value' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, value=self.init_scale)
        else:
            return init_fn(*args, **kwargs)
        
    def custom_initialisation_bias(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.bias_init_scale)
        elif 'value' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, value=self.bias_init_scale)
        else:
            return init_fn(*args, **kwargs)
    
    def custom_initialisation_conv(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.convolution_init_scale)
        else:
            return init_fn(*args, **kwargs)
        
class CNN_mutliclass(nn.Module):
    """A simple CNN model with custom initialization."""

    output_size: int = 14
    use_bias: bool = True
    use_bias_output: bool = False
    init_scale: float = 0.0002
    bias_init_scale: float = 0.0
    kernel_init: Callable = initializers.normal
    bias_init: Callable = initializers.constant
    convolution_init_scale: float = 0.05

    def setup(self):
        self.conv1 = nn.Conv(features=32, kernel_size=(5, 5), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.conv2 = nn.Conv(features=64, kernel_size=(3, 3), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.conv3 = nn.Conv(features=96, kernel_size=(3, 3), 
                             use_bias=self.use_bias, 
                             kernel_init=self.custom_initialisation_conv(self.kernel_init), 
                             bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias else None)
        
        self.dense1 = nn.Dense(features=512, 
                               use_bias=self.use_bias, 
                               kernel_init=self.custom_initialisation(self.kernel_init), 
                               bias_init=self.custom_initialisation_bias(self.bias_init))
        
        self.dense2 = nn.Dense(features=256, 
                               use_bias=self.use_bias, 
                               kernel_init=self.custom_initialisation(self.kernel_init), 
                               bias_init=self.custom_initialisation_bias(self.bias_init))
        
        self.output_layer = nn.Dense(features=self.output_size, 
                                     use_bias=self.use_bias_output, 
                                     kernel_init=self.custom_initialisation(self.kernel_init), 
                                     bias_init=self.custom_initialisation_bias(self.bias_init) if self.use_bias_output else None)

    def __call__(self, x):
        x = self.conv1(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = self.conv2(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = self.conv3(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1)) # Flatten the output for the dense layer
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        x = nn.relu(x)
        x = self.output_layer(x)
        return x

    def custom_initialisation(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.init_scale)
        elif 'value' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, value=self.init_scale)
        else:
            return init_fn(*args, **kwargs)
        
    def custom_initialisation_bias(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.bias_init_scale)
        elif 'value' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, value=self.bias_init_scale)
        else:
            return init_fn(*args, **kwargs)
    
    def custom_initialisation_conv(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases."""
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.convolution_init_scale)
        else:
            return init_fn(*args, **kwargs)
        
if __name__ == "__main__":
    key = random.PRNGKey(0)
    input_shape = (1, 28, 28, 1)  # Example input shape for a 28x28 image with 3 channels (RGB)

    # Initialize the model
    model = CNN()

    # Initialize parameters
    params = model.init(key, jnp.ones(input_shape))

    # Apply model to input data
    output = model.apply(params, jnp.ones(input_shape))

    # Print model structure, output, and parameters
    tabulate_fn = nn.tabulate(CNN(), key)
    print(tabulate_fn(jnp.ones(input_shape)))
    print(model)
    # print("Model Output:", output)
    # print("Model Parameters:", params)