import tensorflow as tf
import keras
import numpy as np
from MPCclass import *

class ResDense(keras.layers.Dense):
    """
    Custom Dense layer that adds a residual connection to the output.
    
    Inherits from keras.layers.Dense.
    
    Args:
        unit (int): Positive integer, dimensionality of the output space.
        layer_scale (float, optional): Scaling factor for the residual connection. Defaults to 1e-1.
        scale_trainable (bool, optional): Whether the scaling factor is trainable. Defaults to False.
        **kwarg: Additional keyword arguments to be passed to the parent class constructor.
    """
    def __init__(self, unit, layer_scale=1e-1, scale_trainable=False, **kwarg):
        super().__init__(unit, **kwarg)
        self.scale_trainable = scale_trainable
        self.scale = layer_scale
    
    def build(self, input_shape):
        self.scale = self.add_weight(
            name='scale',
            shape=(),
            initializer=keras.initializers.Constant(value=self.scale),
            trainable=self.scale_trainable
        )
        super().build(input_shape)
    
    def call(self, x):
        fx = super().call(x)
        return x + self.scale * fx

class Resblock(keras.layers.Layer):
    """
    A class representing a Residual Block.

    Args:
        block: A function representing the main block of the Residual Block.
        residual_connect_block: A function representing the residual connection block. If not provided, it defaults to tf.identity.
        layer_scale: A float representing the scale factor for the residual connection. Defaults to 1e-1.
        scale_trainable: A boolean indicating whether the scale factor is trainable. Defaults to False.
        activation: A string representing the activation function to be applied. If not provided, it defaults to tf.identity.
        stem: A function representing the stem block. If not provided, it defaults to tf.identity.

    Attributes:
        block: A function representing the main block of the Residual Block.
        residual_connect_block: A function representing the residual connection block.
        scale_trainable: A boolean indicating whether the scale factor is trainable.
        scale: A float representing the scale factor for the residual connection.
        _stem: A function representing the stem block.
        act: An activation function to be applied.

    Methods:
        build: Builds the Residual Block.
        stem: Applies the stem block to the input.
        residual: Computes the residual connection.
        call: Calls the Residual Block.

    """

    def __init__(self, block, residual_connect_block=None, layer_scale=1e-1, scale_trainable=False, activation=None, stem=None):
        super().__init__()
        self.block = block
        self.residual_connect_block = residual_connect_block if residual_connect_block else tf.identity
        self.scale_trainable = scale_trainable
        self.scale = layer_scale
        self._stem = stem if stem else tf.identity
        self.act = eval('keras.activations.' + activation) if activation else tf.identity

    def build(self, input_shape):
        self.scale = self.add_weight(
            name='scale',
            shape=(),
            initializer=keras.initializers.Constant(value=self.scale),
            trainable=self.scale_trainable
        )
        super().build(input_shape)

    def stem(self, x):
        """
        Applies the stem block to the input.

        Args:
            x: The input tensor.

        Returns:
            The output tensor after applying the stem block.

        """
        return self._stem(x)

    def residual(self, x):
        """
        Computes the residual connection.

        Args:
            x: The input tensor.

        Returns:
            The output tensor after applying the residual connection.

        """
        fx = self.block(x)
        return self.act(self.residual_connect_block(x) + self.scale * fx)

    def call(self, x):
        """
        Calls the Residual Block.

        Args:
            x: The input tensor.

        Returns:
            The output tensor after applying the Residual Block.

        """
        x = self.stem(x)
        return self.residual(x)
        
    
class Conv2dblock(keras.layers.Conv2D):
    """Custom Conv2D layer with normalization and activation functions.

    Args:
        *args: Positional arguments passed to the parent class.
        **kwargs: Keyword arguments passed to the parent class.
            normalization: Normalization function to be applied after convolution (default: tf.identity).

    Attributes:
        normalization: Normalization function to be applied after convolution.
        act: Activation function.
        activation: Activation function (set to None to prevent activation during the call).

    """

    def __init__(self, *args, **kwargs):
        self.normalization = kwargs.pop('normalization', tf.identity)
        kwargs.update(dict(use_bias=False))
        super().__init__(*args, **kwargs)
        self.act = self.activation
        self.activation = None

    def call(self, x):
        """Forward pass of the Conv2dblock layer.

        Args:
            x: Input tensor.

        Returns:
            Output tensor after convolution, normalization, and activation.

        """
        x = super().call(x)
        x = self.normalization(x)
        x = self.act(x)
        return x
