import tensorflow as tf
import tensorflow_addons as tfa
import keras
import numpy as np
import pandas as pd
import tqdm
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from MPCmodels import *
from mycallbacks import *
import tensorflow_datasets as tfds
from vit_keras import vit
import json

def load_dataset(config, split="train"):
    """
    Loads and prepares the dataset for training or testing.

    Args:
        config (dict): Configuration parameters.
        split (str): Split of the dataset (train or test).

    Returns:
        tf.data.Dataset: Prepared dataset.

    """
    def normalize_img(image, label):
        """Normalizes images: `uint8` -> `float32`."""
        image = tf.image.resize(image, config['image_size'])
        label = tf.one_hot(label, ds_info.features['label'].num_classes)
        return tf.cast(image, tf.float32) / 255., label

    def prepare_dataset(dataset, split):
        """
        Prepares the dataset for a specific split.

        Args:
            dataset (tf.data.Dataset): Raw dataset.
            split (str): Split of the dataset (train or test).

        Returns:
            tf.data.Dataset: Prepared dataset.

        """
        if split == "train":
            return (
                dataset
                # Shuffle the dataset
                .shuffle(ds_info.splits['train'].num_examples)
                # Normalize and resize the images
                .map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
                # Cache the dataset for better performance
                .cache()
                # Batch the dataset
                .batch(BATCH_SIZE)
                # Prefetch the next batch for faster training
                .prefetch(tf.data.AUTOTUNE)
            )
        if split == "test":
            return (
                dataset
                # Normalize and resize the images
                .map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
                # Batch the dataset
                .batch(BATCH_SIZE)
                # Prefetch the next batch for faster testing
                .prefetch(tf.data.AUTOTUNE)
            )
        
    dataset = data[split]
    return prepare_dataset(dataset, split)

# Define a custom layer for the loss block
class lossblock(keras.layers.Layer):
    def __init__(self, output_shape, trainable=True):
        super().__init__()
        self.avg = keras.layers.GlobalAveragePooling2D()
        self.linear = keras.layers.Dense(output_shape, activation='softmax', trainable=trainable)
    
    def build(self, input_shape):
        if self.built:
            return
        self.avg.build(input_shape)
        self.linear.build(self.avg.compute_output_shape(input_shape))
        super().build(input_shape)
    
    def call(self, x):
        y = self.avg(x)
        y = self.linear(y)
        return y
    
# Define a function to create an MPC model using ResNet50 as the base model
def create_mpc_resnet50(config, horizon, stride=1):
    # Load the pre-trained ResNet50 model weights
    resnet50.load_weights('./fine_tune_resnet50/pretrain_resnet50')
    
    # Initialize lists to store the blocks and loss blocks
    blocks = [keras.Model(inputs=resnet50.input, outputs=resnet50.get_layer(config['stem_output_name']).output)]
    lossblocks = []
    
    # Iterate over the convolutional layers specified in the config
    for i, conv in enumerate(config['conv_num']):
        # Create a loss block for each convolutional layer
        lossblocks.append(lossblock(config['classes']))
        
        # Iterate over the repeats specified in the config
        for blk in range(config['repeats'][i]):
            # Duplicate the loss block for each repeat
            lossblocks.append(lossblocks[-1])
            
            # Create the block model based on the current convolutional layer and repeat
            if i == 0 and blk == 0:
                # For the first convolutional layer and repeat, use the stem output as input
                blocks.append(keras.Model(inputs=resnet50.get_layer(config['stem_output_name']).output,
                                          outputs=resnet50.get_layer(f'conv{conv}_block{blk+1}_out').output))
            elif blk == 0:
                # For the first repeat of subsequent convolutional layers, use the output of the previous convolutional layer as input
                blocks.append(keras.Model(inputs=resnet50.get_layer(f"conv{config['conv_num'][i-1]}_block{config['repeats'][i-1]}_out").output,
                                          outputs=resnet50.get_layer(f'conv{conv}_block{blk+1}_out').output))
            else:
                # For subsequent repeats of convolutional layers, use the output of the previous repeat as input
                blocks.append(keras.Model(inputs=resnet50.get_layer(f'conv{conv}_block{blk}_out').output,
                                          outputs=resnet50.get_layer(f'conv{conv}_block{blk+1}_out').output))
        
        # Remove the duplicate loss block added for the last repeat
        lossblocks.pop()
    
    # Set the input and output tensors for the MPC model
    x = resnet50.input
    y = x
    for b in blocks:
        y = b(y)
    
    # Create the MPC model using the MPCNetwork2 class and return it
    mpcmodel = MPCNetwork2(inputs=x, outputs=y, horizon=horizon, stride=stride, block_start=2, lossblocks=lossblocks, **config['mpc_config'])
    return mpcmodel

