import keras.losses
import keras.losses
import tensorflow as tf
import numpy as np
import keras
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import time
import tqdm
import pandas as pd
import tqdm.keras
import json
from MPCmodels import *
from mycallbacks import *


######## Get Data ##########

def generate_circle(n_x, w=np.pi, r=1, t_range=[-1,1], disturbance=0.1):
    """
    Generate a circle dataset.

    Parameters:
    - n_x (int or array-like): Number of samples to generate or an array of input values.
    - w (float, optional): Angular frequency of the circle. Default is np.pi.
    - r (float, optional): Radius of the circle. Default is 1.
    - t_range (list, optional): Range of input values. Default is [-1, 1].
    - disturbance (float, optional): Standard deviation of the disturbance added to the output. Default is 0.1.

    Returns:
    - x (ndarray): Input values.
    - y (ndarray): Output values.
    """
    if isinstance(n_x, int):
        n = n_x
        x = np.random.rand(n) * (t_range[1] - t_range[0]) + t_range[0]
    else:
        x = n_x
        n = x.shape[0]
    eps = np.random.normal(loc=1, scale=disturbance, size=(n,))
    y = np.vstack([np.cos(w * x) * r * eps, np.sin(w * x) * r * eps, np.cos(2 * w * x) * r * eps, np.sin(2 * w * x) * r * eps]).T
    return x[:, None], y

def get_linear_data(d):
    """
    Generate linear data for training and validation.

    Parameters:
    - d (int): The dimension of the data.

    Returns:
    - x (ndarray): Training input data of shape (10000, d).
    - y (ndarray): Training output data of shape (10000, d).
    - xval (ndarray): Validation input data of shape (10000, d).
    - yval (ndarray): Validation output data of shape (10000, d).
    """
    x, xval = np.random.normal(scale=1/np.sqrt(d), size=(10000, d)), np.random.normal(scale=1/np.sqrt(d), size=(10000, d))
    truew = np.random.normal(scale=1/np.sqrt(d), size=(d, d))
    y, yval = x @ truew, xval @ truew
    return x, y, xval, yval

def get_cifar():
    """
    Loads and preprocesses the CIFAR-10 dataset.

    Returns:
        x: Training data features (images) excluding the last 5000 samples.
        y: Training data labels (one-hot encoded) excluding the last 5000 samples.
        xval: Validation data features (images) consisting of the last 5000 samples.
        yval: Validation data labels (one-hot encoded) consisting of the last 5000 samples.
    """
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train = x_train.astype("float32") / 255.
    x_test = x_test.astype("float32") / 255.
    y_train=keras.utils.to_categorical(y_train)
    y_test=keras.utils.to_categorical(y_test)
    xval = x_train[-5000:]
    yval = y_train[-5000:]
    x = x_train[:-5000]
    y = y_train[:-5000]
    return x,y,xval,yval

######## Get Data ##########


######## Experiments #########

def generate_df_fix(index, config):
    """
    Code for fix horizon experiment for a given index and configuration and store in a CSV file.

    Args:
        index (int): The index value.
        config (dict): A dictionary containing the configuration parameters.

    Returns:
        pandas.DataFrame: The generated DataFrame.

    """
    df2 = pd.DataFrame()
    train_dataset = tf.data.Dataset.from_tensor_slices((config['x_train'], config['y_train'])).batch(config['batch_size'])
    val_dataset = tf.data.Dataset.from_tensor_slices((config['x_val'], config['y_val'])).batch(config['batch_size'])
    weights = []
    for iters in tqdm.trange(config['repeats']):
        model = create_model(1, 1, config)
        weights = model.get_weights()
        for horizon, stride in config['horizon_stride']:
            # Create a new model with the specified horizon and stride
            model = create_model(horizon, stride, config)
            
            # Set the weights of the model to the initial weights
            model.set_weights(weights)
            
            # Create callbacks to track time usage, loss, and gradient information
            timecb = get_time_callback()
            losscb = get_intermedium_loss_callback(config['x_val'], config['y_val'])
            select_cb = Select_horizon_Callback(train_dataset, mode='adjust_lr')

            # Train the model and collect information
            history = model.fit(train_dataset, validation_data=val_dataset,
                                epochs=config['epochs'], batch_size=config['batch_size'],
                                verbose=1, callbacks=[timecb, losscb, select_cb,
                                                      MyReduceLR(monitor='loss', factor=0.9, patience=0, min_lr=1e-5,)])
            history.history.update(dict(epoch=history.epoch, time=timecb.log))
            
            # Create a DataFrame with time usage, loss, and gradient information
            dfi = pd.DataFrame(history.history)
            dfi = pd.concat([dfi, pd.DataFrame(losscb.log, columns=['loss_l'+str(i+1) for i in range(losscb.log.shape[1])])], axis=1)
            dfi['iter'] = iters
            dfi['horizon'] = horizon
            dfi['stride'] = stride

            # Concatenate the current DataFrame with the main DataFrame
            df2 = pd.concat([df2, dfi], ignore_index=True)

            # Save the DataFrame to a CSV file
            df2.to_csv('df_fix_horizon_'+str(index)+f"_{config['model']}")  

    
