import pandas as pd
import copy
import numpy as np
import math
import os
import time
import tqdm
import sys

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

os.environ["OMP_NUM_THREADS"] = "1"  # OpenMP
os.environ["MKL_NUM_THREADS"] = "1"  # Intel Math Kernel Library
os.environ["NUMEXPR_NUM_THREADS"] = "1"  # NumExpr
os.environ["OPENBLAS_NUM_THREADS"] = "1"  # OpenBLAS
import random
import ast
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

USE_GPU = False
if USE_GPU:
    device_idx = 0
    gpus = tf.config.list_physical_devices('GPU')
    gpu_device = gpus[device_idx]
    core_config = tf.config.experimental.set_visible_devices(gpu_device, 'GPU')
    tf.config.experimental.set_memory_growth(gpu_device, True)
    tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=core_config))
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def repeat_Samples(samples_to_be_repeated, X_train, y_train):
    random_indices = np.random.choice(X_train.shape[0], samples_to_be_repeated)
    X_train = np.concatenate((X_train, X_train[random_indices]), axis=0)
    y_train = np.concatenate((y_train, y_train[random_indices]), axis=0)
    return X_train, y_train


def data_preparation(tasks_list):
    data_param_dict_for_specific_task = {}

    lengths = []
    for sch_id in tasks_list:
        X_train = np.load(f'{DataPath}/{sch_id}_X_train.npy')
        lengths.append(X_train.shape[0])

    if len(tasks_list) > 1:
        max_size = max(lengths)
        if max_size % 2 == 0:
            max_size += 1

        print(f'max size = {max_size}')

    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        y_train = np.load(f'{DataPath}/{task_id}_y_train.npy')
        X_test = np.load(f'{DataPath}/{task_id}_X_test.npy')
        y_test = np.load(f'{DataPath}/{task_id}_y_test.npy')

        if len(tasks_list) > 1:
            samples_to_be_repeated = max_size - len(X_train)

            if samples_to_be_repeated > 0:
                # print(f'X_train: {X_train.shape}, samples_to_be_repeated = {samples_to_be_repeated}')
                X_train, y_train = repeat_Samples(samples_to_be_repeated, X_train, y_train)

        data_param_dict_for_specific_task.update({f'School_{task_id}_X_train': X_train})
        data_param_dict_for_specific_task.update({f'School_{task_id}_y_train': y_train})
        data_param_dict_for_specific_task.update({f'School_{task_id}_X_test': X_test})
        data_param_dict_for_specific_task.update({f'School_{task_id}_y_test': y_test})

    return data_param_dict_for_specific_task


def decay_lr(step, optimizer):
    if (step + 1) % 75 == 0:
        optimizer.lr = optimizer.lr / 2.
        # print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))

