import tensorflow as tf
import keras

class MPCNetwork(keras.Model):
    """A base model for MPC framework.

    This model updates weights after a full run and implements the MPC algorithm.
    It consists of a stem layer, followed by a series of blocks, and a set of loss blocks.

    Args:
        horizon (int): The number of blocks to consider in each training step.
        stride (int): The number of blocks to skip between each training step.
        mpc_type (int): The type of MPC algorithm to use.
        block_start (int): The index of the first block in the model.
        lossblocks (keras.layers.Layer or list of keras.layers.Layer): The loss blocks to use for each block in the model.

    Attributes:
        block_start (int): The index of the first block in the model.
        blocks (list): The list of layers representing the blocks in the model.
        mpc_type (int): The type of MPC algorithm used.
        horizon (int): The number of blocks to consider in each iteration.
        stride (int): The stride value for iterating over the blocks.
        lossblocks (list): The list of loss blocks used for each block in the model.

    Methods:
        call(x, training=False): Performs a forward pass through the model.
        stem(x, training=True): Applies the stem layers to the input.
        build(input_shape): Builds the model by initializing the loss blocks.
        train_step(data): Performs a single training step on the model.

    """

    def __init__(self, *args, horizon=1, stride=1, mpc_type=1, block_start=1,
                 lossblocks=keras.layers.Dense(10, activation='softmax'), **kwargs):
        super().__init__(*args, **kwargs)
        self.block_start = block_start
        self.blocks = self.layers[block_start:]
        self.mpc_type = mpc_type
        self.horizon = horizon
        self.stride = stride
        if isinstance(lossblocks, keras.layers.Layer):
            self.lossblocks = [lossblocks] * len(self.blocks)
        else:
            assert len(lossblocks) == len(self.blocks)
            self.lossblocks = lossblocks

    def call(self, x, training=False):
        """Performs a forward pass through the model.

        Args:
            x (tf.Tensor): The input tensor.
            training (bool): Whether the model is in training mode or not.

        Returns:
            tf.Tensor: The output tensor.

        """
        x = super().call(x, training=training)
        return self.lossblocks[-1](x, training=training)

    def stem(self, x, training=True):
        """Applies the stem layers to the input.

        Args:
            x (tf.Tensor): The input tensor.
            training (bool): Whether the model is in training mode or not.

        Returns:
            tf.Tensor: The output tensor after applying the stem layers.

        """
        for l in self.layers[:self.block_start]:
            x = l(x, training=training)
        return x

    def build(self, input_shape):
        """Builds the model by initializing the loss blocks.

        Args:
            input_shape (tuple): The shape of the input tensor.

        """
        super().build(input_shape)
        for i, l in enumerate(self.blocks):
            x_shape = l.output_shape
            self.lossblocks[i].build(x_shape)

    def train_step(self, data):
        """Performs a single MPC training step on the model.

        Args:
            data (tuple): A tuple containing the input data and the target labels.

        Returns:
            dict: A dictionary containing the model metrics.

        """
        # Unpack the input data
        x, y = data

        # Get the number of blocks
        nblocks = len(self.blocks)

        # Initialize variables
        s = 0
        lastblock = 0

        # Use GradientTape to record operations for automatic differentiation
        with tf.GradientTape() as tape:
            loss = 0
            xs = self.stem(x, training=True)

            # Loop until the last block is reached
            while True:
                lastblock = min(nblocks, s + self.horizon)
                x = xs

                # Determine if it is the last run based on the MPC type
                if self.mpc_type == 2:
                    last_run = not (s + self.stride < nblocks)
                else:
                    last_run = not (lastblock < nblocks)

                # Iterate over the blocks
                for i, l in enumerate(self.blocks[s:lastblock]):
                    x = l(x, training=True)

                    # Update xs every stride steps
                    if i == self.stride - 1:
                        xs = tf.stop_gradient(x)

                # Calculate the loss and add it to the total loss
                loss = loss + tf.reduce_mean(self.loss(y, self.lossblocks[lastblock - 1](x)))

                # Break the loop if it is the last run
                if last_run:
                    break
                else:
                    s = s + self.stride

            # Add additional losses if present
            if self.losses:
                loss = loss + tf.reduce_sum(self.losses)

        # Calculate gradients and apply them to the trainable variables
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        # Get the predicted output
        y_pred = self.lossblocks[-1](x, training=True)

        # Calculate the compiled loss and update the compiled metrics
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)

        # Get the model metrics
        modelmetrics = {m.name: m.result() for m in self.metrics}

        return modelmetrics
    
