import pandas as pd
import copy
import numpy as np
import math
import os
import time
import tqdm

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
#
os.environ["OMP_NUM_THREADS"] = "1"  # OpenMP
os.environ["MKL_NUM_THREADS"] = "1"  # Intel Math Kernel Library
os.environ["NUMEXPR_NUM_THREADS"] = "1"  # NumExpr
os.environ["OPENBLAS_NUM_THREADS"] = "1"  # OpenBLAS
import random
import ast
from sklearn.linear_model import LogisticRegression
import tensorflow as tf

from tensorflow.keras.layers import Input, Dense

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# print(f'version = {tf.__version__}')

USE_GPU = False
if USE_GPU:
    device_idx = 0
    gpus = tf.config.list_physical_devices('GPU')
    gpu_device = gpus[device_idx]
    core_config = tf.config.experimental.set_visible_devices(gpu_device, 'GPU')
    tf.config.experimental.set_memory_growth(gpu_device, True)
    tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=core_config))
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def repeat_Samples(samples_to_be_repeated, X_train, y_train):

    # Step 1: Separate indices by class
    class0_indices = np.where(y_train ==0)[0]
    class1_indices = np.where(y_train ==1)[0]
    # print(f'TRAIN: [0,1] : [{len(class0_indices), len(class1_indices)}]')

    # Step 2: Compute how many samples to repeat from each class
    repeat_per_class = samples_to_be_repeated // 2

    # If odd, one sample will be left out — you can choose what to do with that if needed

    # Step 3: Randomly sample with replacement
    sampled_class0 = np.random.choice(class0_indices, repeat_per_class, replace=True)
    sampled_class1 = np.random.choice(class1_indices, repeat_per_class, replace=True)

    # Step 4: Combine and concatenate
    sampled_indices = np.concatenate([sampled_class0, sampled_class1])
    X_repeat = X_train[sampled_indices]
    y_repeat = y_train[sampled_indices]

    # Final step: Concatenate to training set
    X_train = np.concatenate((X_train, X_repeat), axis=0)
    y_train = np.concatenate((y_train, y_repeat), axis=0)

    return X_train, y_train


def data_preparation(tasks_list):
    data_param_dict_for_specific_task = {}
    DataPath = f'../Dataset/{datasetName.upper()}/Task_Splits'
    lengths = []
    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        lengths.append(X_train.shape[0])

    if len(tasks_list) > 1:
        max_size = max(lengths)
        if max_size % 2 == 0:
            max_size += 1

        print(f'max size = {max_size}')

    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        y_train = np.load(f'{DataPath}/{task_id}_y_train.npy')
        X_test = np.load(f'{DataPath}/{task_id}_X_test.npy')
        y_test = np.load(f'{DataPath}/{task_id}_y_test.npy')

        if len(tasks_list) > 1:
            samples_to_be_repeated = max_size - len(X_train)

            if samples_to_be_repeated > 0:
                # print(f'X_train: {X_train.shape}, samples_to_be_repeated = {samples_to_be_repeated}')
                X_train, y_train = repeat_Samples(samples_to_be_repeated, X_train, y_train)

        # print(f'unique labels in train: {(np.unique(y_train))}, unique labels in test: {(np.unique(y_test))}')
        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)
        # print(f'shape of y_train = {y_train.shape}')

        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_X_train': X_train})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_y_train': y_train})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_X_test': X_test})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_y_test': y_test})

    return data_param_dict_for_specific_task




def decay_lr(step, optimizer):
    if (step + 1) % 75 == 0:
        optimizer.lr = optimizer.lr / 2.
        # print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))