class Conv2dblock(keras.layers.Conv2D):
    def __init__(self,*args,**kwargs):
        self.normalization=kwargs.pop('normalization',tf.identity)
        kwargs.update(dict(use_bias=False))
        super().__init__(*args,**kwargs)
        self.act=self.activation
        self.activation=None
        
    def call(self, x):
        x=super().call(x)
        x=self.normalization(x)
        x=self.act(x)
        return x
    
class Conv2dblockx2(keras.layers.Layer):
    """
    A custom Keras layer that applies two convolutional blocks sequentially.

    Args:
        filters (list[int] or int): Number of filters for each convolutional block. If an int is provided, the same number of filters will be used for both blocks.
        kernel_size (int): Size of the convolutional kernel.
        activation (str): Activation function to use in the convolutional blocks.
        strides (list[int] or int): Strides for each convolutional block. If an int is provided, the same stride will be used for both blocks.

    Attributes:
        conv1 (Conv2dblock): The first convolutional block.
        conv2 (Conv2dblock): The second convolutional block.
    """

    def __init__(self, filters=[16, 16], kernel_size=3, activation='relu', strides=[1, 1]):
        super().__init__()
        if isinstance(filters, int):
            filters = [filters] * 2
        if isinstance(strides, int):
            strides = [strides] * 2
        self.conv1 = Conv2dblock(filters[0], kernel_size=kernel_size, strides=strides[0], padding='same', activation=activation,
                                 normalization=keras.layers.BatchNormalization(epsilon=1.001e-5))
        self.conv2 = Conv2dblock(filters[1], kernel_size=kernel_size, strides=strides[1], padding='same', activation=None,
                                 normalization=keras.layers.BatchNormalization(epsilon=1.001e-5))

    def call(self, x):
        """
        Applies the two convolutional blocks sequentially to the input tensor.

        Args:
            x (tf.Tensor): The input tensor.

        Returns:
            tf.Tensor: The output tensor after applying the convolutional blocks.
        """
        y = self.conv1(x)
        y = self.conv2(y)
        return y