def train_model(get_model, horizon_stride, config):
    df = pd.DataFrame()  # Create an empty DataFrame to store the training history
    ds_train = load_dataset(config, split='train')  # Load the training dataset
    ds_test = load_dataset(config, split='test')  # Load the testing dataset
    
    for lr in config['learning_rates']:  # Iterate over different learning rates
        for _ in range(config['rep']):  # Repeat the training process for a specified number of times
            for horizon, stride in tqdm.tqdm(horizon_stride):  # Iterate over different horizon and stride values
                model = get_model(config, horizon, stride)  # Create the MPC model using the specified configuration
                config['optimizer_config'].update({'learning_rate': lr})  # Update the learning rate in the optimizer configuration
                
                if config['different_lr']:  # If using different learning rates for different parts of the model
                    # Create two optimizers with different learning rates
                    optimizers = [getattr(keras.optimizers, config['optimizer'])(**config['optimizer_config']),]
                    config['optimizer_config'].update({'learning_rate': config['top_learning_rate']})
                    optimizers.append(getattr(keras.optimizers, config['optimizer'])(**config['optimizer_config']))
                    
                    # Assign the optimizers to the corresponding layers
                    optimizers_and_layers = [(optimizers[0], model.lossblocks), (optimizers[1], model.layers[:model.block_start] + model.blocks)]
                    optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)  # Create a multi optimizer
                    model.compile(optimizer, keras.losses.categorical_crossentropy, metrics=[keras.metrics.categorical_accuracy], run_eagerly=True)
                else:
                    # Compile the model with a single optimizer and the specified loss function and metrics
                    model.compile(getattr(keras.optimizers, config['optimizer'])(**config['optimizer_config']),
                                  keras.losses.categorical_crossentropy, metrics=[keras.metrics.categorical_accuracy], run_eagerly=False)
                
                timecb = get_time_callback()  # Create a callback to measure the training time
                history = model.fit(ds_train, validation_data=ds_test, epochs=config['epochs'], verbose=1,
                                    callbacks=[
                                        timecb,
                                        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=0, min_lr=1e-5),
                                    ])  # Train the model and record the training history
                
                history.history.update(dict(epoch=history.epoch, time=timecb.log))  # Add the epoch and time information to the training history
                dfi = pd.DataFrame(history.history)  # Create a DataFrame from the training history
                dfi['horizon'] = horizon  # Add the horizon value to the DataFrame
                dfi['stride'] = stride  # Add the stride value to the DataFrame
                dfi['learning_rate'] = lr  # Add the learning rate to the DataFrame
                df = pd.concat([df, dfi], ignore_index=True)  # Concatenate the current training history with the overall training history
                df.to_csv(f"df_{config['model_name']}_{config['dataset_name']}_{config['optimizer']}_{'' if config['freeze'] else 'nofreeze'}{'_different_lr' if config['different_lr'] else ''}{'_usd' if config['mpc_config']['update_stride'] else ''}{'_ust' if config['mpc_config']['update_state'] else ''}")  # Save the training history to a CSV file
    
    return df  # Return the overall training history

class LoraLayer(keras.layers.Layer):
    def __init__(
        self,
        original_layer,
        rank=8,
        alpha=32,
        trainable=False,
        **kwargs,
    ):
        # We want to keep the name of this layer the same as the original
        # dense layer.
        original_layer_config = original_layer.get_config()
        name = original_layer_config["name"]

        kwargs.pop("name", None)

        super().__init__(name=name, trainable=trainable, **kwargs)

        self.rank = rank
        self.alpha = alpha

        self._scale = alpha / rank

        self.in_shape,self.out_shape=original_layer.weights[0].shape

        # Layers.

        # Original dense layer.
        self.original_layer = original_layer
        # No matter whether we are training the model or are in inference mode,
        # this layer should be frozen.
        self.original_layer.trainable = False

        # LoRA dense layers.
        self.A = keras.layers.Dense(
            units=rank,
            use_bias=False,
            # Note: the original paper mentions that normal distribution was
            # used for initialization. However, the official LoRA implementation
            # uses "Kaiming/He Initialization".
            kernel_initializer=keras.initializers.VarianceScaling(
                scale=np.sqrt(5), mode="fan_in", distribution="uniform"
            ),
            trainable=trainable,
            name=f"lora_A",
        )
        # B has the same `equation` and `output_shape` as the original layer.
        # `equation = abc,cde->abde`, where `a`: batch size, `b`: sequence
        # length, `c`: `hidden_dim`, `d`: `num_heads`,
        # `e`: `hidden_dim//num_heads`. The only difference is that in layer `B`,
        # `c` represents `rank`.
        self.B = keras.layers.Dense(
            units=self.out_shape,
            use_bias=False,
            kernel_initializer="zeros",
            trainable=trainable,
            name=f"lora_B",
        )

    def call(self, inputs):
        original_output = self.original_layer(inputs)
        if self.trainable:
            # If we are fine-tuning the model, we will add LoRA layers' output
            # to the original layer's output.
            lora_output = self.B(self.A(inputs)) * self._scale
            return original_output + lora_output

        # If we are in inference mode, we "merge" the LoRA layers' weights into
        # the original layer's weights - more on this in the text generation
        # section!
        return original_output
    
