import pandas as pd
import copy
import numpy as np
import math
import os
import time

os.environ["OMP_NUM_THREADS"] = "1"  # OpenMP
os.environ["MKL_NUM_THREADS"] = "1"  # Intel Math Kernel Library
os.environ["NUMEXPR_NUM_THREADS"] = "1"  # NumExpr
os.environ["OPENBLAS_NUM_THREADS"] = "1"  # OpenBLAS
import random
import ast

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# print(f'version = {tf.__version__}')

USE_GPU = False
if USE_GPU:
    device_idx = 0
    gpus = tf.config.list_physical_devices('GPU')
    gpu_device = gpus[device_idx]
    core_config = tf.config.experimental.set_visible_devices(gpu_device, 'GPU')
    tf.config.experimental.set_memory_growth(gpu_device, True)
    tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=core_config))
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

def repeat_Samples(samples_to_be_repeated, X_train, y_train):

    # Step 1: Separate indices by class
    class0_indices = np.where(y_train < 1)[0]
    class1_indices = np.where(y_train >= 1)[0]
    # print(f'TRAIN: [0,1] : [{len(class0_indices), len(class1_indices)}]')

    # Step 2: Compute how many samples to repeat from each class
    repeat_per_class = samples_to_be_repeated // 2

    # If odd, one sample will be left out — you can choose what to do with that if needed

    # Step 3: Randomly sample with replacement
    sampled_class0 = np.random.choice(class0_indices, repeat_per_class, replace=True)
    sampled_class1 = np.random.choice(class1_indices, repeat_per_class, replace=True)

    # Step 4: Combine and concatenate
    sampled_indices = np.concatenate([sampled_class0, sampled_class1])
    X_repeat = X_train[sampled_indices]
    y_repeat = y_train[sampled_indices]

    # Final step: Concatenate to training set
    X_train = np.concatenate((X_train, X_repeat), axis=0)
    y_train = np.concatenate((y_train, y_repeat), axis=0)

    return X_train, y_train

def data_preparation(tasks_list):
    data_param_dict_for_specific_task = {}
    DataPath = f'../Dataset/{datasetName.upper()}/Task_Splits'

    lengths = []
    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        lengths.append(X_train.shape[0])

    if len(tasks_list) > 1:
        max_size = max(lengths)
        if max_size % 2 == 0:
            max_size += 1

        print(f'max size = {max_size}')

    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        y_train = np.load(f'{DataPath}/{task_id}_y_train.npy')
        X_test = np.load(f'{DataPath}/{task_id}_X_test.npy')
        y_test = np.load(f'{DataPath}/{task_id}_y_test.npy')

        if len(tasks_list) > 1:
            samples_to_be_repeated = max_size - len(X_train)

            if samples_to_be_repeated > 0:
                # print(f'X_train: {X_train.shape}, samples_to_be_repeated = {samples_to_be_repeated}')
                X_train, y_train = repeat_Samples(samples_to_be_repeated, X_train, y_train)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)
        # print(f'shape of y_train = {y_train.shape}')

        data_param_dict_for_specific_task.update({f'Molecule_{task_id}_X_train': X_train})
        data_param_dict_for_specific_task.update({f'Molecule_{task_id}_y_train': y_train})
        data_param_dict_for_specific_task.update({f'Molecule_{task_id}_X_test': X_test})
        data_param_dict_for_specific_task.update({f'Molecule_{task_id}_y_test': y_test})

    return data_param_dict_for_specific_task


def SplitLabels(Target):
    label_data = np.zeros((len(Target), 1))
    for t in range(len(Target)):
        label_data[t] = Target[t][0]
    return label_data


def Splitting_Values(Labels):
    Predicted = []
    for i in Labels:
        for j in i:
            Predicted.append(j)
    return Predicted



def decay_lr(step, optimizer):
    if (step + 1) % 75 == 0:
        optimizer.lr = optimizer.lr / 2.
        # print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))

def permute_list_limit(lst, max_len=2):
    """Returns all combinations of tasks in the task list."""
    task_lst = [t for t in lst]
    print(f'task_lst = {task_lst}')
    # task_lst.sort()
    rtn = []
    for group_len in range(1, max_len + 1):
        for task in itertools.combinations(task_lst, group_len):
            task = list(task)
            # task.sort()
            task = "_".join(task)
            rtn.append(task)
    print(f'rtn = {rtn}')
    return rtn