def final_model(shared_hyperparameters, tasks_list, data_param_dict_for_specific_task, val=False):
    train_data = []
    train_label = []
    test_data = []
    test_label = []
    if val:
        val_data = []
        val_label = []

    for sch_id in tasks_list:
        X_train_full = data_param_dict_for_specific_task[f'School_{sch_id}_X_train']
        y_train_full = data_param_dict_for_specific_task[f'School_{sch_id}_y_train']

        if val:
            # Add a small validation split from the training data
            X_train, X_val, y_train, y_val = train_test_split(
                X_train_full, y_train_full, test_size=0.1, random_state=42)
            val_data.append(X_val)
            val_label.append(y_val)
        else:
            X_train = data_param_dict_for_specific_task[f'School_{sch_id}_X_train']
            y_train = data_param_dict_for_specific_task[f'School_{sch_id}_y_train']

        train_data.append(X_train)
        train_label.append(y_train)
        test_data.append(data_param_dict_for_specific_task[f'School_{sch_id}_X_test'])
        test_label.append(data_param_dict_for_specific_task[f'School_{sch_id}_y_test'])

    class SharedEncoder(tf.keras.Model):
        def __init__(self):
            super(SharedEncoder, self).__init__()
            self.shared_layers = []
            for h in range(shared_hyperparameters['shared_FF_Layers']):
                self.shared_layers.append(Dense(shared_hyperparameters['shared_FF_Neurons'][h], activation='relu'))

        def call(self, inputs):
            x = inputs
            for layer in self.shared_layers:
                x = layer(x)
            return x

    class TaskDecoder(tf.keras.Model):
        def __init__(self):
            super(TaskDecoder, self).__init__()
            self.output_layer = Dense(1, activation='tanh')

        def call(self, shared_representation):
            return self.output_layer(shared_representation)

    # Create instances of the Shared Encoder and Task Decoders
    shared_encoder = SharedEncoder()
    task_decoders = {sch_id: TaskDecoder() for sch_id in tasks_list}

    global_step = tf.Variable(0, trainable=False)
    init_lr = shared_hyperparameters['learning_rate']
    optimizer = tf.keras.optimizers.SGD(init_lr, momentum=0.9, nesterov=False)
    '''mean squared error'''
    loss_fn = tf.keras.losses.MeanSquaredError()

    @tf.function
    def train_step(x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            shared_representations = [shared_encoder(input_data, training=True) for input_data in x_batch_train]
            predictions = [task_decoders[sch_id](shared_rep, training=True) for sch_id, shared_rep in
                           zip(tasks_list, shared_representations)]
            losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
            losses_dict = {task: loss for task, loss in zip(tasks_list, losses)}
            tot_loss = tf.reduce_sum(losses)

        gradients = tape.gradient(tot_loss, shared_encoder.trainable_variables + sum(
            [decoder.trainable_variables for decoder in task_decoders.values()], []))

        optimizer.apply_gradients(zip(gradients, shared_encoder.trainable_variables + sum(
            [decoder.trainable_variables for decoder in task_decoders.values()], [])))

        original_shared_weights = [tf.identity(weight) for weight in shared_encoder.trainable_weights]
        original_decoder_weights = {sch_id: [tf.identity(weight) for weight in decoder.trainable_weights]
                                    for sch_id, decoder in task_decoders.items()}

        '''check how many trainable parameters are trainable'''
        # print(f'trainable variables = {len(optimizer.variables())}')
        # print(f'how many trainable variables = {len(shared_encoder.trainable_variables)}')
        # if fold == 0:
        #     print(shared_encoder.summary())
        return tot_loss, original_shared_weights, original_decoder_weights

    @tf.function
    def test_step(x_batch_test, y_batch_test):
        shared_representations = [shared_encoder(input_data, training=False) for input_data in x_batch_test]
        predictions = [task_decoders[sch_id](shared_rep, training=False) for sch_id, shared_rep in
                       zip(tasks_list, shared_representations)]
        eval_losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_test, predictions)]
        eval_loss = tf.reduce_sum(eval_losses)
        return eval_loss, eval_losses, predictions

    @tf.function
    def train_step_ITA(x_batch_train, y_batch_train, first_step=False):  # per-batch calculation
        task_gains = {task: {task: {} for task in TASKS}
                      for task in TASKS}

        '''can't see the output since it's a tf.function'''
        # print(f'first_step = {first_step} at epoch, batch_idx = {epoch, batch_idx}, {len(optimizer.variables())}')

        with tf.GradientTape(persistent=True) as tape:
            shared_representations = [shared_encoder(input_data, training=True) for input_data in x_batch_train]
            predictions = [task_decoders[sch_id](shared_rep, training=True) for sch_id, shared_rep in
                           zip(tasks_list, shared_representations)]
            losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
            losses_dict = {task: loss for task, loss in zip(TASKS, losses)}
            tot_loss = tf.reduce_sum(losses)

            single_task_specific_gradients = [
                (single_task, tape.gradient(losses_dict[single_task], shared_encoder.trainable_weights)) for
                single_task in TASKS]

        # Compute for regular model update
        all_tasks_gradients = [tf.add_n([task_gradient[i] for _, task_gradient in single_task_specific_gradients])
                               for i in range(len(shared_encoder.trainable_weights))]

        before_update_losses = {task: loss for task, loss in losses_dict.items()}
        original_shared_weights = [tf.identity(weight) for weight in shared_encoder.trainable_weights]
        original_decoder_weights = {sch_id: [tf.identity(weight) for weight in decoder.trainable_weights]
                                    for sch_id, decoder in task_decoders.items()}

        # print(f'BEFORE ITA_baseline: first parameter of shared_encoder = {shared_encoder.trainable_weights[0][0][:5]}')
        # print(f'before_update_losses = {before_update_losses}')

        for base_task, task_grad in tqdm.tqdm(single_task_specific_gradients):
            if first_step:
                # Regular update for the first step
                base_update = [optimizer.lr * grad for grad in task_grad]
                base_updated = [param - update for param, update in zip(shared_encoder.trainable_weights, base_update)]
            else:
                # Momentum-based update for later steps
                base_update = [(optimizer._momentum * optimizer.get_slot(param, 'momentum') - optimizer.lr * grad)
                               for param, grad in zip(shared_encoder.trainable_weights, task_grad)]
                base_updated = [param + update for param, update in zip(shared_encoder.trainable_weights, base_update)]

            # Recompute representation and losses using updated base (base_updated)

            # Temporarily update shared encoder weights with base_updated for ITA_baseline computation
            for original_param, updated_param in zip(shared_encoder.trainable_weights, base_updated):
                original_param.assign(updated_param)

            shared_representations = [shared_encoder(input_data) for input_data in x_batch_train]
            predictions = [task_decoders[sch_id](shared_rep, training=True) for sch_id, shared_rep in
                           zip(tasks_list, shared_representations)]

            after_update_losses_list = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
            after_update_losses = {task: loss for task, loss in zip(TASKS, after_update_losses_list)}

            '''Compute task gain'''
            task_gain = {
                second_task: (1.0 - after_update_losses[second_task] / before_update_losses[second_task]) / optimizer.lr
                for second_task in TASKS}
            task_gains[base_task] = task_gain

            # Revert shared encoder weights back to the original parameters after ITA_baseline computation
            for original_param, updated_param in zip(shared_encoder.trainable_weights, original_shared_weights):
                original_param.assign(updated_param)

        # '''ITA_baseline Calculation'''
        # for base_task,task_grad in single_task_specific_gradients:
        #     '''update the model with the gradient of the base task'''
        #     # print(f'base_task = {base_task}')
        #     optimizer.apply_gradients(zip(task_grad, shared_encoder.trainable_weights))
        #
        #     '''Compute after-update losses for all tasks'''
        #     shared_representations = [shared_encoder(input_data) for input_data in x_batch_train]
        #     predictions = [task_decoders[sch_id](shared_rep, training=True) for sch_id, shared_rep in
        #                    zip(tasks_list, shared_representations)]
        #     after_update_losses_list = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
        #     after_update_losses = {task: loss for task, loss in zip(TASKS, after_update_losses_list)}
        #
        #     '''Compute task gain'''
        #     task_gain = {second_task: (1.0 - after_update_losses[second_task] / before_update_losses[second_task]) / optimizer.lr for second_task in TASKS}
        #     task_gains[base_task] = task_gain
        #
        #     '''revert back the model to the previous state'''
        #     for original_weight, weight in zip(original_shared_weights, shared_encoder.trainable_weights):
        #         weight.assign(original_weight)
        #
        #     for sch_id, original_weights in original_decoder_weights.items():
        #         for original_weight, weight in zip(original_weights, task_decoders[sch_id].trainable_weights):
        #             weight.assign(original_weight)
        # print(f'done ITA calc')
        '''revert back the model to the previous state'''
        for original_weight, weight in zip(original_shared_weights, shared_encoder.trainable_weights):
            weight.assign(original_weight)

        for sch_id, original_weights in original_decoder_weights.items():
            for original_weight, weight in zip(original_weights, task_decoders[sch_id].trainable_weights):
                weight.assign(original_weight)

        '''apply regular model updates'''

        for task, decoder in task_decoders.items():
            task_grads = tape.gradient(losses_dict[task], decoder.trainable_weights)
            optimizer.apply_gradients(zip(task_grads, decoder.trainable_weights))
        # print('update')
        # all_grads = tape.gradient(tot_loss, shared_encoder.trainable_weights)
        optimizer.apply_gradients(zip(all_tasks_gradients, shared_encoder.trainable_weights))

        '''save the original weights'''
        original_shared_weights = [tf.identity(weight) for weight in shared_encoder.trainable_weights]
        original_decoder_weights = {sch_id: [tf.identity(weight) for weight in decoder.trainable_weights]
                                    for sch_id, decoder in task_decoders.items()}

        del tape
        return tot_loss, task_gains, original_shared_weights, original_decoder_weights

    Patience = 20

    min_loss_to_consider = math.inf

    TRAIN_SIZE = len(train_data[0])
    # print(f'TRAIN_SIZE = {TRAIN_SIZE}')

    gradient_metrics = {task: [] for task in tasks_list}

    timeStart = time.time()
    velocity_trackers = {}
    print(f'total batches {TRAIN_SIZE / batch_size}')
    TRAIN_LOSS = []
    VAL_LOSS = []
    TEST_LOSS = []
    for epoch in range(num_epochs):
        if epoch > 100:
            decay_lr(epoch, optimizer)

        batch_grad_metrics = {combined_task: {task: 0. for task in tasks_list} for combined_task in
                              gradient_metrics}

        for batch_idx in range(0, len(train_data[0]), batch_size):
            x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
            y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]

            if ITA_baseline:
                if Method_name == 'ITA':
                    train_loss, task_gains, shared_weights, decoder_weights = train_step_ITA(x_batch_train,
                                                                                             y_batch_train, first_step=(
                                len(optimizer.variables()) == 0))
                    # print(f"batch_idx = {batch_idx}\tLoss: {train_loss.numpy()}")

                # Record batch-level training and gradient metrics.
                for first_task, task_gain_map in task_gains.items():
                    for second_task, gain in task_gain_map.items():
                        # print(f'first_task = {first_task}\tsecond_task = {second_task}\tgain = {gain}')
                        batch_grad_metrics[first_task][second_task] += gain.numpy() / (
                            math.ceil(TRAIN_SIZE / batch_size))
                        # print(f'first_task = {first_task}\tsecond_task = {second_task}\tgain = {gain.numpy()}, batch_grad_metrics = {batch_grad_metrics[first_task][second_task]}')

            else:
                train_loss, shared_weights, decoder_weights = train_step(x_batch_train, y_batch_train)
                # print(f"batch_idx = {batch_idx}\tLoss: {loss.numpy()}")

            # print(f"Loss: {loss.numpy()}")

        # print(f'One epoch done')
        # for source_task, task_gain_map in batch_grad_metrics.items():
        #     print(f'source_task = {source_task}\ttask_gain_map = {task_gain_map}')
        # exit(0)

        if val:
            ### Validation can be done here if needed, by evaluating on `val_data` and `val_label`
            val_loss, _, _ = test_step(val_data, val_label)
            loss_to_consider = val_loss
        else:
            loss_to_consider = train_loss
        if epoch % 20 == 0:
            print(f'Epoch {epoch + 1}/{num_epochs}, loss = {train_loss.numpy()}, Patience = {Patience}')

        TRAIN_LOSS.append(train_loss.numpy())
        test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
        TEST_LOSS.append(test_loss.numpy())
        if loss_to_consider.numpy() < min_loss_to_consider:
            min_loss_to_consider = min(min_loss_to_consider, loss_to_consider.numpy())
            Patience = 20
            best_shared_weights = copy.deepcopy(shared_weights)
            best_decoder_weights = copy.deepcopy(decoder_weights)
        else:
            Patience -= 1
            if Patience == 0:
                print(f'Stopping Training at Epoch {epoch + 1}')
                break
        # if epoch % 20 == 0:
        #     for base_task, task_gain_map in batch_grad_metrics.items():
        #         print(f'base_task = {base_task}\tgain = {task_gain_map}')
        # exit(0)

        if ITA_baseline:
            # Record epoch-level training and gradient metrics.
            for combined_task, task_gain_map in batch_grad_metrics.items():
                gradient_metrics[combined_task].append(task_gain_map)

            if epoch % 100 == 0:
                print(f'epoch {epoch}, gradient_metrics = {len(gradient_metrics)}')

    time_taken = time.time() - timeStart
    # print(f'gradient_metrics = {gradient_metrics}')
    # load the original model
    for best_weight, curr_weight in zip(best_shared_weights, shared_encoder.trainable_weights):
        curr_weight.assign(best_weight)

    for sch_id, decoder_specific_weights in best_decoder_weights.items():
        for best_weight, curr_weight in zip(decoder_specific_weights, task_decoders[sch_id].trainable_weights):
            curr_weight.assign(best_weight)

    print(f'stopping training at epoch {epoch + 1}')

    def save_gradients():
        # Set the model to evaluation mode in TensorFlow
        # (In TF, there's no need for an explicit eval mode like PyTorch, just ensure dropout/batchnorm layers are in inference mode)
        task_gradients = {task: [] for task in tasks_list}

        for batch_idx in range(0, len(train_data[0]), batch_size):
            x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
            y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]
            # optimizer = new_optimizer
            with tf.GradientTape(persistent=True) as tape:
                shared_representations = [shared_encoder(input_data, training=True) for input_data in x_batch_train]
                predictions = [task_decoders[molecule](shared_rep, training=True) for molecule, shared_rep in
                               zip(tasks_list, shared_representations)]
                losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
                losses_dict = {task: loss for task, loss in zip(tasks_list, losses)}
                tot_loss = tf.reduce_sum(losses)

                # Compute the gradient of the task-specific loss w.r.t. the shared base.
                single_task_specific_gradients = [
                    (single_task, tape.gradient(losses_dict[single_task], shared_encoder.trainable_weights)) for
                    single_task in tasks_list]

            for task, tmp_gradients in single_task_specific_gradients:
                # for tmp_gradients in grads:
                #     '''flatten and concatenate gradients'''
                #     tmp_gradients = tf.concat([tf.reshape(g, [-1]) for g in tmp_gradients], axis=0).numpy()
                #     tmp_gradients = (tmp_gradients.reshape(1, -1) @ project_matrix).flatten()
                #     task_gradients[task].append(tmp_gradients)

                # for tmp_gradients in grads:
                '''flatten and concatenate gradients'''
                tmp_gradients = tf.concat([tf.reshape(g, [-1]) for g in tmp_gradients], axis=0).numpy()

                # Debug print to check the dimensions
                # print(f"tmp_gradients shape: {tmp_gradients.shape}")
                # print(f"project_matrix shape: {project_matrix.shape}")

                if tmp_gradients.size != project_matrix.shape[0]:
                    raise ValueError(
                        f"Gradient size {tmp_gradients.size} does not match expected size {project_matrix.shape[0]}")

                tmp_gradients = (tmp_gradients.reshape(1, -1) @ project_matrix).flatten()

                task_gradients[task].append(tmp_gradients)

        for task_name, gradients in task_gradients.items():
            np.save(f"{gradients_dir}/{task_name}_train_gradients.npy", gradients)

        del tape  # Clean up the persistent GradientTape

    '''new parts'''
    if len(tasks_list) == len(TASKS):
        '''save best weights and model to a file'''
        model_base_dir = f'{datasetName}_model_weights'
        if not os.path.exists(model_base_dir):
            os.makedirs(model_base_dir)

        gradients_dir = f'{datasetName}_gradients_run_{run}'
        if not os.path.exists(gradients_dir):
            os.makedirs(gradients_dir)

        model_dir = f'{model_base_dir}/run_{run}'
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        shared_encoder.save_weights(f'{model_dir}/shared_encoder')
        for molecule, decoder in task_decoders.items():
            decoder.save_weights(f'{model_dir}/decoder_{molecule}')

        grad_params = []
        for params in shared_encoder.trainable_weights:
            grad_params.append(params)

        print(f'len(grad_params): {len(grad_params)}', end=' ')
        gradient_dim = 0
        for param in grad_params:
            gradient_dim += param.numpy().size
        print("Gradient Dim: {}".format(gradient_dim), end=' ')

        project_dim = 200
        project_matrix = (2 * np.random.randint(2, size=(gradient_dim, project_dim)) - 1).astype(float)
        project_matrix *= 1 / np.sqrt(project_dim)
        print("Project Dim: {}".format(project_dim))

        # Save gradients
        start_time_grad = time.time()
        save_gradients()
        end_time_grad = time.time()
        # print(f"Time taken for train gradients: {end_time_grad - start_time_grad}")

    test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
    indiv_losses = [each_loss.numpy() for each_loss in indiv_losses]

    print(f'test_loss = {test_loss}')
    if ITA_baseline:
        if Method_name == 'ITA':  # Fifty's method--TAG
            ita_file = f'{ResultPath}/ITA/gradient_metrics_{Method_name}_run_{run}_Arch_{Arch_Name}_{name_suffix}.csv'

        with open(ita_file, 'w') as f:
            for key in gradient_metrics.keys():
                f.write("%s,%s\n" % (key, gradient_metrics[key]))
        f.close()

    return test_loss.numpy(), indiv_losses, time_taken