def creat_resnet_mpc(horizon, input_shape=(None, None), filters=[16, 32, 64], repeats=[10, 10, 10], activation='relu', strides=2, kernel_sizes=3,
              layer_scale=1e-1, scale_trainable=False, loss_trainable=True, conv_shortcut=False,
              expand_dim=True, mpc_stride=1, output_shape=10, block_start=1,
              different_loss=False, update_stride=False, update_state=True, update_together=None, **kwargs):
    """
    Creates a ResNet-based model for Model Predictive Control (MPC).

    Args:
        horizon (int): The horizon length for the MPC.
        input_shape (tuple, optional): The shape of the input tensor. Defaults to (None, None).
        filters (list, optional): The number of filters for each ResNet block. Defaults to [16, 32, 64].
        repeats (list, optional): The number of times each ResNet block is repeated. Defaults to [10, 10, 10].
        activation (str, optional): The activation function to use. Defaults to 'relu'.
        strides (int, optional): The stride value for the convolutional layers. Defaults to 2.
        kernel_sizes (int, optional): The size of the convolutional kernel. Defaults to 3.
        layer_scale (float, optional): The scaling factor for the ResNet layers. Defaults to 1e-1.
        scale_trainable (bool, optional): Whether the scaling factor is trainable. Defaults to False.
        loss_trainable (bool, optional): Whether the loss function is trainable. Defaults to True.
        conv_shortcut (bool, optional): Whether to use convolutional shortcut connections. Defaults to False.
        expand_dim (bool, optional): Whether to expand the dimensions of the input tensor. Defaults to True.
        mpc_stride (int, optional): The stride value for the MPC. Defaults to 1.
        output_shape (int, optional): The shape of the output tensor. Defaults to 10.
        block_start (int, optional): The starting index for the ResNet blocks. Defaults to 1.
        different_loss (bool, optional): Whether to use different loss functions for each repeat. Defaults to False.
        update_stride (bool, optional): Whether to update the stride value. Defaults to False.
        update_state (bool, optional): Whether to update the state. Defaults to True.
        update_together (None, optional): The update strategy for the MPC. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        model_mpc: The ResNet-based MPC model.

    Raises:
        AssertionError: If the lengths of filters, repeats, strides, and kernel_sizes are not equal.

    """
    
    # Get model_name, optimizer, and iscompile from kwargs
    model_name=kwargs.get('model_name', 'model_mpc')
    optimizer=kwargs.get('optimizer','sgd')
    iscompile=kwargs.get('iscompile',True)
    
    # Check if repeats, strides, and kernel_sizes are integers and convert them to lists if they are
    if isinstance(repeats,int):
        repeats=[repeats]*len(filters)
    if isinstance(strides,int):
        strides=[strides]*len(filters)
    if isinstance(kernel_sizes,int):
        kernel_sizes=[kernel_sizes]*len(filters)
    
    # Check if the lengths of filters, repeats, strides, and kernel_sizes are equal
    assert len(filters)==len(repeats) and len(filters)==len(strides) and len(filters)==len(kernel_sizes)
    
    # Define a custom layer for the loss block
    class lossblock(keras.layers.Layer):
        def __init__(self,output_shape,trainable=True):
            super().__init__()
            self.avg=keras.layers.GlobalAveragePooling2D()
            self.linear=keras.layers.Dense(output_shape,activation='softmax',trainable=trainable)
        def build(self,input_shape):
            if self.built:
                return
            self.avg.build(input_shape)
            self.linear.build(self.avg.compute_output_shape(input_shape))
            super().build(input_shape)
        def call(self,x):
            y=self.avg(x)
            y=self.linear(y)
            return y
    
    # Define the input layer
    inputs=tf.keras.Input(shape=input_shape)
    
    # Initialize block_start-1 to bs
    bs=block_start-1
    
    # Expand the dimensions of the input tensor if expand_dim is True
    if expand_dim:
        x=keras.layers.Lambda(lambda x: tf.expand_dims(x,-1))(inputs)
        block_start+=1
    else:
        x=inputs
    
    # Create the ResNet blocks
    for i,(filteri,repeat,stride,kernel_size,) in enumerate(zip(filters,repeats,strides,kernel_sizes)):
        if i==0:
            x=Resblock(Conv2dblockx2(kernel_size=kernel_size,filters=filteri,activation=activation),
                    residual_connect_block=Conv2dblock(kernel_size,filteri,activation=None,padding='same',
                    normalization=keras.layers.BatchNormalization(epsilon=1.001e-5)) if conv_shortcut else None,
                    layer_scale=layer_scale,scale_trainable=scale_trainable,activation=activation,
                    stem=Conv2dblock(filteri,kernel_size,strides=stride,activation=activation,padding='same'))(x)
        else:
            x=Resblock(Conv2dblockx2(kernel_size=kernel_size,filters=filteri,activation=activation,strides=[stride,1]),
                    residual_connect_block=Conv2dblock(kernel_size=kernel_size,filters=filteri,activation=None,strides=stride,padding='same',
                    normalization=keras.layers.BatchNormalization(epsilon=1.001e-5)),
                    layer_scale=layer_scale,scale_trainable=scale_trainable,activation=activation,)(x)
        for _ in range(repeat-1):
            x=Resblock(Conv2dblockx2(kernel_size=kernel_size,filters=filteri,activation=activation),
                    residual_connect_block=Conv2dblock(kernel_size,filteri,activation=None,padding='same',
                    normalization=keras.layers.BatchNormalization(epsilon=1.001e-5)) if conv_shortcut else None,
                    layer_scale=layer_scale,scale_trainable=scale_trainable,activation=activation,)(x)
            
    # Create the loss blocks
    lossblocks=[]
    for repeat in repeats:
        if bs>=repeat:
            bs-=repeat
        else:
            if different_loss:
                lossblocks+=[lossblock(output_shape,trainable=loss_trainable) for _ in range(repeat-bs)]
            else:
                lossblocks+=[lossblock(output_shape,trainable=loss_trainable)]*(repeat-bs)
            bs=0
            
    # Create the MPC model
    model_mpc=MPCNetwork2(inputs=inputs,outputs=x,name=model_name,horizon=horizon,stride=mpc_stride,
                           block_start=block_start,lossblocks=lossblocks,
                          update_together=update_together,update_stride=update_stride,update_state=update_state,
                          )
    
    # Load weights if provided in kwargs
    if hasattr(kwargs,'weights'):
        model_mpc.load_weights(kwargs.get('weights'))
    
    # Compile the model if iscompile is True
    if iscompile:
        model_mpc.compile(optimizer,keras.losses.categorical_crossentropy,
                          metrics=[keras.metrics.CategoricalAccuracy(name='acc')])
    
    return model_mpc


