from __future__ import print_function

import random
import time
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import os
from sklearn.preprocessing import MinMaxScaler
from train_func import *
from utils import *
from backbone_models import conv_EEG, conv_EEG_Feature, PiCO, PaPi, PG
from sklearn.metrics import confusion_matrix
import parsing
import pandas as pd

if torch.cuda.is_available() is False:
    os.environ["OMP_NUM_THREADS"] = "1"
if args.num_class == 3:
    train_de = '../EEGDATA/SEED/train/de/{}_{}.npy'  # Subject_No, Session_No
    test_de = '../EEGDATA/SEED/test/de/{}_{}.npy'  # Subject_No, Session_No
    train_label = '../EEGDATA/SEED/train/label/{}_{}.npy'
    test_label = '../EEGDATA/SEED/test/label/{}_{}.npy'
    labels = ['negative', 'neutral', 'positive']
elif args.num_class == 4:
    data_addr = '../EEGDATA/SEED_IV/EEG/de_{}_{}.npy'  # subject_No, Fold_No
    label_addr = '../EEGDATA/SEED_IV/EEG/label_{}_{}.npy'  # subject_No, Fold_No
    labels = ['Neutral', 'Sad', 'Fear', 'Happy']
elif args.num_class == 5:
    data_addr = '../EEGDATA/SEED_V/EEG/de_{}_{}.npy'  # subject_No, Fold_No
    label_addr = '../EEGDATA/SEED_V/EEG/label_{}_{}.npy'  # subject_No, Fold_No
    labels = ['Disgust', 'Fear', 'Sad', 'Neutral', 'Happy']

parser = parsing.create_parser()
args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

args.conf_ema_range = [float(item) for item in args.conf_ema_range.split(',')]
train_func_dict = {'PGNA_PL':T_PGNA_PL,
                   'DNPL': T_DNPL, 'PRODEN': T_PRODEN, 'CAVL': T_CAVL, 'LW': T_LW, 'CR': T_CR, 'PiCO': T_PiCO, 'PaPI': T_PaPI,  #baselines
                   'PGNA_PL_wtihout_PG': T_PGNA_PL_wtihout_PG, 'PGNA_PL_wtihout_NA': T_PGNA_PL_wtihout_NA,  # ablations
                   'PG_Other_NA__Mask': T_PG_Other_NA__Mask, 'PG_Other_NA__Gauss': T_PG_Other_NA__Gauss, 'PGNA_Other_OriginMixup':T_PGNA_Other_OriginMixup,
                   'PG_Other_NA__PepperSalt': T_PG_Other_NA__PepperSalt, "FullySupervision":T_FullySupervision, 'PGNA_PL_FullySupervision':T_PGNA_PL_FullySupervision   # Compare with other noise augmentation methods
                   }

using_prototypes = True
methods_without_prototypes = ["DNPL", "PRODEN", "CAVL", "LW", "CR", "FullySupervision"]
PGNA_Methods = ["PG_Other_NA__Mask", "PG_Other_NA__labelMixup", "PG_Other_NA__Gauss", "PG_Other_NA__PepperSalt", "PGNA_PL_wtihout_NA", "PGNA_PL", "PGNA_PL_wtihout_PG", "PGNA_Other_OriginMixup", "PGNA_PL_FullySupervision"]
if args.method in methods_without_prototypes:
    using_prototypes = False

all_prototypes = []

def eval_step(inputs, labels, model):
    if args.method == 'PiCO' or args.method == 'PaPI' or args.method in PGNA_Methods:
        outputs_classification = model(inputs, args=args, eval_only=True)
    else:
        outputs_classification = model(inputs)

    classification_pred = torch.max(outputs_classification, 1)[1]

    batch_size = labels.shape[0]
    digital_labels = torch.max(labels, 1)[1]

    err = nn.CrossEntropyLoss()(outputs_classification, digital_labels)
    running_corrects = (classification_pred == digital_labels).float().sum()
    accuracy = running_corrects / batch_size
    # for confusion metrix
    conf_matrix = confusion_matrix(digital_labels.cpu().numpy(), classification_pred.cpu().numpy(),
                                   labels=list(range(args.num_class)))
    return err.item(), accuracy.detach().cpu().clone().numpy(), conf_matrix