def generate_df_storage(index, config, eager=False):
    """
    Code for storage experiment for a given index and configuration and store in a CSV file.

    Args:
        index (int): The index of the DataFrame.
        config (dict): A dictionary containing configuration parameters.
        eager (bool, optional): Whether to run the model eagerly. Defaults to False.

    Returns:
        None
    """
    df = pd.DataFrame()
    train_dataset = tf.data.Dataset.from_tensor_slices((config['x_train'], config['y_train'])).batch(config['batch_size'])
    val_dataset = tf.data.Dataset.from_tensor_slices((config['x_val'], config['y_val'])).batch(config['batch_size'])
    df_name = f"df_storage_{index}_{config['model']}{'_eager' if eager else '_static'}"
    for _ in tqdm.trange(config['repeats']):
        model = create_model(1, 1, config)
        weights = model.get_weights()
        for horizon, stride in config['horizon_stride']:
            # Create a memory callback to track memory usage
            memory_cb = get_memory_callback(period=10)
            
            # Create a new model with the specified horizon and stride
            model = create_model(horizon, stride, config)
            
            # Set the weights of the model to the initial weights
            model.set_weights(weights)
            
            # Compile the model with the specified optimizer, loss, and metrics
            if eager:
                model.compile(model.optimizer, config['loss'], metrics=config['metrics'], run_eagerly=True)
            
            # Train the model and collect memory usage and time information
            model.fit(train_dataset, epochs=2, verbose=1, validation_data=val_dataset, callbacks=[memory_cb])
            
            # Create a DataFrame with memory usage and time information
            dfi = pd.DataFrame(dict(total_size=memory_cb.log, time=time.perf_counter()-time0))
            dfi['horizon'] = horizon
            dfi['stride'] = stride
            
            # Concatenate the current DataFrame with the main DataFrame
            df = pd.concat([df, dfi], ignore_index=True)
            
            # Save the DataFrame to a CSV file
            df.to_csv(df_name)

def generate_df_grad(index, config):
    """
    Code for gradient experiment for a given index and configuration and store in a CSV file.

    Args:
        index (int): The index value.
        config (dict): A dictionary containing configuration parameters.

    Returns:
        None
    """
    # Create train and validation datasets
    train_dataset = tf.data.Dataset.from_tensor_slices((config['x_train'], config['y_train'])).batch(100)
    val_dataset = tf.data.Dataset.from_tensor_slices((config['x_val'], config['y_val'])).batch(100)
    
    # Initialize an empty DataFrame
    df = pd.DataFrame()
    
    # Iterate over the specified number of repeats
    for i in range(config['repeats']):
        # Create a new model
        model = create_model(1000, 1, config)
        model.horizon=len(model.blocks)
        
        # Create a Select_horizon_Callback object for selecting horizons during training
        select_cb = Select_horizon_Callback(train_dataset, mode='test', period=5,
                                            init_select=True, horizon_stride=config['horizon_stride'])
        
        # Train the model and collect the select_cb DataFrame
        model.fit(train_dataset, epochs=config['epochs'], verbose=1, validation_data=val_dataset,
                  callbacks=[MyReduceLR(monitor='loss', factor=0.9, patience=0, min_lr=1e-5,),
                             select_cb])
        dfi = select_cb.df
        
        # Add the iteration number to the DataFrame
        dfi['iter'] = i
        
        # Concatenate the current DataFrame with the main DataFrame
        df = pd.concat([df, dfi], ignore_index=True,)
        
        # Save the DataFrame to a CSV file
        df.to_csv(f"df_grad_{index}_{config['model']}")
        
        
######## Experiments #########
            
            
if __name__ =='__main__':
    with open('./experiment_config.json') as j:
        config = json.load(j)
    
    gpus = tf.config.list_physical_devices('GPU')
    use_index=config['gpu']
    if gpus:
        # Restrict TensorFlow to only use the specified GPU
        try:
            tf.config.set_visible_devices(gpus[use_index], 'GPU')
            tf.config.experimental.set_memory_growth(gpus[use_index], True)
        except RuntimeError as e:
            # Visible devices must be set before GPUs have been initialized
            print(e)
    else:
        print('no gpu!')
        
    # Check the model type and set the corresponding configurations
    if 'linear' in config['model']:
        config['model_config']['activation']=None
        x,y,xval,yval=get_linear_data(config['d'])
        config.update({'loss':keras.losses.MSE,'metrics':[]})
    elif 'res' in config['model']:
        x,y=generate_circle(10000,t_range=[-2,2],w=np.pi,r=1,disturbance=0.03)
        xval,yval=generate_circle(10000,t_range=[-2,2],w=np.pi,r=1,disturbance=0.03)
        config.update({'loss':keras.losses.MSE,'metrics':[]})
    elif 'conv' in config['model']:
        x,y,xval,yval=get_cifar()
        config.update({'loss':keras.losses.categorical_crossentropy,'metrics':[]})
    
    config.update({'x_train':x,'y_train':y,'x_val':xval,'y_val':yval,})
    
    # Perform the experiment based on the specified experiment type
    if config['experiment']=='fix':
        generate_df_fix(config['index'],config)
    elif config['experiment']=='storage':
        generate_df_storage(config['index'],config,eager=config['eager'])
    elif config['experiment']=='grad':
        generate_df_grad(config['index'],config)