def estimation_func(shared_hyperparameters, molecule_list, data_param_dict_for_specific_task):
    train_data = []
    train_label = []

    test_data = []
    test_label = []

    for molecule in molecule_list:
        train_data.append(data_param_dict_for_specific_task[f'Molecule_{molecule}_X_train'])
        train_label.append(data_param_dict_for_specific_task[f'Molecule_{molecule}_y_train'])


        test_data.append(data_param_dict_for_specific_task[f'Molecule_{molecule}_X_test'])
        test_label.append(data_param_dict_for_specific_task[f'Molecule_{molecule}_y_test'])

    class SharedEncoder(tf.keras.Model):
        def __init__(self):
            super(SharedEncoder, self).__init__()
            self.shared_layers = []
            for h in range(shared_hyperparameters['shared_FF_Layers']):
                self.shared_layers.append(Dense(shared_hyperparameters['shared_FF_Neurons'][h], activation='relu'))

        def call(self, inputs):
            x = inputs
            for layer in self.shared_layers:
                x = layer(x)
            return x

    class TaskDecoder(tf.keras.Model):
        def __init__(self):
            super(TaskDecoder, self).__init__()
            self.output_layer = Dense(1, activation='sigmoid')

        def call(self, shared_representation):
            return self.output_layer(shared_representation)

    # Create instances of the Shared Encoder and Task Decoders
    shared_encoder = SharedEncoder()
    task_decoders = {molecule: TaskDecoder() for molecule in molecule_list}

    # optimizer = tf.keras.optimizers.Adam(learning_rate=shared_hyperparameters['learning_rate'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=shared_hyperparameters['learning_rate'], momentum=0.9)
    loss_fn = tf.keras.losses.BinaryCrossentropy()


    # @tf.function
    def test_step(x_batch_test, y_batch_test):
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_test]
        predictions = [task_decoders[molecule](shared_rep, training=False) for molecule, shared_rep in
                       zip(molecule_list, shared_representations)]
        eval_losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_test, predictions)]
        eval_loss = tf.reduce_sum(eval_losses)
        return eval_loss, eval_losses, predictions


    '''new parts for estimation'''

    '''pass dummy data to get the gradients'''
    for batch_idx in range(0, len(train_data[0]), batch_size):
        x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_train]
        predictions = [task_decoders[molecule](shared_rep, training=False) for molecule, shared_rep in
                       zip(molecule_list, shared_representations)]
        break

    '''load best weights and model from a file'''
    timeStart = time.time()
    model_base_dir = f'{datasetName}_model_weights'
    gradients_dir = f'{datasetName}_gradients_run_{run}'

    model_dir = f'{model_base_dir}/run_{run}'
    # Check if the files exist and load the weights for shared_encoder
    encoder_weight_path = os.path.join(model_dir, 'shared_encoder')
    if os.path.exists(encoder_weight_path + '.index'):
        shared_encoder.load_weights(encoder_weight_path)
    else:
        print(f'Encoder weights not found at {encoder_weight_path}')

    # Check if the files exist and load the weights for task_decoders
    for molecule in molecule_list:
        decoder_weight_path = os.path.join(model_dir, f'decoder_{molecule}')
        if os.path.exists(decoder_weight_path + '.index'):
            task_decoders[molecule].load_weights(decoder_weight_path)
        else:
            print(f'Decoder weights not found at {decoder_weight_path}')

    grad_params = []
    for params in shared_encoder.trainable_weights:
        grad_params.append(params)

    gradient_dim = 0
    for param in grad_params:
        gradient_dim += param.numpy().size
    # print(f"len(grad_params): {len(grad_params)} \t Gradient Dim: {gradient_dim}")

    project_dim = 200  # project_dim = 202 (200 after dimensionality reduction and 2 task-individual parameters)
    project_matrix = (2 * np.random.randint(2, size=(gradient_dim, project_dim)) - 1).astype(float)
    project_matrix *= 1 / np.sqrt(project_dim)
    # print("Project Dim: {}".format(project_dim))


    # Load gradients
    gradients = []
    for molecule in molecule_list:
        gradient_file = f"{gradients_dir}/{molecule}_train_gradients.npy"
        tmp_gradients = np.load(gradient_file)
        gradients.append(tmp_gradients)
    # print(f'fold = {fold}, shape of gradients {gradients[0].shape}')
    gradients = np.concatenate(gradients, axis=0)

    # randomly assign labels as 0 or 1
    labels = np.random.binomial(n=1, p=0.7, size=gradients.shape[0])
    # if fold ==0:
    #     print(f'shape of gradients and labels = {gradients.shape}, {labels.shape}')

    # reverse the gradients for the 0 labels
    mask = np.copy(labels)
    mask[labels == 0] = -1
    mask = mask.reshape(-1, 1)
    gradients = gradients * mask
    train_num = int(len(gradients) * 0.8)
    train_gradients, train_labels = gradients[:train_num], labels[:train_num]
    test_gradients, test_labels = gradients[train_num:], labels[train_num:]

    # train a logistic regression model
    from sklearn.linear_model import LogisticRegression
    clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)  #
    clf.fit(gradients, labels)
    print(f'score = {clf.score(gradients, labels)}')

    # # Train a logistic regression model on the training set
    # clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)
    # clf.fit(train_gradients, train_labels)
    #
    # # Evaluate on the test set
    # train_score = clf.score(train_gradients, train_labels)
    # test_score = clf.score(test_gradients, test_labels)
    # print(f'fold +{fold}, train_score = {train_score}, test_score = {test_score}')
    #
    # exit(0)

    ## %%
    # projection_matrix = np.load(f"./gradients/{args.dataset_key}_{args.model_key}_{args.preset_key}_{args.project_dim}/projection_matrix_{args.run}.npy")
    proj_coef = clf.coef_.copy().flatten().reshape(-1, 1)
    coef = project_matrix @ proj_coef.flatten()
    # print("L2 norm", np.linalg.norm(coef))
    coef = coef * 2 / np.linalg.norm(coef)

    # print("L2 norm", np.linalg.norm(coef))

    def generate_state_dict(model, state_dict, coef, removing_keys=["pred_head", "bn"]):
        # Convert coef to TensorFlow tensor if it's not already
        coef = tf.convert_to_tensor(coef, dtype=tf.float32)

        new_state_dict = {}
        cur_len = 0

        # Iterate over the model's trainable weights
        for param in model.weights:  # Only get the parameter (TensorFlow's structure)
            param_name = param.name

            # If the parameter matches removing keys, skip updating
            if any([rkey in param_name for rkey in removing_keys]):
                new_state_dict[param_name] = state_dict[
                    f'{param_name}'].copy()  # No `.numpy()` needed for dict access
            else:
                param_len = np.prod(param.shape)  # Get the number of elements in the param tensor

                # Reshape the portion of `coef` to the size of the param and update it
                new_param_value = state_dict[f'{param_name}'].copy() + \
                                  np.reshape(coef[cur_len:cur_len + param_len], param.shape)

                new_state_dict[param_name] = new_param_value
                cur_len += param_len

        return new_state_dict

    state_dict_encoder = {}
    for param in shared_encoder.weights:
        state_dict_encoder[f'{param.name}'] = param.numpy()

    state_dict_decoders = {}
    for molecule in molecule_list:
        state_dict_decoders[molecule] = {}
        for param in task_decoders[molecule].weights:
            state_dict_decoders[molecule][f'{param.name}'] = param.numpy()

    # print(f'len(state_dict_encoder) = {len(state_dict_encoder)}, len(state_dict_decoders) = {len(state_dict_decoders)}')
    '''how many param in new state dict'''

    param_count = 0
    for key in state_dict_encoder.keys():
        param_count += state_dict_encoder[key].size
    # print(f'param_count state_dict_encoder: {param_count}')

    # Update state dict for encoder
    new_state_dict_encoder = generate_state_dict(shared_encoder, state_dict_encoder, coef)
    # print(f'len(new_state_dict_encoder) = {len(new_state_dict_encoder)}')
    new_state_dict_decoders = {}
    for molecule in molecule_list:
        new_state_dict_decoders[molecule] = generate_state_dict(task_decoders[molecule], state_dict_decoders[molecule],
                                                                coef)
        # print(f'len(new_state_dict_decoders) = {len(new_state_dict_decoders)}')

    pretrain_state_dict_encoder = state_dict_encoder
    pretrain_state_dict_decoders = state_dict_decoders
    finetuned_state_dict_encoder = new_state_dict_encoder
    finetuned_state_dict_decoders = new_state_dict_decoders
    # print(f'pretrain_state_dict_restower keys = {pretrain_state_dict_decoders.keys()}')
    # print(f'finetuned_state_dict_ResTowers keys = {finetuned_state_dict_decoders.keys()}')

    shared_encoder.set_weights([finetuned_state_dict_encoder[key] for key in pretrain_state_dict_encoder.keys()])
    for molecule in molecule_list:
        task_decoders[molecule].set_weights(
            [finetuned_state_dict_decoders[molecule][key] for key in pretrain_state_dict_decoders[molecule].keys()])

    results_acc = {task: [] for task in molecule_list}
    results_loss = {task: [] for task in molecule_list}
    for estimate_run in range(5):
        # optimizer = tf.keras.optimizers.SGD(learning_rate=shared_hyperparameters['learning_rate'], momentum=0.9)
        # if name_suffix == 'Fine_Tuned':
        #     '''single estimate_run'''
        #     # batch_train_loss = {task: 0. for task in molecule_list}
        #     for batch_idx in range(0, len(train_data[0]), batch_size):
        #         x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        #         y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        #         train_loss, shared_weights, decoder_weights = train_step(x_batch_train, y_batch_train)
        #         # for task, loss in train_loss.items():
        #         #     batch_train_loss[task] += loss.numpy() / math.ceil(TRAIN_SIZE / batch_size)
        #     print(f'single estimate run ends here')
        '''evaluation'''
        test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
        indiv_losses = [each_loss.numpy() for each_loss in indiv_losses]
        # print(f'shape of y_pred = {y_pred[0].shape}, len(y_pred) = {len(y_pred)}')
        '''get individual error_rate'''
        indi_acc = {task: [] for task in molecule_list}
        for task_idx, task_pred in enumerate(y_pred):
            task_pred = task_pred.numpy()
            task_test_label = test_label[task_idx]
            predicted_val = (task_pred >= 0.75).astype(int)
            error_rate = np.mean(predicted_val != task_test_label)
            indi_acc[molecule_list[task_idx]].append(1 - error_rate)

        # exit(0)
        # y_pred = [pred.numpy() for pred in y_pred]
        # y_test = [label for label in test_label]
        # y_pred = np.concatenate(y_pred, axis=0)
        # y_test = np.concatenate(y_test, axis=0)
        # # print(f'y_pred = {y_pred[:50]}, y_test = {y_test[:50]}')
        #
        # predicted_val = (y_pred >= 0.75).astype(int)
        # error_rate = np.mean(predicted_val != y_test)

        test_metrics_loss = {molecule_list[task_idx]: metric for task_idx, metric in enumerate(indiv_losses)}
        for key, value in test_metrics_loss.items():
            results_loss[key].append(value)
        test_metrics_accu = {task_idx: metric for task_idx, metric in indi_acc.items()}
        for key, value in test_metrics_accu.items():
            results_acc[key].append(value)

    # for key, values in results_acc.items():
    #     print("fold = {} Test Acc {}:  {:1.4f} +/- {:1.4f}".format(fold, key, np.mean(values), np.std(values)))
    # for key, values in results_loss.items():
    #     print("Test Loss {}:  {:1.4f} +/- {:1.4f}".format(key, np.mean(values), np.std(values)))
    time_taken = time.time() - timeStart
    return results_loss, results_acc,time_taken




