import pandas as pd
import copy
import numpy as np
import math
import os
import time
import tqdm

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
#
os.environ["OMP_NUM_THREADS"] = "1"  # OpenMP
os.environ["MKL_NUM_THREADS"] = "1"  # Intel Math Kernel Library
os.environ["NUMEXPR_NUM_THREADS"] = "1"  # NumExpr
os.environ["OPENBLAS_NUM_THREADS"] = "1"  # OpenBLAS
import random
import ast
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

USE_GPU = False
if USE_GPU:
    device_idx = 0
    gpus = tf.config.list_physical_devices('GPU')
    gpu_device = gpus[device_idx]
    core_config = tf.config.experimental.set_visible_devices(gpu_device, 'GPU')
    tf.config.experimental.set_memory_growth(gpu_device, True)
    tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=core_config))
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def repeat_Samples(samples_to_be_repeated, X_train, y_train):
    # Step 1: Separate indices by class
    class0_indices = np.where(y_train == 0)[0]
    class1_indices = np.where(y_train == 1)[0]
    # print(f'TRAIN: [0,1] : [{len(class0_indices), len(class1_indices)}]')

    # Step 2: Compute how many samples to repeat from each class
    repeat_per_class = samples_to_be_repeated // 2

    # If odd, one sample will be left out — you can choose what to do with that if needed

    # Step 3: Randomly sample with replacement
    sampled_class0 = np.random.choice(class0_indices, repeat_per_class, replace=True)
    sampled_class1 = np.random.choice(class1_indices, repeat_per_class, replace=True)

    # Step 4: Combine and concatenate
    sampled_indices = np.concatenate([sampled_class0, sampled_class1])
    X_repeat = X_train[sampled_indices]
    y_repeat = y_train[sampled_indices]

    # Final step: Concatenate to training set
    X_train = np.concatenate((X_train, X_repeat), axis=0)
    y_train = np.concatenate((y_train, y_repeat), axis=0)

    return X_train, y_train


def data_preparation(tasks_list):
    data_param_dict_for_specific_task = {}
    DataPath = f'../Dataset/{datasetName.upper()}/Task_Splits'
    lengths = []
    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        lengths.append(X_train.shape[0])

    if len(tasks_list) > 1:
        max_size = max(lengths)
        if max_size % 2 == 0:
            max_size += 1

        print(f'max size = {max_size}')

    for task_id in tasks_list:
        X_train = np.load(f'{DataPath}/{task_id}_X_train.npy')
        y_train = np.load(f'{DataPath}/{task_id}_y_train.npy')
        X_test = np.load(f'{DataPath}/{task_id}_X_test.npy')
        y_test = np.load(f'{DataPath}/{task_id}_y_test.npy')

        if len(tasks_list) > 1:
            samples_to_be_repeated = max_size - len(X_train)

            if samples_to_be_repeated > 0:
                # print(f'X_train: {X_train.shape}, samples_to_be_repeated = {samples_to_be_repeated}')
                X_train, y_train = repeat_Samples(samples_to_be_repeated, X_train, y_train)

        # print(f'unique labels in train: {(np.unique(y_train))}, unique labels in test: {(np.unique(y_test))}')
        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)
        # print(f'shape of y_train = {y_train.shape}')

        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_X_train': X_train})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_y_train': y_train})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_X_test': X_test})
        data_param_dict_for_specific_task.update({f'Landmine_{task_id}_y_test': y_test})

    return data_param_dict_for_specific_task


def decay_lr(step, optimizer):
    if (step + 1) % 75 == 0:
        optimizer.lr = optimizer.lr / 2.
        # print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))