class MPCNetwork2(MPCNetwork): 
    """A class representing a modified version of MPCNetwork.

    This class inherits from the MPCNetwork class and introduces additional functionality for updating weights during training.

    Args:
        horizon (int): The number of blocks to consider in each training step.
        stride (int): The number of blocks to skip between each training step.
        block_start (int): The index of the first block to consider in the training process.
        update_together (bool): Whether to update all weights after a full run.
        update_stride (bool): Whether to update only the weights of the stride length.
        update_state (bool): Whether to run the stride after weight update.
        lossblocks (keras.layers.Layer): The loss block layer.
        **kwargs: Additional keyword arguments to be passed to the parent class constructor.

    Attributes:
        horizon (int): The number of blocks to consider in each training step.
        update_together (bool): Whether to update all weights after a full run.
        update_stride (bool): Whether to update only the weights of the stride length.
        update_state (bool): Whether to run the stride after weight update.

    Methods:
        train_one_step: Perform one step of training.
        train_step: Perform a full training step.

    """

    def __init__(self, horizon=1, stride=1, block_start=1,
                 update_together=None, update_stride=False, update_state=True,
                 lossblocks=keras.layers.Dense(10, activation='softmax'), **kwargs):
        super().__init__(horizon=horizon, stride=stride, block_start=block_start,
                         mpc_type=2 if update_stride and update_state else 1,
                         lossblocks=lossblocks, **kwargs)
        self.horizon = horizon
        self.update_together = update_together
        self.update_stride = update_stride
        self.update_state = update_state
        if update_together == None:
            if update_state:
                self.update_together = False
            elif update_stride:
                self.update_together = False
            else:
                self.update_together = True
    
    def train_one_step(self, x, y, blocks, lossblock, first_run=False, last_run=False, stride=None, training=True):
        """Perform submodel training step.

        Args:
            x: The input data.
            y: The target data.
            blocks: The list of blocks to consider in this step.
            lossblock: The loss block layer.
            first_run (bool): Whether this is the first run.
            last_run (bool): Whether this is the last run.

        Returns:
            grad_weight: The gradients and weights.
            xs: The intermediate output.
            y_pred: The predicted output.
            loss: The loss value.

        """
        # Initialize xs as a constant tensor
        xs = tf.constant(False)
        
        # Start recording operations for automatic differentiation
        with tf.GradientTape() as tape:
            # If it's the first run, apply the stem layer to the input data
            if first_run:
                x = self.stem(x, training=training)
            
            # Iterate over the blocks and apply them to the input data
            for i, l in enumerate(blocks):
                x = l(x, training=training)
            
                # If it's the last block of the stride, save the intermediate output
                if i == stride - 1:
                    xs = x
            
            # Compute the predicted output using the loss block
            y_pred = lossblock(x)
            
            # Compute the loss value
            if last_run:
                loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            else:
                loss = tf.reduce_mean(self.loss(y, y_pred))
        
        # Collect the trainable variables for gradient computation
        weights = [w for l in self.layers[:self.block_start] for w in l.trainable_variables] if first_run else []
        
        # Compute the gradients of the loss with respect to the trainable variables
        if self.update_stride and not last_run:
            weights += [w for l in blocks[:stride] for w in l.trainable_variables] + lossblock.trainable_variables
        else:
            weights += [w for l in blocks for w in l.trainable_variables] + lossblock.trainable_variables
        grads = tape.gradient(loss, weights)
        grad_weight = zip(grads, weights)
        
        return grad_weight, xs, y_pred, loss
    
    def train_step(self, data):
        """Perform a full training step.

        This method performs a full training step for the model. It takes in the input data and updates the model's weights based on the training algorithm.

        Args:
            data: The input data.

        Returns:
            modelmetrics: The metrics of the model.

        """
        # Check if the model should be updated together
        if self.update_together:
            return super().train_step(data)

        # Unpack the input data
        xs, y = data

        # Get the number of blocks
        nblocks = len(self.blocks)

        # Initialize variables
        s = 0
        lastblock = 0
        first_run = True

        # Loop until all blocks have been processed
        while True:
            # Determine the last block to process
            lastblock = min(nblocks, s + self.horizon)

            # Check if this is the last run
            if self.update_stride and self.update_state:
                last_run = not (s + self.stride < nblocks)
            else:
                last_run = not (lastblock < nblocks)

            # Perform one training step
            grad_weight, tmpxs, y_pred, loss = self.train_one_step(xs, y, self.blocks[s:lastblock],
                                lossblock=self.lossblocks[lastblock-1], first_run=first_run, last_run=last_run, stride=self.stride)

            # Update weights
            self.optimizer.apply_gradients(grad_weight)

            # Check if this is the last run
            if last_run:
                break
            else:
                # Update the state
                if self.update_state:
                    if first_run:
                        xs = self.stem(xs)
                    for l in self.blocks[s:s+self.stride]:
                        xs = l(xs, training=True)
                else:
                    xs = tmpxs

                # Update the stride
                s = s + self.stride

            # Set first_run to False after the first iteration
            first_run = False

        # Update metrics
        self.compiled_metrics.update_state(y, y_pred)
        modelmetrics = {m.name: m.result() for m in self.metrics}

        return modelmetrics