if __name__ == "__main__":
    datasetName = 'School'

    import sys

    w_momentum = False
    ITA_baseline = 0
    if ITA_baseline:
        Method_name = 'ITA'  # sys.argv[1]
        group_len = 'ALL'
    else:
        Method_name = 'SimpleMTL'
        group_len = int(sys.argv[1])  # 1 for STL, 2 for Pairs

        if len(sys.argv) > 2:
            part = sys.argv[2]

    ResultPath = '../RESULTS/GROUPS_MTL/'

    DataPath = f'../Dataset/{datasetName.upper()}/'
    SchoolData = pd.read_csv(f'{DataPath}Task_Information_School.csv', low_memory=False)
    DataPath = f'../Dataset/{datasetName.upper()}/Task_Splits/'
    TASKS = list(SchoolData['Task_Name'])
    TASKS = [str(task) for task in TASKS]

    task_len = {}
    variance_dict = {}
    std_dev_dict = {}
    dist_dict = {}
    Single_res_dict = {}
    STL_error = {}
    STL_AP = {}

    num_folds = 10

    Arch_Name = 'Arch_1'
    if Arch_Name == 'Arch_1':
        num_epochs = 1000
        batch_size = 64
        MAX_PATIENCE = 20
        initial_shared_architecture = {'shared_FF_Layers': 3,
                                       'shared_FF_Neurons': [20, 10, 32],
                                       'learning_rate': 0.005}

    import itertools

    if ITA_baseline:
        TASK_Group = [tuple(TASKS)]

    '''Pairs'''
    if group_len == 2:  # 2 for pairs
        pairs = list(itertools.combinations(TASKS, group_len))
        TASK_Group = pairs
        name_suffix = 'pairs'
        tot_len = len(TASK_Group)

    '''STL'''
    if group_len == 1:
        Tasks_tuples = [tuple([task]) for task in TASKS]
        TASK_Group = Tasks_tuples
        name_suffix = 'STL'

    '''Groups of random number of tasks'''
    if group_len !='ALL' and group_len >=3:
        groups = list(itertools.combinations(TASKS, group_len))
        TASK_Group = groups
        name_suffix = f'G{group_len}'

    if group_len == 'ALL':
        TASK_Group = [tuple(TASKS)]
        name_suffix = 'ALL'

    if group_len == 'GroundTruth':
        random_subsets = pd.read_csv(f'../RESULTS/{datasetName}_Random_Subsets_for_GroundTruth_New.csv',
                                     low_memory=False)
        print(len(random_subsets))
        TASK_Group = list(random_subsets['Random_Subsets'])
        TASK_Group = [ast.literal_eval(grp) for grp in TASK_Group]
        print(f'Total Groups : {len(TASK_Group)}')

        tot_len = len(TASK_Group)
        name_suffix = 'GroundTruth_NEW'

        print(f'part : {part}, total Groups : {len(TASK_Group)}')

    # '''ALL'''
    # TASK_Group = [tuple(TASKS)]
    # name_suffix = 'ALL'

    RUNS = [1, 2, 3, 4, 5, 6]
    for run in RUNS:
        seed_value = run
        tf.random.set_seed(seed_value)
        np.random.seed(seed_value)
        random.seed(seed_value)

        Task_group = []
        Total_Loss = []
        Individual_Group_Score = []
        Individual_Error_Rate = []
        Individual_AP = []
        Number_of_Groups = []
        Individual_Task_Score = []
        Time_taken_for_training = []
        Prev_Groups = {}
        print(f'Total Groups : {len(TASK_Group)}')
        for count in range(len(TASK_Group)):
            print(f'Initial Training for {datasetName}-partition {count}, {TASK_Group[count]}')
            task_group = TASK_Group[count]

            args_tasks = []
            group_score = {}
            tmp_task_score = []

            data_param_dict_for_specific_task = data_preparation(task_group)
            all_scores = final_model(initial_shared_architecture, task_group, data_param_dict_for_specific_task,
                                     val=True)

            tot_loss, indi_scores = all_scores[0], all_scores[1]
            Total_time = all_scores[2]

            print(f'tot_loss = {tot_loss}')
            task_scores = {}
            for idx, task in enumerate(task_group):
                task_scores[task] = indi_scores[idx]

            Task_group.append(task_group)
            Total_Loss.append(tot_loss)
            Individual_Task_Score.append(copy.deepcopy(task_scores))

            Time_taken_for_training.append(np.mean(Total_time) / 60)

            print(len(Total_Loss), len(Task_group), len(Individual_Task_Score), len(Individual_Error_Rate))

            temp_res = pd.DataFrame({'Total_Loss': Total_Loss,
                                     'Task_group': Task_group,
                                     'Individual_Task_Score': Individual_Task_Score,
                                     'Time_taken_for_training': Time_taken_for_training,
                                     })
            if ITA_baseline:
                temp_res.to_csv(
                    f'{ResultPath}/{datasetName}_SimpleMTL_{Method_name}_run_{run}_SGD_Arch_{Arch_Name}.csv',
                    index=False)

                print(f'total_time = {Total_time}')
                print(f'avg time in minutes = {np.mean(Total_time) / 60}')

                '''save time to txt file'''
                timefile = f'{ResultPath}/School_{Method_name}_time_run_{run}_SGD_Arch_{Arch_Name}_{name_suffix}.txt'
                with open(timefile, 'a') as f:
                    f.write(f'{np.mean(Total_time) / 60}\n')
                    f.write(f'{Total_time}\n')

                f.close()
            else:
                temp_res.to_csv(f'{ResultPath}/{datasetName}_{name_suffix}_run_{run}_SGD_Arch_{Arch_Name}.csv',
                                index=False)