if __name__ == "__main__":
    datasetName = 'Chemical'
    DataPath = f'../Dataset/{datasetName.upper()}/'
    import sys


    Method_name = 'SimpleMTL'
    group_type = 'RANDOM'


    ResultPath = 'Results'

    import sys

    # ChemicalData = pd.read_csv(f'{DataPath}ChemicalData_All.csv', low_memory=False)
    ChemicalData = pd.read_csv(f'{DataPath}Task_Information_Chemical.csv', low_memory=False)
    '''sort based on dataset'''
    ChemicalData = ChemicalData.sort_values(by=['Dataset_Size'], ascending=False)

    TASKS = list(ChemicalData['Molecule'])
    print(f'TASKS = {TASKS}')
    TASKS = [str(task) for task in TASKS]


    task_len = {}
    variance_dict = {}
    std_dev_dict = {}
    dist_dict = {}
    Single_res_dict = {}
    STL_error = {}
    STL_AP = {}

    num_folds = 5

    Arch_Name = 'Arch_1'
    if Arch_Name == 'Arch_1':
        initial_shared_architecture = {'shared_FF_Layers': 2, 'shared_FF_Neurons': [32, 16],
                                       'learning_rate': 0.001}
        num_epochs = 1000
        batch_size = 264


    import itertools

    if group_type == 'RANDOM':
        for subsetNum in range(5):
            random_subsets = pd.read_csv(f'../RESULTS/{datasetName}_Random_Subsets_for_GRADTAE_run_{subsetNum}.csv', low_memory=False)
            print(len(random_subsets))
            TASK_Group = list(random_subsets['Random_Subsets'])
            TASK_Group = [ast.literal_eval(tg) for tg in TASK_Group]
            name_suffix = 'Fast_Estimation_w_Time'

            for run in [1, 2, 3, 4,5,6]:
                seed_value = run
                tf.random.set_seed(seed_value)
                np.random.seed(seed_value)
                random.seed(seed_value)
                Task_group = []
                Total_Loss = []
                Individual_Group_Score = []
                Individual_Error_Rate = []
                Individual_AP = []
                Number_of_Groups = []
                Individual_Task_Score = []
                Time_taken_for_Fast_training = []
                Prev_Groups = {}
                for count in range(len(TASK_Group)):
                    print(f'Initial Training for {datasetName}-partition {count}, {TASK_Group[count]}')
                    task_group = TASK_Group[count]


                    args_tasks = []
                    group_score = {}
                    group_avg_err = {}
                    group_avg_AP = {}
                    tmp_task_score = []


                    data_param_dict_for_specific_task = data_preparation(task_group)
                    all_scores = estimation_func(initial_shared_architecture, task_group, data_param_dict_for_specific_task)
                    print(all_scores)
                    time_taken = all_scores[-1]
                    all_scores = all_scores[:-1]

                    task_loss = {task: [] for task in task_group}
                    task_acc = {task: [] for task in task_group}

                    results_loss, results_acc = all_scores
                    for task, loss in results_loss.items():
                        task_loss[task].append(loss)
                    for task, acc in results_acc.items():
                        task_acc[task].append(acc)

                    task_loss_mean = {task: np.mean(loss) for task, loss in task_loss.items()}
                    task_acc_mean = {task: 1 - np.mean(acc) for task, acc in task_acc.items()}

                    tot_loss = np.sum(list(task_loss_mean.values()))
                    print(f'tot_loss = {tot_loss}')
                    print(f'task_loss_mean = {task_loss_mean}')
                    print(f'task_acc_mean = {task_acc_mean}')

                    print(f'tot_loss = {tot_loss}')
                    # print(f'group_score = {group_score}')

                    # exit(0)
                    Task_group.append(task_group)
                    # Number_of_Groups.append(len(task_group))
                    Total_Loss.append(tot_loss)
                    # Individual_Group_Score.append(group_score.copy())
                    Individual_Error_Rate.append(task_acc_mean.copy())
                    # Individual_AP.append(group_avg_AP.copy())
                    Individual_Task_Score.append(copy.deepcopy(task_loss_mean))
                    Time_taken_for_Fast_training.append(time_taken/60)
                    # print(Individual_Group_Score)

                    print(len(Total_Loss), len(Task_group), len(Individual_Task_Score), len(Individual_Error_Rate))
                    # exit(0)

                    temp_res = pd.DataFrame({'Total_Loss': Total_Loss,
                                             # 'Number_of_Groups': Number_of_Groups,
                                             'Task_group': Task_group,
                                             'Individual_Task_Score': Individual_Task_Score,
                                             # 'Individual_Group_Score': Individual_Group_Score,
                                             'Individual_Error_Rate': Individual_Error_Rate,
                                             'Time_taken_for_Fast_training': Time_taken_for_Fast_training,
                                             # 'Individual_AP': Individual_AP
                                             })

                    if group_type == 'RANDOM':
                        temp_res.to_csv(f'{ResultPath}/{datasetName}_{name_suffix}_run_{run}_SGD_Arch_{Arch_Name}_subsetNum_{subsetNum}.csv',
                                        index=False)