def final_model(shared_hyperparameters, tasks_list, data_param_dict_for_specific_task, Method_name, val=True):
    patience = MAX_PATIENCE
    train_data = []
    train_label = []
    val_data = []
    val_label = []
    test_data = []
    test_label = []

    for task_id in tasks_list:
        X_train_full = data_param_dict_for_specific_task[f'Landmine_{task_id}_X_train']
        y_train_full = data_param_dict_for_specific_task[f'Landmine_{task_id}_y_train']

        # Add a small validation split from the training data
        X_train, X_val, y_train, y_val = train_test_split(
            X_train_full, y_train_full, test_size=0.1, random_state=42, stratify=y_train_full
        )

        train_data.append(X_train)
        train_label.append(y_train)
        val_data.append(X_val)
        val_label.append(y_val)

        test_data.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_X_test'])
        test_label.append(data_param_dict_for_specific_task[f'Landmine_{task_id}_y_test'])

    class SharedEncoder(tf.keras.Model):
        def __init__(self):
            super(SharedEncoder, self).__init__()
            self.shared_layers = []
            for h in range(shared_hyperparameters['shared_FF_Layers']):
                self.shared_layers.append(Dense(shared_hyperparameters['shared_FF_Neurons'][h], activation='relu'))

        def call(self, inputs):
            x = inputs
            for layer in self.shared_layers:
                x = layer(x)
            return x

    class TaskDecoder(tf.keras.Model):
        def __init__(self):
            super(TaskDecoder, self).__init__()
            self.output_layer = Dense(1, activation='sigmoid')

        def call(self, shared_representation):
            return self.output_layer(shared_representation)

    # Create instances of the Shared Encoder and Task Decoders
    shared_encoder = SharedEncoder()
    task_decoders = {task_id: TaskDecoder() for task_id in tasks_list}

    global_step = tf.Variable(0, trainable=False)
    init_lr = shared_hyperparameters['learning_rate']
    optimizer = tf.keras.optimizers.SGD(lr=init_lr, momentum=0.9)
    '''mean squared error'''
    # loss_fn = tf.keras.losses.BinaryCrossentropy()
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

    @tf.function
    def train_step(x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            # Batch all inputs
            x_all = tf.concat(x_batch_train, axis=0)
            shared_representations = shared_encoder(x_all, training=True)

            splits = tf.split(shared_representations, num_or_size_splits=[tf.shape(x)[0] for x in x_batch_train])
            predictions = [
                task_decoders[task_id](rep, training=True)
                for task_id, rep in zip(tasks_list, splits)
            ]

            # Ensure label shapes match
            y_batch_train = [tf.reshape(y, (-1, 1)) for y in y_batch_train]
            losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_train, predictions)]
            tot_loss = tf.reduce_sum(losses)

        original_shared_weights = [tf.identity(weight) for weight in shared_encoder.trainable_weights]
        original_decoder_weights = {task_id: [tf.identity(weight) for weight in decoder.trainable_weights]
                                    for task_id, decoder in task_decoders.items()}

        all_vars = shared_encoder.trainable_variables + sum(
            [decoder.trainable_variables for decoder in task_decoders.values()], [])
        gradients = tape.gradient(tot_loss, all_vars)

        optimizer.apply_gradients(zip(gradients, all_vars))

        return tot_loss, original_shared_weights, original_decoder_weights

    @tf.function
    def test_step(x_batch_test, y_batch_test):
        x_all = tf.concat(x_batch_test, axis=0)
        shared_representations = shared_encoder(x_all, training=False)
        splits = tf.split(shared_representations, num_or_size_splits=[tf.shape(x)[0] for x in x_batch_test])
        predictions = [
            task_decoders[task_id](rep, training=False)
            for task_id, rep in zip(tasks_list, splits)
        ]
        y_batch_test = [tf.reshape(y, (-1, 1)) for y in y_batch_test]
        eval_losses = [loss_fn(y_true, y_pred) for y_true, y_pred in zip(y_batch_test, predictions)]
        eval_loss = tf.reduce_sum(eval_losses)
        return eval_loss, eval_losses, predictions

    min_valid_loss = math.inf

    TRAIN_SIZE = len(train_data[0])

    gradient_metrics = {task: [] for task in tasks_list}

    timeStart = time.time()
    for epoch in range(num_epochs):
        if epoch > 100:
            decay_lr(epoch, optimizer)

        batch_grad_metrics = {combined_task: {task: 0. for task in tasks_list} for combined_task in
                              gradient_metrics}

        for batch_idx in range(0, len(train_data[0]), batch_size):
            x_batch_train = [data[batch_idx:batch_idx + batch_size] for data in train_data]
            y_batch_train = [label[batch_idx:batch_idx + batch_size] for label in train_label]

            train_loss, shared_weights, decoder_weights = train_step(x_batch_train, y_batch_train)

        # Validation can be done here if needed, by evaluating on `val_data` and `val_label`
        val_loss, _, _ = test_step(val_data, val_label)
        if epoch % 5 == 0:
            # print(
            #     f'Epoch {epoch + 1}/{num_epochs}, loss = {train_loss.numpy()}')
            print(
                f'Epoch {epoch + 1}/{num_epochs}, loss = {train_loss.numpy()}, val_loss = {val_loss.numpy()}, Patience = {patience}')

        if val_loss.numpy() < min_valid_loss:
            min_valid_loss = min(min_valid_loss, val_loss.numpy())
            patience = MAX_PATIENCE
            best_shared_weights = copy.deepcopy(shared_weights)
            best_decoder_weights = copy.deepcopy(decoder_weights)
        else:
            patience -= 1
            if patience == 0:
                print(f'Stopping Training at Epoch {epoch + 1}')
                break


    time_taken = time.time() - timeStart
    # print(f'gradient_metrics = {gradient_metrics}')
    # load the original model
    for best_weight, curr_weight in zip(best_shared_weights, shared_encoder.trainable_weights):
        curr_weight.assign(best_weight)

    for task_id, decoder_specific_weights in best_decoder_weights.items():
        for best_weight, curr_weight in zip(decoder_specific_weights, task_decoders[task_id].trainable_weights):
            curr_weight.assign(best_weight)

    print(f'stopping training at Epoch {epoch + 1}')

    test_loss, indiv_losses, y_pred = test_step(test_data, test_label)
    indiv_losses = [each_loss.numpy() for each_loss in indiv_losses]
    y_pred = [pred.numpy() for pred in y_pred]
    y_test = [label for label in test_label]
    y_pred = np.concatenate(y_pred, axis=0)
    y_test = np.concatenate(y_test, axis=0)

    print(f'test_loss = {test_loss}')

    return test_loss.numpy(), indiv_losses, time_taken


