import pandas as pd
import copy
import numpy as np
import os
import time

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
#
os.environ["OMP_NUM_THREADS"] = "1"  # OpenMP
os.environ["MKL_NUM_THREADS"] = "1"  # Intel Math Kernel Library
os.environ["NUMEXPR_NUM_THREADS"] = "1"  # NumExpr
os.environ["OPENBLAS_NUM_THREADS"] = "1"  # OpenBLAS
import random
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from sklearn.model_selection import train_test_split
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from sklearn.linear_model import LogisticRegression

# print(f'version = {tf.__version__}')

USE_GPU = False
if USE_GPU:
    device_idx = 0
    gpus = tf.config.list_physical_devices('GPU')
    gpu_device = gpus[device_idx]
    core_config = tf.config.experimental.set_visible_devices(gpu_device, 'GPU')
    tf.config.experimental.set_memory_growth(gpu_device, True)
    tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=core_config))
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def repeat_Samples(samples_to_be_repeated, X_train, y_train):
    random_indices = np.random.choice(X_train.shape[0], samples_to_be_repeated)
    X_train = np.concatenate((X_train, X_train[random_indices]), axis=0)
    y_train = np.concatenate((y_train, y_train[random_indices]), axis=0)
    return X_train, y_train


def data_preparation(tasks_list):
    data_param_dict_for_specific_task = {}
    DataPath = f'../Dataset/{datasetName.upper()}/Task_Splits'
    lengths = []
    for sch_id in tasks_list:
        X_train = np.load(f'{DataPath}/{sch_id}_X_train.npy')
        lengths.append(X_train.shape[0])

    if len(tasks_list) > 1:
        max_size = max(lengths)
        if max_size % 2 == 0:
            max_size += 1

        print(f'max size = {max_size}')

    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        y_train = np.load(f'{DataPath}/{task_id}_y_train.npy')
        X_test = np.load(f'{DataPath}/{task_id}_X_test.npy')
        y_test = np.load(f'{DataPath}/{task_id}_y_test.npy')

        if len(tasks_list) > 1:
            samples_to_be_repeated = max_size - len(X_train)

            if samples_to_be_repeated > 0:
                # print(f'X_train: {X_train.shape}, samples_to_be_repeated = {samples_to_be_repeated}')
                X_train, y_train = repeat_Samples(samples_to_be_repeated, X_train, y_train)

        data_param_dict_for_specific_task.update({f'School_{task_id}_X_train': X_train})
        data_param_dict_for_specific_task.update({f'School_{task_id}_y_train': y_train})
        data_param_dict_for_specific_task.update({f'School_{task_id}_X_test': X_test})
        data_param_dict_for_specific_task.update({f'School_{task_id}_y_test': y_test})

    return data_param_dict_for_specific_task

def decay_lr(step, optimizer):
    if (step + 1) % 75 == 0:
        optimizer.lr = optimizer.lr / 2.
        # print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))