def get_lora_model(vit,rank,alpha):
    
    for layer in vit.layers:
        if not hasattr(layer,'att'):
            continue
        # Change query dense layer.
        self_attention_layer = layer.att

        # Change query dense layer.
        self_attention_layer.query_dense = LoraLayer(
            self_attention_layer.query_dense,
            rank=rank,
            alpha=alpha,
            trainable=True,
        )

        # Change value dense layer.
        self_attention_layer.value_dense = LoraLayer(
            self_attention_layer.value_dense,
            rank=rank,
            alpha=alpha,
            trainable=True,
        )
    for layer in vit._flatten_layers():
        lst_of_sublayers = list(layer._flatten_layers())

        if len(lst_of_sublayers) == 1:  # "leaves of the model"
            if layer.name in ["lora_A", "lora_B"]:
                layer.trainable = True
            else:
                layer.trainable = False
    return vit

# Define a function to create an MPC model using LoRA-ViT as the base model
def create_mpc_lora_vit(config, horizon, stride=1):
    # Clone the ViT model and set its weights
    tmpvit = keras.models.clone_model(vitmodel)
    tmpvit.set_weights(vitmodel.get_weights())
    
    # Apply LoRA transformation to the cloned ViT model
    lora_vit = get_lora_model(tmpvit, config['rank'], config['alpha'])
    
    # Initialize lists to store the blocks and loss block
    blocks = [keras.Model(inputs=lora_vit.input, outputs=lora_vit.get_layer(config['stem_output_name']).output)]
    lossblock = keras.Model(inputs=lora_vit.get_layer('Transformer/encoder_norm').input, outputs=lora_vit.output)
    
    # Set the loss block layers to be trainable
    for l in lossblock.layers:
        l.trainable = True
    
    # Iterate over the layers in the LoRA-ViT model
    for l in lora_vit.layers:
        if isinstance(l, vit.layers.TransformerBlock):
            # Add each transformer block as a separate block
            blocks.append(keras.Model(inputs=l.input, outputs=l.output[0]))
    
    # Set the input and output tensors for the MPC model
    x = lora_vit.input
    y = x
    for b in blocks:
        y = b(y)
    
    # Create the MPC model using the MPCNetwork2 class and return it
    mpcmodel = MPCNetwork2(inputs=x, outputs=y, horizon=horizon, stride=stride, block_start=2, lossblocks=lossblock, **config['mpc_config'])
    return mpcmodel



if __name__ =='__main__':
    with open('./finetune_config.json') as j:
      config = json.load(j)
    
    gpus = tf.config.list_physical_devices('GPU')
    use_index=config['gpu']
    if gpus:
      # Restrict TensorFlow to only use the first GPU
      try:
        tf.config.set_visible_devices(gpus[use_index], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[use_index], True)
        # logical_gpus = tf.config.list_logical_devices('GPU')
        # print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
      except RuntimeError as e:
        # Visible devices must be set before GPUs have been initialized
        print(e)
    else:
        print('no gpu!')
        
    data, ds_info = tfds.load(
        'cifar100',
    #     split=['train', 'test'],
        as_supervised=True,
        with_info=True,
    )
    BATCH_SIZE=32
    
    if config['model_name']=='resnet50':
        resnet50=keras.applications.ResNet50(include_top=False,
            weights='imagenet',
            input_shape=(224,224,3),
            pooling='avg',
            classes=100,)
        tmpdf=train_model(create_mpc_resnet50,[[i,1] for i in config['horizons']],config)
    
    elif config['model_name']=='vit_lora':
        image_size = 224
        vitmodel = vit.vit_b16(
            image_size=image_size,
            activation='softmax',
            pretrained=True,
            include_top=True,
            pretrained_top=False,
            classes=100,
            )
        tmpdf=train_model(create_mpc_lora_vit,[[i,1] for i in config['horizons']],config)
    else:
        print('Unknown Model!')
        
    