if __name__ == "__main__":
    datasetName = 'Landmine'
    DataPath = f'../Dataset/{datasetName.upper()}/'
    import sys

    Method_name = 'SimpleMTL'
    group_len = 'ALL'#1 for STL, 2 for Pairs,

    ResultPath = '../RESULTS/GROUPS_MTL/'

    import sys

    Task_InfoData = pd.read_csv(f'{DataPath}Task_Information_{datasetName}.csv', low_memory=False)
    TASKS = list(Task_InfoData['Task_Name'])
    TASKS = [str(task) for task in TASKS]

    task_len = {}
    variance_dict = {}
    std_dev_dict = {}
    dist_dict = {}
    Single_res_dict = {}
    STL_error = {}
    STL_AP = {}

    Arch_Name = 'Arch_1'
    if Arch_Name == 'Arch_1':
        initial_shared_architecture = {'shared_FF_Layers': 2, 'shared_FF_Neurons': [64, 32],
                                       'learning_rate': 1e-3}
        num_epochs = 1000
        batch_size = 64
        MAX_PATIENCE = 20

    import itertools


    if group_len == 'ALL':
        TASK_Group = [tuple(TASKS)]
        name_suffix = 'ALL'


    RUNS = [1, 2, 3, 4, 5, 6]
    for run in RUNS:
        seed_value = run
        tf.random.set_seed(seed_value)
        np.random.seed(seed_value)
        random.seed(seed_value)

        Task_group = []
        Total_Loss = []
        Individual_Group_Score = []
        Individual_Error_Rate = []
        Individual_AP = []
        Number_of_Groups = []
        Individual_Task_Score = []
        Time_taken_for_training = []
        Prev_Groups = {}
        for count in range(len(TASK_Group)):
            print(f'Initial Training for {datasetName}-partition {count}, {TASK_Group[count]}')
            task_group = TASK_Group[count]

            args_tasks = []
            group_score = {}
            tmp_task_score = []

            data_param_dict_for_specific_task = data_preparation(task_group)
            all_scores = final_model(initial_shared_architecture, task_group, data_param_dict_for_specific_task,
                                     Method_name, val=True)
            print(all_scores)
            tot_loss, indi_scores = all_scores[0], all_scores[1]
            Total_time = all_scores[-1]
            task_scores = {}
            for idx, task in enumerate(task_group):
                task_scores[f'Landmine_{task}'] = indi_scores[idx]

            print(f'tot_loss = {tot_loss}')
            print(f'group_score = {group_score}')

            Task_group.append(task_group)
            Total_Loss.append(tot_loss)
            Individual_Task_Score.append(copy.deepcopy(task_scores))
            Time_taken_for_training.append((Total_time / 60))

            print(len(Total_Loss), len(Task_group), len(Individual_Task_Score), len(Individual_Error_Rate))

            temp_res = pd.DataFrame({'Total_Loss': Total_Loss,
                                     'Task_group': Task_group,
                                     'Individual_Task_Score': Individual_Task_Score,
                                     'Time_taken_for_training': Time_taken_for_training,
                                     })
            temp_res.to_csv(f'{ResultPath}/{datasetName}_{name_suffix}_run_{run}_SGD_Arch_{Arch_Name}_BASELINE_GRADTAE.csv',index=False)