def estimation_func(shared_hyperparameters, tasks_list, data_param_dict_for_specific_task):

    train_data = []
    train_label = []
    test_data = []
    test_label = []

    for sch_id in tasks_list:
        X_train = data_param_dict_for_specific_task[f'School_{sch_id}_X_train']
        y_train = data_param_dict_for_specific_task[f'School_{sch_id}_y_train']



        train_data.append(X_train)
        train_label.append(y_train)


        test_data.append(data_param_dict_for_specific_task[f'School_{sch_id}_X_test'])
        test_label.append(data_param_dict_for_specific_task[f'School_{sch_id}_y_test'])


    class SharedEncoder(tf.keras.Model):
        def __init__(self):
            super(SharedEncoder, self).__init__()
            self.shared_layers = []
            for h in range(shared_hyperparameters['shared_FF_Layers']):
                self.shared_layers.append(Dense(shared_hyperparameters['shared_FF_Neurons'][h], activation='relu'))

        def call(self, inputs):
            x = inputs
            for layer in self.shared_layers:
                x = layer(x)
            return x

    class TaskDecoder(tf.keras.Model):
        def __init__(self):
            super(TaskDecoder, self).__init__()
            self.output_layer = Dense(1, activation='tanh')

        def call(self, shared_representation):
            return self.output_layer(shared_representation)

    # Create instances of the Shared Encoder and Task Decoders
    shared_encoder = SharedEncoder()
    task_decoders = {sch_id: TaskDecoder() for sch_id in tasks_list}

    global_step = tf.Variable(0, trainable=False)
    init_lr = shared_hyperparameters['learning_rate']
    optimizer = tf.keras.optimizers.SGD(init_lr, momentum=0.9, nesterov=False)
    '''mean squared error'''
    loss_fn = tf.keras.losses.MeanSquaredError()

    def test_step(x_batch_test, y_batch_test):
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_test]
        predictions = [task_decoders[sch_id](shared_rep, training=False) for sch_id, shared_rep in
                       zip(tasks_list, shared_representations)]
        eval_losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_test, predictions)]
        eval_loss = tf.reduce_sum(eval_losses)
        return eval_loss, eval_losses, predictions


    '''new parts for estimation'''

    '''pass dummy data to get the gradients'''
    for batch_idx in range(0, len(train_data[0]), batch_size):
        x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_train]
        predictions = [task_decoders[sch_id](shared_rep, training=False) for sch_id, shared_rep in
                       zip(tasks_list, shared_representations)]
        break

    timeStart = time.time()
    '''load best weights and model from a file'''

    model_base_dir = f'{datasetName}_model_weights'
    gradients_dir = f'{datasetName}_gradients_run_{run}'

    model_dir = f'{model_base_dir}/run_{run}'
    # Check if the files exist and load the weights for shared_encoder
    encoder_weight_path = os.path.join(model_dir, 'shared_encoder')
    if os.path.exists(encoder_weight_path + '.index'):
        shared_encoder.load_weights(encoder_weight_path)
    else:
        print(f'Encoder weights not found at {encoder_weight_path}')

    # Check if the files exist and load the weights for task_decoders
    for sch_id in tasks_list:
        decoder_weight_path = os.path.join(model_dir, f'decoder_{sch_id}')
        if os.path.exists(decoder_weight_path + '.index'):
            task_decoders[sch_id].load_weights(decoder_weight_path)
        else:
            print(f'Decoder weights not found at {decoder_weight_path}')

    grad_params = []
    for params in shared_encoder.trainable_weights:
        grad_params.append(params)

    gradient_dim = 0
    for param in grad_params:
        gradient_dim += param.numpy().size
    # print(f"len(grad_params): {len(grad_params)} \t Gradient Dim: {gradient_dim}")

    project_dim = 200  # project_dim = 202 (200 after dimensionality reduction and 2 task-individual parameters)
    project_matrix = (2 * np.random.randint(2, size=(gradient_dim, project_dim)) - 1).astype(float)
    project_matrix *= 1 / np.sqrt(project_dim)
    # print("Project Dim: {}".format(project_dim))

    # Load gradients
    gradients = []
    for sch_id in tasks_list:
        gradient_file = f"{gradients_dir}/{sch_id}_train_gradients.npy"
        tmp_gradients = np.load(gradient_file)
        gradients.append(tmp_gradients)
    # print(f'fold = {fold}, shape of gradients {gradients[0].shape}')
    gradients = np.concatenate(gradients, axis=0)

    # # randomly assign labels as 0 or 1
    # labels = np.random.binomial(n=1, p=0.7, size=gradients.shape[0])
    # mask = np.copy(labels)
    # mask[labels == 0] = -1
    # mask = mask.reshape(-1, 1)
    # gradients = gradients * mask
    # train_num = int(len(gradients) * 0.8)
    # train_gradients, train_labels = gradients[:train_num], labels[:train_num]
    # test_gradients, test_labels = gradients[train_num:], labels[train_num:]
    #
    # # train a logistic regression model
    #
    # clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)  #
    # clf.fit(gradients, labels)
    # print(f'score = {clf.score(gradients, labels)}')



    # Example: randomly assign labels but ensure both classes exist
    def generate_labels(n, p=0.7):
        while True:
            labels = np.random.binomial(n=1, p=p, size=n)
            if np.any(labels == 0) and np.any(labels == 1):
                return labels

    labels = generate_labels(gradients.shape[0])

    # create mask
    mask = np.copy(labels)
    mask[labels == 0] = -1
    mask = mask.reshape(-1, 1)
    gradients_masked = gradients * mask

    # train/test split
    train_num = int(len(gradients) * 0.8)
    train_gradients, train_labels = gradients_masked[:train_num], labels[:train_num]
    test_gradients, test_labels = gradients_masked[train_num:], labels[train_num:]

    # train logistic regression
    clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)
    clf.fit(gradients_masked, labels)
    print(f'Score = {clf.score(gradients_masked, labels)}')

    ## %%
    # projection_matrix = np.load(f"./gradients/{args.dataset_key}_{args.model_key}_{args.preset_key}_{args.project_dim}/projection_matrix_{args.run}.npy")
    proj_coef = clf.coef_.copy().flatten().reshape(-1, 1)
    coef = project_matrix @ proj_coef.flatten()
    # print("L2 norm", np.linalg.norm(coef))
    coef = coef * 2 / np.linalg.norm(coef)

    # print("L2 norm", np.linalg.norm(coef))

    def generate_state_dict(model, state_dict, coef, removing_keys=["pred_head", "bn"]):
        # Convert coef to TensorFlow tensor if it's not already
        coef = tf.convert_to_tensor(coef, dtype=tf.float32)

        new_state_dict = {}
        cur_len = 0

        # Iterate over the model's trainable weights
        for param in model.weights:  # Only get the parameter (TensorFlow's structure)
            param_name = param.name

            # If the parameter matches removing keys, skip updating
            if any([rkey in param_name for rkey in removing_keys]):
                new_state_dict[param_name] = state_dict[
                    f'{param_name}'].copy()  # No `.numpy()` needed for dict access
            else:
                param_len = np.prod(param.shape)  # Get the number of elements in the param tensor

                # Reshape the portion of `coef` to the size of the param and update it
                new_param_value = state_dict[f'{param_name}'].copy() + \
                                  np.reshape(coef[cur_len:cur_len + param_len], param.shape)

                new_state_dict[param_name] = new_param_value
                cur_len += param_len

        return new_state_dict

    state_dict_encoder = {}
    for param in shared_encoder.weights:
        state_dict_encoder[f'{param.name}'] = param.numpy()

    state_dict_decoders = {}
    for sch_id in tasks_list:
        state_dict_decoders[sch_id] = {}
        for param in task_decoders[sch_id].weights:
            state_dict_decoders[sch_id][f'{param.name}'] = param.numpy()

    # print(f'len(state_dict_encoder) = {len(state_dict_encoder)}, len(state_dict_decoders) = {len(state_dict_decoders)}')
    '''how many param in new state dict'''

    param_count = 0
    for key in state_dict_encoder.keys():
        param_count += state_dict_encoder[key].size
    # print(f'param_count state_dict_encoder: {param_count}')

    # Update state dict for encoder
    new_state_dict_encoder = generate_state_dict(shared_encoder, state_dict_encoder, coef)
    # print(f'len(new_state_dict_encoder) = {len(new_state_dict_encoder)}')
    new_state_dict_decoders = {}
    for sch_id in tasks_list:
        new_state_dict_decoders[sch_id] = generate_state_dict(task_decoders[sch_id], state_dict_decoders[sch_id],
                                                                coef)
        # print(f'len(new_state_dict_decoders) = {len(new_state_dict_decoders)}')

    pretrain_state_dict_encoder = state_dict_encoder
    pretrain_state_dict_decoders = state_dict_decoders
    finetuned_state_dict_encoder = new_state_dict_encoder
    finetuned_state_dict_decoders = new_state_dict_decoders
    # print(f'pretrain_state_dict_restower keys = {pretrain_state_dict_decoders.keys()}')
    # print(f'finetuned_state_dict_ResTowers keys = {finetuned_state_dict_decoders.keys()}')

    shared_encoder.set_weights([finetuned_state_dict_encoder[key] for key in pretrain_state_dict_encoder.keys()])
    for sch_id in tasks_list:
        task_decoders[sch_id].set_weights(
            [finetuned_state_dict_decoders[sch_id][key] for key in pretrain_state_dict_decoders[sch_id].keys()])

    results_loss = {task: [] for task in tasks_list}
    for estimate_run in range(5):
        # optimizer = tf.keras.optimizers.SGD(learning_rate=shared_hyperparameters['learning_rate'], momentum=0.9)
        # if name_suffix == 'Fine_Tuned':
        #     '''single estimate_run'''
        #     # batch_train_loss = {task: 0. for task in tasks_list}
        #     for batch_idx in range(0, len(train_data[0]), batch_size):
        #         x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        #         y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        #         train_loss, shared_weights, decoder_weights = train_step(x_batch_train, y_batch_train)
        #         # for task, loss in train_loss.items():
        #         #     batch_train_loss[task] += loss.numpy() / math.ceil(TRAIN_SIZE / batch_size)
        #     print(f'single estimate run ends here')
        '''evaluation'''
        test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
        indiv_losses = [each_loss.numpy() for each_loss in indiv_losses]


        test_metrics_loss = {tasks_list[task_idx]: metric for task_idx, metric in enumerate(indiv_losses)}
        for key, value in test_metrics_loss.items():
            results_loss[key].append(value)

    # for key, values in results_loss.items():
    #     print("Test Loss {}:  {:1.4f} +/- {:1.4f}".format(key, np.mean(values), np.std(values)))
    time_taken = time.time() - timeStart
    return results_loss, time_taken