def estimation_func(shared_hyperparameters, tasks_list, data_param_dict_for_specific_task):

    train_data = []
    train_label = []
    test_data = []
    test_label = []

    for task_id in tasks_list:
        train_data.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_X_train'])
        train_label.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_y_train'])

        test_data.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_X_test'])
        test_label.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_y_test'])

    class SharedEncoder(tf.keras.Model):
        def __init__(self):
            super(SharedEncoder, self).__init__()
            self.shared_layers = []
            for h in range(shared_hyperparameters['shared_FF_Layers']):
                self.shared_layers.append(Dense(shared_hyperparameters['shared_FF_Neurons'][h], activation='relu'))

        def call(self, inputs):
            x = inputs
            for layer in self.shared_layers:
                x = layer(x)
            return x

    class TaskDecoder(tf.keras.Model):
        def __init__(self):
            super(TaskDecoder, self).__init__()
            self.output_layer = Dense(1, activation='sigmoid')

        def call(self, shared_representation):
            return self.output_layer(shared_representation)

    # Create instances of the Shared Encoder and Task Decoders
    shared_encoder = SharedEncoder()
    task_decoders = {task_id: TaskDecoder() for task_id in tasks_list}

    global_step = tf.Variable(0, trainable=False)
    init_lr = shared_hyperparameters['learning_rate']
    optimizer = tf.keras.optimizers.SGD(init_lr, momentum=0.9, nesterov=False)
    '''mean squared error'''
    loss_fn = tf.keras.losses.BinaryCrossentropy()

    # print(f'optimizer variables = {len(optimizer.variables())},')
    # print(f'how many trainable variables = {len(shared_encoder.trainable_variables)}')


    # @tf.function
    def test_step(x_batch_test, y_batch_test):
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_test]
        predictions = [task_decoders[task_id](shared_rep, training=False) for task_id, shared_rep in
                       zip(tasks_list, shared_representations)]
        eval_losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_test, predictions)]
        eval_loss = tf.reduce_sum(eval_losses)
        return eval_loss, eval_losses, predictions


    '''new parts for estimation'''

    '''pass dummy data to get the gradients'''
    for batch_idx in range(0, len(train_data[0]), batch_size):
        x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_train]
        predictions = [task_decoders[task_id](shared_rep, training=False) for task_id, shared_rep in
                       zip(tasks_list, shared_representations)]
        break

    '''save best weights and model to a file'''
    timeStart = time.time()
    model_base_dir = f'{datasetName}_model_weights'
    gradients_dir = f'{datasetName}_gradients_run_{run}'

    model_dir = f'{model_base_dir}/run_{run}'
    # Check if the files exist and load the weights for shared_encoder
    encoder_weight_path = os.path.join(model_dir, 'shared_encoder')
    if os.path.exists(encoder_weight_path + '.index'):
        shared_encoder.load_weights(encoder_weight_path)
    else:
        print(f'Encoder weights not found at {encoder_weight_path}')

    # Check if the files exist and load the weights for task_decoders
    for task_id in tasks_list:
        decoder_weight_path = os.path.join(model_dir, f'decoder_{task_id}')
        if os.path.exists(decoder_weight_path + '.index'):
            task_decoders[task_id].load_weights(decoder_weight_path)
        else:
            print(f'Decoder weights not found at {decoder_weight_path}')

    grad_params = []
    for params in shared_encoder.trainable_weights:
        grad_params.append(params)

    gradient_dim = 0
    for param in grad_params:
        gradient_dim += param.numpy().size
    # print(f"len(grad_params): {len(grad_params)} \t Gradient Dim: {gradient_dim}")

    project_dim = 200  # project_dim = 202 (200 after dimensionality reduction and 2 task-individual parameters)
    project_matrix = (2 * np.random.randint(2, size=(gradient_dim, project_dim)) - 1).astype(float)
    project_matrix *= 1 / np.sqrt(project_dim)
    # print("Project Dim: {}".format(project_dim))

    # Load gradients
    gradients = []
    for task_id in tasks_list:
        gradient_file = f"{gradients_dir}/{task_id}_train_gradients.npy"
        tmp_gradients = np.load(gradient_file)
        gradients.append(tmp_gradients)
    # print(f'fold = {fold}, shape of gradients {gradients[0].shape}')
    gradients = np.concatenate(gradients, axis=0)

    # randomly assign labels as 0 or 1
    labels = np.random.binomial(n=1, p=0.7, size=gradients.shape[0])
    # if fold ==0:
    #     print(f'shape of gradients and labels = {gradients.shape}, {labels.shape}')

    # reverse the gradients for the 0 labels
    mask = np.copy(labels)
    mask[labels == 0] = -1
    mask = mask.reshape(-1, 1)
    gradients = gradients * mask
    train_num = int(len(gradients) * 0.8)
    train_gradients, train_labels = gradients[:train_num], labels[:train_num]
    test_gradients, test_labels = gradients[train_num:], labels[train_num:]

    # train a logistic regression model

    clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)  #
    clf.fit(gradients, labels)
    print(f'score = {clf.score(gradients, labels)}')

    # # Train a logistic regression model on the training set
    # clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4)
    # clf.fit(train_gradients, train_labels)
    #
    # # Evaluate on the test set
    # train_score = clf.score(train_gradients, train_labels)
    # test_score = clf.score(test_gradients, test_labels)
    # print(f'fold +{fold}, train_score = {train_score}, test_score = {test_score}')
    #
    # exit(0)

    ## %%
    # projection_matrix = np.load(f"./gradients/{args.dataset_key}_{args.model_key}_{args.preset_key}_{args.project_dim}/projection_matrix_{args.run}.npy")
    proj_coef = clf.coef_.copy().flatten().reshape(-1, 1)
    coef = project_matrix @ proj_coef.flatten()
    # print("L2 norm", np.linalg.norm(coef))
    coef = coef * 2 / np.linalg.norm(coef)

    # print("L2 norm", np.linalg.norm(coef))

    def generate_state_dict(model, state_dict, coef, removing_keys=["pred_head", "bn"]):
        # Convert coef to TensorFlow tensor if it's not already
        coef = tf.convert_to_tensor(coef, dtype=tf.float32)

        new_state_dict = {}
        cur_len = 0

        # Iterate over the model's trainable weights
        for param in model.weights:  # Only get the parameter (TensorFlow's structure)
            param_name = param.name

            # If the parameter matches removing keys, skip updating
            if any([rkey in param_name for rkey in removing_keys]):
                new_state_dict[param_name] = state_dict[
                    f'{param_name}'].copy()  # No `.numpy()` needed for dict access
            else:
                param_len = np.prod(param.shape)  # Get the number of elements in the param tensor

                # Reshape the portion of `coef` to the size of the param and update it
                new_param_value = state_dict[f'{param_name}'].copy() + \
                                  np.reshape(coef[cur_len:cur_len + param_len], param.shape)

                new_state_dict[param_name] = new_param_value
                cur_len += param_len

        return new_state_dict

    state_dict_encoder = {}
    for param in shared_encoder.weights:
        state_dict_encoder[f'{param.name}'] = param.numpy()

    state_dict_decoders = {}
    for task_id in tasks_list:
        state_dict_decoders[task_id] = {}
        for param in task_decoders[task_id].weights:
            state_dict_decoders[task_id][f'{param.name}'] = param.numpy()

    # print(f'len(state_dict_encoder) = {len(state_dict_encoder)}, len(state_dict_decoders) = {len(state_dict_decoders)}')
    '''how many param in new state dict'''

    param_count = 0
    for key in state_dict_encoder.keys():
        param_count += state_dict_encoder[key].size
    # print(f'param_count state_dict_encoder: {param_count}')

    # Update state dict for encoder
    new_state_dict_encoder = generate_state_dict(shared_encoder, state_dict_encoder, coef)
    # print(f'len(new_state_dict_encoder) = {len(new_state_dict_encoder)}')
    new_state_dict_decoders = {}
    for task_id in tasks_list:
        new_state_dict_decoders[task_id] = generate_state_dict(task_decoders[task_id], state_dict_decoders[task_id],
                                                              coef)
        # print(f'len(new_state_dict_decoders) = {len(new_state_dict_decoders)}')

    pretrain_state_dict_encoder = state_dict_encoder
    pretrain_state_dict_decoders = state_dict_decoders
    finetuned_state_dict_encoder = new_state_dict_encoder
    finetuned_state_dict_decoders = new_state_dict_decoders

    shared_encoder.set_weights([finetuned_state_dict_encoder[key] for key in pretrain_state_dict_encoder.keys()])
    for task_id in tasks_list:
        task_decoders[task_id].set_weights(
            [finetuned_state_dict_decoders[task_id][key] for key in pretrain_state_dict_decoders[task_id].keys()])

    results_loss = {task: [] for task in tasks_list}
    for estimate_run in range(5):
        # optimizer = tf.keras.optimizers.SGD(learning_rate=shared_hyperparameters['learning_rate'], momentum=0.9)
        # if name_suffix == 'Fine_Tuned':
        #     '''single estimate_run'''
        #     # batch_train_loss = {task: 0. for task in tasks_list}
        #     for batch_idx in range(0, len(train_data[0]), batch_size):
        #         x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
        #         y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
        #         train_loss, shared_weights, decoder_weights = train_step(x_batch_train, y_batch_train)
        #         # for task, loss in train_loss.items():
        #         #     batch_train_loss[task] += loss.numpy() / math.ceil(TRAIN_SIZE / batch_size)
        #     print(f'single estimate run ends here')
        '''evaluation'''
        test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
        indiv_losses = [each_loss.numpy() for each_loss in indiv_losses]
        test_metrics_loss = {tasks_list[task_idx]: metric for task_idx, metric in enumerate(indiv_losses)}
        for key, value in test_metrics_loss.items():
            results_loss[key].append(value)

    time_taken = time.time() - timeStart
    return results_loss, time_taken