def train(Net, train_dataset, val_dataset, test_dataset):
    train_loss_epoch = np.ones((args.epochs, 1))
    val_acc_epoch = np.zeros((args.epochs, 1))  # the acc on the val dataset
    best_val_acc = 0.0  # save the best accuracy on the val dataset
    best_model_weights = None  # save the best model
    best_prototypes = ""

    'Label confidence initialization'
    confidence = copy.deepcopy(train_dataset.dataset.partial_labels)
    confidence = confidence / confidence.sum(axis=1)[:, None]
    confidence = torch.FloatTensor(confidence).to(device)

    'Choice of optmizer'
    if args.optimizer == 'adam':
        optimizer = optim.Adam(Net.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(Net.parameters(), momentum=0.9, lr=args.lr, weight_decay=1e-4)
    else:
        raise Exception('Need to choose the optimizer')

    'Choice of using learning scheduler'
    if args.use_scheduler == True:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 20], last_epoch=-1)

    # define loss
    if args.method == 'PiCO':
        loss_func = partial_loss(confidence)
    if args.method == 'PaPI':
        loss_func = PaPiLoss(predicted_score_cls=confidence, pseudo_label_weight=0.99)

    '''***Start of Training***'''
    for epoch in range(args.epochs):
        start = time.time()
        train_loss_batch = []
        Net.train()
        if args.method == 'PaPI':
            loss_func.set_alpha(epoch, args)
        for index, input, input_w, input_s, origin_y, part_y in train_dataset:
            index, input, input_w, input_s, origin_y, part_y = map(lambda x: x.to(device),
                                                         (index, input, input_w, input_s, origin_y, part_y))
            if args.method=='PiCO' or args.method == 'PaPI':
                loss, confidence = train_func_dict[args.method]().train_step(index, confidence, input, input_w, input_s,
                                                                             part_y, Net, optimizer, epoch, loss_func)
            elif args.method == "PGNA_PL_FullySupervision":
                #replace part_y by origin_y
                loss, confidence = train_func_dict[args.method]().train_step(index, confidence, input, input_w, input_s,
                                                                             origin_y, Net, optimizer, epoch)
            elif args.method in PGNA_Methods:
                loss, confidence = train_func_dict[args.method]().train_step(index, confidence, input, input_w, input_s,
                                                                             part_y, Net, optimizer, epoch)
            elif args.method =="FullySupervision":
                loss = train_func_dict[args.method]().train_step(input, origin_y, Net, optimizer)
            else:
                loss, confidence = train_func_dict[args.method]().train_step(index, confidence, input, input_w, input_s,
                                                                             part_y, Net, optimizer, epoch)
            train_loss_batch.append(loss)
        if args.use_scheduler == True:
            scheduler.step()

        if args.method == 'PiCO':
            loss_func.set_conf_ema_m(epoch, args)
        if args.method == 'PaPI':
            loss_func.set_pseudo_label_weight(epoch, args)
            Net.set_prototype_update_weight(epoch, args)

        train_loss_epoch[epoch] = Average(train_loss_batch)

        Net.eval()
        val_acc_batch = []
        with torch.no_grad():
            for _, image_batch, label_batch, _ in val_dataset:
                image_batch = image_batch.to(device)
                label_batch = label_batch.to(device)
                _, acc_val, _ = eval_step(image_batch, label_batch, Net)
                val_acc_batch.append(acc_val)

        val_acc_epoch[epoch] = np.mean(val_acc_batch)

        if val_acc_epoch[epoch] > best_val_acc:
            best_val_acc = val_acc_epoch[epoch]
            best_model_weights = Net.state_dict().copy()
            if using_prototypes:
                best_prototypes = Net.prototypes.detach().clone().cpu().numpy()

    Net.load_state_dict(best_model_weights)  # Load the best model
    test_loss_batch = []
    test_acc_batch = []
    Net.eval()
    with torch.no_grad():
        conf_matrix_epoch = np.zeros((args.num_class, args.num_class))
        for _, image_batch, label_batch, _ in test_dataset:
            image_batch = image_batch.to(device)
            label_batch = label_batch.to(device)
            loss, acc, conf_matrix = eval_step(image_batch, label_batch, Net)
            conf_matrix_epoch += conf_matrix
            test_loss_batch.append(loss)
            test_acc_batch.append(acc)
    conf_matrix_epoch /= conf_matrix_epoch.sum(axis=1, keepdims=True)  # normalization
    test_acc = Average(test_acc_batch)
    return test_acc, conf_matrix_epoch, best_prototypes