def create_model(horizon, stride, config):
    """
    Create a model for MPC (Model Predictive Control) framework based on the given configuration.

    Args:
        horizon (int): The prediction horizon.
        stride (int): The stride for the MPC model.
        config (dict): The configuration dictionary containing model and MPC configurations.

    Returns:
        keras.Model: The created MPC model.

    Raises:
        ValueError: If the model type specified in the configuration is not supported.

    """
    # Check if the model type is 'conv'
    if config['model'] == 'conv':
        # Create ResNet-based MPC model
        model = creat_resnet_mpc(horizon, stride=stride, **config['model_config'], **config['mpc_config'])
    else:
        # Create a custom model
        xx = keras.Input(shape=(config['model_config']['input']))
        yy = keras.layers.Dense(config['model_config']['hidden'], activation=config['model_config']['activation'])(xx)
        
        # Check if normalization is enabled
        if config['model_config']['normalize']:
            # Define a custom Dense layer with normalization
            class Norm_Dense(keras.layers.Dense):
                def __init__(self, *args, **kwargs):
                    super().__init__(*args, **kwargs)
                    self.normalize_layer = keras.layers.BatchNormalization()
                    
                def call(self, x, training=False):
                    y = super().call(x)
                    y = self.normalize_layer(y, training=training)
                    return y
                    
            layer = Norm_Dense
        else:
            layer = keras.layers.Dense
            
        # Create ResNet blocks
        for _ in range(config['model_config']['hidden_num']):
            yy = Resblock(layer(config['model_config']['hidden'], activation=config['model_config']['activation'],
                                use_bias=False if config['model_config']['activation'] is None else True),
                          layer_scale=config['model_config']['layer_scale'])(yy)
                
        # Create loss block
        lossblock = keras.layers.Dense(config['model_config']['output'], activation=None)
        
        # Check if different loss functions are enabled
        if 'different_loss' in config['mpc_config'] and config['mpc_config']['different_loss']:
            lossblock = []
            for _ in range(config['model_config']['hidden_num']):
                lossblock.append(keras.layers.Dense(config['model_config']['output'], activation=None))
                
        # Set loss block trainability
        if not config['model_config']['loss_trainable']:
            lossblock.trainable = False
            
        # Create MPCNetwork2 model
        model = MPCNetwork2(inputs=xx, outputs=yy, horizon=horizon, stride=stride, lossblocks=lossblock, block_start=2,
                            update_state=config['mpc_config']['update_state'],
                            update_stride=config['mpc_config']['update_stride'])
                            
    # Set optimizer and compile the model
    optimizer = keras.optimizers.SGD(learning_rate=config['learning_rate'])
    model.build(config['x_train'].shape)
    model.compile(optimizer, config['loss'], metrics=config['metrics'])
    
    return model