if __name__ == "__main__":
    datasetName = 'Landmine'
    DataPath = f'../Dataset/{datasetName.upper()}/'
    import sys

    Method_name = 'SimpleMTL'
    # group_type = 'test'
    group_type = 'RANDOM'



    ResultPath = 'Results'

    import sys

    Task_InfoData = pd.read_csv(f'{DataPath}Task_Information_{datasetName}.csv', low_memory=False)
    TASKS = list(Task_InfoData['Task_Name'])
    TASKS = [str(task) for task in TASKS]
    # exit(0)

    task_len = {}
    variance_dict = {}
    std_dev_dict = {}
    dist_dict = {}
    Single_res_dict = {}
    STL_error = {}
    STL_AP = {}

    Arch_Name = 'Arch_1'
    if Arch_Name == 'Arch_1':
        initial_shared_architecture = {'shared_FF_Layers': 2, 'shared_FF_Neurons': [64, 32],
                                       'learning_rate': 1e-3}
        num_epochs = 1000
        batch_size = 64
        MAX_PATIENCE = 20
    import itertools

    if group_type == 'ALL':
        TASK_Group = [tuple(TASKS)]
        name_suffix = 'ALL'


    if group_type == 'RANDOM':
        for subsetNum in range(5):
            random_subsets = pd.read_csv(f'../RESULTS/{datasetName}_Random_Subsets_for_GRADTAE_run_{subsetNum}.csv', low_memory=False)
            print(len(random_subsets))
            TASK_Group = list(random_subsets['Random_Subsets'])
            TASK_Group = [ast.literal_eval(tg) for tg in TASK_Group]
            name_suffix = 'Fast_Estimation_w_Time'

            RUNS = [1, 2, 3, 4,5]
            for run in RUNS:
                seed_value = run
                tf.random.set_seed(seed_value)
                np.random.seed(seed_value)
                random.seed(seed_value)
                Task_group = []
                Total_Loss = []
                Individual_Group_Score = []
                Individual_Error_Rate = []
                Individual_AP = []
                Number_of_Groups = []
                Individual_Task_Score = []
                Time_taken_for_Fast_training = []
                Prev_Groups = {}
                for count in range(len(TASK_Group)):
                    print(f'Initial Training for {datasetName}-partition {count}, {TASK_Group[count]}')
                    task_group = TASK_Group[count]


                    args_tasks = []
                    group_score = {}
                    tmp_task_score = []


                    data_param_dict_for_specific_task = data_preparation(task_group)
                    all_scores,time_taken = estimation_func(initial_shared_architecture, task_group, data_param_dict_for_specific_task)
                    print(all_scores)

                    task_loss = {task: [] for task in task_group}

                    results_loss = all_scores
                    for task, loss in results_loss.items():
                        task_loss[task].append(loss)

                    task_loss_mean = {task: np.mean(loss) for task, loss in task_loss.items()}

                    tot_loss = np.sum(list(task_loss_mean.values()))

                    print(f'tot_loss = {tot_loss}')
                    print(f'group_score = {group_score}')

                    # exit(0)
                    Task_group.append(task_group)
                    Total_Loss.append(tot_loss)
                    Individual_Task_Score.append(copy.deepcopy(task_loss_mean))
                    Time_taken_for_Fast_training.append(time_taken/60)
                    # print(Individual_Group_Score)

                    print(len(Total_Loss), len(Task_group), len(Individual_Task_Score), len(Individual_Error_Rate))
                    # exit(0)

                    temp_res = pd.DataFrame({'Total_Loss': Total_Loss,
                                             'Task_group': Task_group,
                                             'Individual_Task_Score': Individual_Task_Score,
                                             'Time_taken_for_Fast_training': Time_taken_for_Fast_training
                                             })

                    temp_res.to_csv(f'{ResultPath}/{datasetName}_{name_suffix}_run_{run}_SGD_Arch_{Arch_Name}_subsetNum_{subsetNum}.csv',
                                        index=False)