if __name__ == "__main__":
    datasetName = 'School'
    DataPath = f'../Dataset/{datasetName.upper()}/'
    import sys

    Method_name = 'SimpleMTL'
    # group_type = 'test'
    group_type = 'RANDOM'



    ResultPath = 'Results'

    import sys

    SchoolData = pd.read_csv(f'{DataPath}Task_Information_School.csv', low_memory=False)
    '''sort based on dataset'''
    # SchoolData = SchoolData.sort_values(by=['Dataset_Size'], ascending=False)

    TASKS = list(SchoolData['Task_Name'])
    # print(f'TASKS = {TASKS}')


    TASKS = [str(task) for task in TASKS]

    task_len = {}
    variance_dict = {}
    std_dev_dict = {}
    dist_dict = {}
    Single_res_dict = {}
    STL_error = {}
    STL_AP = {}


    Arch_Name = 'Arch_1'
    if Arch_Name == 'Arch_1':

        num_epochs = 1000
        batch_size = 64
        if Method_name == 'ITA':
            batch_size = 64
        initial_shared_architecture = {'shared_FF_Layers': 3,
                                       'shared_FF_Neurons': [20, 10, 32],
                                       'learning_rate': 0.005}
        MAX_PATIENCE = 50

    import itertools



    if group_type == 'ALL':
        TASK_Group = [tuple(TASKS)]
        name_suffix = 'ALL'

    import ast
    if group_type == 'RANDOM':
        for subsetNum in range(0,3):
            random_subsets = pd.read_csv(f'../RESULTS/{datasetName}_Random_Subsets_for_GRADTAE_run_{subsetNum}.csv', low_memory=False)
            print(len(random_subsets))
            TASK_Group = list(random_subsets['Random_Subsets'])
            TASK_Group = [ast.literal_eval(tg) for tg in TASK_Group]
            name_suffix = 'Fast_Estimation_w_Time'

            for run in [1,2,3,4,5]:
                seed_value = run
                tf.random.set_seed(seed_value)
                np.random.seed(seed_value)
                random.seed(seed_value)
                Task_group = []
                Total_Loss = []
                Individual_Group_Score = []
                Individual_Error_Rate = []
                Individual_AP = []
                Number_of_Groups = []
                Individual_Task_Score = []
                Time_taken_for_Fast_training = []

                Prev_Groups = {}
                for count in range(len(TASK_Group)):
                    print(f'Initial Training for {datasetName}-partition {count}, {TASK_Group[count]}')
                    task_group = TASK_Group[count]

                    args_tasks = []
                    group_score = {}
                    tmp_task_score = []
                    data_param_dict_for_specific_task = data_preparation(task_group)
                    all_scores, time_taken = estimation_func(initial_shared_architecture, task_group, data_param_dict_for_specific_task)
                    print(all_scores)

                    task_loss = {task: [] for task in task_group}


                    results_loss = all_scores
                    for task, loss in results_loss.items():
                        task_loss[task].append(loss)

                    task_loss_mean = {task: np.mean(loss) for task, loss in task_loss.items()}
                    tot_loss = np.sum(list(task_loss_mean.values()))

                    print(f'tot_loss = {tot_loss}')
                    # print(f'group_score = {group_score}')

                    # exit(0)
                    Task_group.append(task_group)
                    # Number_of_Groups.append(len(task_group))
                    Total_Loss.append(tot_loss)
                    # Individual_Group_Score.append(group_score.copy())
                    # Individual_AP.append(group_avg_AP.copy())
                    Individual_Task_Score.append(copy.deepcopy(task_loss_mean))
                    Time_taken_for_Fast_training.append(time_taken/60)
                    # print(Individual_Group_Score)

                    print(len(Total_Loss), len(Task_group), len(Individual_Task_Score), len(Individual_Error_Rate))
                    # exit(0)

                    temp_res = pd.DataFrame({'Total_Loss': Total_Loss,
                                             # 'Number_of_Groups': Number_of_Groups,
                                             'Task_group': Task_group,
                                             'Individual_Task_Score': Individual_Task_Score,
                                             'Time_taken_for_Fast_training': Time_taken_for_Fast_training,
                                             # 'Individual_Group_Score': Individual_Group_Score,
                                             # 'Individual_Error_Rate': Individual_Error_Rate,
                                             # 'Individual_AP': Individual_AP
                                             })
                    temp_res.to_csv(f'{ResultPath}/{datasetName}_{name_suffix}_run_{run}_SGD_Arch_{Arch_Name}_subsetNum_{subsetNum}.csv',
                                        index=False)