if __name__ == '__main__':
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # Results save path, depends on method and the arguments in parsing
    if args.num_class == 3:
        main_path = './SEED_result/PLL_confusion/main/' + args.method + '/scheduler_{}'.format(
            args.use_scheduler) + '/optimizer_{}/lr_{}/'.format(args.optimizer, args.lr)
        subjects =15
    elif args.num_class == 4:
        main_path = './SEED_IV_result/PLL_confusion/main/' + args.method + '/scheduler_{}'.format(
            args.use_scheduler) + '/optimizer_{}/lr_{}/'.format(args.optimizer, args.lr)
        subjects =15
    elif args.num_class == 5:
        main_path = './SEED_V_result/PLL_confusion/main/' + args.method + '/scheduler_{}'.format(
        args.use_scheduler) + '/optimizer_{}/lr_{}/'.format(args.optimizer, args.lr)
        subjects = 16
    directory = main_path
    if args.method == 'DNPL' or args.method =="FullySupervision":
        directory = main_path
    elif args.method == 'PRODEN' or args.method == 'CAVL':
        directory = main_path + 'confidence_{}/'.format(args.use_confidence)
    elif args.method == 'LW':
        directory = main_path + 'confidence_{}/'.format(args.use_confidence) + '{}/'.format(
            args.loss) + 'beta_{}/'.format(args.beta)
    elif args.method == 'CR':
        directory = main_path + 'confidence_{}/'.format(args.use_confidence) + 'weight_{}_{}_{}/'.format(args.c_weight,
                                                                                                         args.c_weight_w,
                                                                                                         args.c_weight_s)

    elif args.method == "PGNA_PL" or args.method == "PGNA_PL_wtihout_PG" or args.method =="PGNA_Other_OriginMixup" or args.method =="PGNA_PL_FullySupervision":
        directory = main_path + 'confidence_{}/'.format(args.use_confidence) + 'beta_parameter_{}/'.format(args.beta_parameter)
    elif args.method in PGNA_Methods:
        directory = main_path + 'confidence_{}/'.format(args.use_confidence) + 'q_aug_{}/'.format(args.q_aug)
    elif args.method == 'PiCO':
        directory = main_path + 'confidence_{}/'.format(args.use_confidence) + 'contrast_weight_{}/'.format(
            args.gamma) + 'queue_size_{}/'.format(args.moco_queue)
    elif args.method == 'PaPI':
        directory = main_path + 'confidence_{}/'.format(args.use_confidence)
    else:
        raise Exception('Need to choose the method')

    directory = directory + 'run_{}/'.format(args.run_idx)
    directory_matrices = directory + 'confusion_matrices/'

    if args.partial_type == 'Semantic_Distribution':
        time.sleep(2)

    if not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)
    if not os.path.exists(directory_matrices):
        os.makedirs(directory_matrices, exist_ok=True)
    '''Repeat Experiment Five Times'''
    random_seed_arr = [100, 42, 19, 57, 598]
    seed = random_seed_arr[args.run_idx - 1]
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.deterministic = True
    cudnn.benchmark = True

    if args.partial_type == 'Russel_Distribution' or args.partial_type == 'Semantic_Distribution':
        prob_arr = [0]
    else:
        raise Exception('Need to choose the parital label generation method')
    acc_array = np.zeros((subjects, 3))
    for prob in prob_arr:
        if args.partial_type == 'Russel_Distribution':
            if os.path.exists(os.path.join(directory, "Russel.csv")):
                continue
        elif args.partial_type == 'Semantic_Distribution':
            if os.path.exists(os.path.join(directory, "Semantic.csv")):
                continue
        else:
            raise Exception('Need to choose the parital label generation method')

        if args.num_class == 3:
            sessions = 3
            for subject_id in range(1, subjects + 1):
                for session_num in range(1, sessions + 1):
                    Net = conv_EEG(args.num_class).to(device)
                    if args.method in PGNA_Methods:
                        Net = PG(args, conv_EEG_Feature).to(device)
                    if args.method == 'PiCO':
                        Net = PiCO(args, conv_EEG_Feature).to(device)
                    if args.method == 'PaPI':
                        Net = PaPi(args, conv_EEG_Feature).to(device)

                    Net.apply(WeightInit)
                    Net.apply(WeightClipper)

                    X_train = np.load(train_de.format(subject_id, session_num))
                    X_test = np.load(test_de.format(subject_id, session_num))
                    '''Normalize EEG features to the range of [0,1] before fed into model'''
                    X = np.vstack((X_train, X_test))

                    scaler = MinMaxScaler()
                    X = scaler.fit_transform(X)

                    X_train_origin = X[0: X_train.shape[0]]
                    X_test = X[X_train.shape[0]:]

                    Y_train_origin = np.load(train_label.format(subject_id, session_num))
                    Y_test = np.load(test_label.format(subject_id, session_num))

                    X_train, Y_train, X_val, Y_val = split_balance_class(
                        data=X_train_origin, label=Y_train_origin, train_rate=0.8, random=False
                    )
                    Y_train_single = copy.deepcopy(Y_train)

                    Y_train, Y_val, Y_test = map(lambda x: to_categorical(np.ravel(x)), (Y_train, Y_val, Y_test))

                    if args.partial_type == 'Russel_Distribution':
                        partial_label_train, avgC = partialize_Russel_Distribution(Y_train, Y_train_single)  # generation of candidate labels depends on Russel_Distribution similarities
                    elif args.partial_type == 'Semantic_Distribution':
                        partial_label_train, avgC = partialize_Semantic_Distribution(Y_train, Y_train_single)
                    else:
                        break

                    #just for fitting the use of fuction load_augmented_dataset_to_device_aug
                    partial_label_val, _ = partialize(Y_val, p=0.0)
                    partial_label_test, avgC = partialize(Y_test, p=0.0)

                    data_train, data_val, data_test = np.expand_dims(X_train, axis=1), np.expand_dims(X_val, axis=1), np.expand_dims(X_test, axis=1)
                    label_train, label_val, label_test = Y_train, Y_val, Y_test

                    # data loader
                    train_dataset = load_augmented_dataset_to_device(data_train, label_train, partial_label_train,
                                                                     batch_size=8, shuffle_flag=True,
                                                                     augmentation_flag=True)

                    val_dataset = load_augmented_dataset_to_device(data_val, label_val, partial_label_val, batch_size=8,
                                                                    shuffle_flag=False, augmentation_flag=False)
                    test_dataset = load_augmented_dataset_to_device(data_test, label_test, partial_label_test, batch_size=8,
                                                                    shuffle_flag=False, augmentation_flag=False)

                    acc_array[subject_id - 1, session_num -1], best_epoch_conf_matrix, prototypes = train(Net, train_dataset, val_dataset,
                                                                                                    test_dataset)
                    all_prototypes.append(prototypes)
                    torch.cuda.empty_cache()
                    #save the current confusion_matrix
                    if args.partial_type == 'Russel_Distribution':
                        confusion_file = "confusion_matrix_Russel_"+str(subject_id-1)+"_"+str(session_num-1)+".csv"
                    elif args.partial_type == 'Semantic_Distribution':
                        confusion_file = "confusion_matrix_Semantic_"+str(subject_id-1)+"_"+str(session_num-1)+".csv"
                    df1 = pd.DataFrame(best_epoch_conf_matrix, index=labels, columns=labels)
                    csv_file = os.path.join(directory_matrices, confusion_file)
                    df1.to_csv(csv_file)
        elif args.num_class == 4 or args.num_class == 5:
            for subject_id in range(1, subjects + 1):
                # data and labels load
                X1 = np.load(data_addr.format(subject_id, 1))
                X2 = np.load(data_addr.format(subject_id, 2))
                X3 = np.load(data_addr.format(subject_id, 3))
                X = np.vstack((X1, X2, X3))
                Y1 = np.load(label_addr.format(subject_id, 1))
                Y2 = np.load(label_addr.format(subject_id, 2))
                Y3 = np.load(label_addr.format(subject_id, 3))
                Y = np.vstack((Y1, Y2, Y3))
                # data normalization
                scaler = MinMaxScaler()
                X = scaler.fit_transform(X)

                for fold_num in range(3):
                    Net = conv_EEG(args.num_class).to(device)
                    if args.method in PGNA_Methods:
                        Net = PG(args, conv_EEG_Feature).to(device)
                    if args.method == 'PiCO':
                        Net = PiCO(args, conv_EEG_Feature).to(device)
                    if args.method == 'PaPI':
                        Net = PaPi(args, conv_EEG_Feature).to(device)

                    Net.apply(WeightInit)
                    Net.apply(WeightClipper)
                    # print('model parameters:', sum(param.numel() for param in Net.parameters()))

                    fold_1_index = [i for i in range(0, len(X1))]
                    fold_2_index = [i for i in range(len(X1), len(X1) + len(X2))]
                    fold_3_index = [i for i in range(len(X1) + len(X2), len(X1) + len(X2) + len(X3))]

                    # Three-fold cross-validaiton based on pre-defined folds
                    if fold_num == 0:
                        train_index, test_index = fold_1_index + fold_2_index, fold_3_index
                    elif fold_num == 1:
                        train_index, test_index = fold_2_index + fold_3_index, fold_1_index
                    else:
                        train_index, test_index = fold_3_index + fold_1_index, fold_2_index

                    X_train_origin, X_test, Y_train_origin, Y_test = X[train_index], X[test_index], Y[train_index], Y[test_index]

                    X_train, Y_train, X_val, Y_val = split_balance_class(
                        data=X_train_origin, label=Y_train_origin, train_rate=0.8, random=False
                    )
                    Y_train_single = copy.deepcopy(Y_train)

                    Y_train, Y_val, Y_test = map(lambda x: to_categorical(np.ravel(x)), (Y_train, Y_val, Y_test))

                    if args.partial_type == 'Russel_Distribution':
                        partial_label_train, avgC = partialize_Russel_Distribution(Y_train, Y_train_single)  # generation of candidate labels depends on Semantic_Distribution similarities
                    elif args.partial_type == 'Semantic_Distribution':
                        partial_label_train, avgC = partialize_Semantic_Distribution(Y_train, Y_train_single)
                    else:
                        break

                    # just for fitting the use of fuction load_augmented_dataset_to_device_aug
                    partial_label_val, _ = partialize(Y_val, p=0.0)
                    partial_label_test, avgC = partialize(Y_test, p=0.0)

                    data_train, data_val, data_test = np.expand_dims(X_train, axis=1), np.expand_dims(X_val, axis=1), np.expand_dims(
                        X_test, axis=1)
                    label_train, label_val, label_test = Y_train, Y_val, Y_test

                    # data loader
                    train_dataset = load_augmented_dataset_to_device(data_train, label_train, partial_label_train,
                                                                     batch_size=8, shuffle_flag=True,
                                                                     augmentation_flag=True)
                    val_dataset = load_augmented_dataset_to_device(data_val, label_val, partial_label_val, batch_size=8,
                                                                    shuffle_flag=False, augmentation_flag=False)
                    test_dataset = load_augmented_dataset_to_device(data_test, label_test, partial_label_test, batch_size=8,
                                                                    shuffle_flag=False, augmentation_flag=False)

                    acc_array[subject_id - 1, fold_num], best_epoch_conf_matrix, prototypes = train(Net, train_dataset, val_dataset,
                                                                                                    test_dataset)
                    all_prototypes.append(prototypes)
                    torch.cuda.empty_cache()

                    #save the current confusion_matrix
                    if args.partial_type == 'Russel_Distribution':
                        confusion_file = "confusion_matrix_Russel_"+str(subject_id-1)+"_"+str(fold_num)+".csv"
                    elif args.partial_type == 'Semantic_Distribution':
                        confusion_file = "confusion_matrix_Semantic_"+str(subject_id-1)+"_"+str(fold_num)+".csv"
                    df1 = pd.DataFrame(best_epoch_conf_matrix, index=labels, columns=labels)
                    csv_file = os.path.join(directory_matrices, confusion_file)
                    df1.to_csv(csv_file)

        # save results
        if args.partial_type == 'Russel_Distribution':
            np.savetxt(os.path.join(directory, "Russel.csv"), acc_array, delimiter=",")
        elif args.partial_type == 'Semantic_Distribution':
            np.savetxt(os.path.join(directory, "Semantic.csv"), acc_array, delimiter=",")
        else:
            raise Exception('Need to choose the parital label generation method')

    # save prototypes
    if args.partial_type == 'Russel_Distribution' and using_prototypes is True:
        prototypes_mean = np.mean(all_prototypes, axis=0)
        np.save(directory + 'prototypes_Russel.npy', prototypes)
    elif args.partial_type == 'Semantic_Distribution' and using_prototypes is True:
        prototypes_mean = np.mean(all_prototypes, axis=0)
        np.save(directory + 'prototypes_Semantic.npy', prototypes)
