# Time : 2023/11/13 15:10
# Author : 小霸奔
# FileName: incremental_trainer.p
import copy
import torch
from model.pretrain_net import FeatureExtractor, TransformerEncoder, SleepMLP
from utils.config import ModelConfig
import os
from dataloader.data_loader import Builder
from torch.utils.data import DataLoader
import numpy as np
from model.incremental_algorithm import Finetune, CPC, SimSiam, MMD, KL
from utils.util import Evaluator, compute_aaa, compute_forget


def evaluator(model, dl, args):
    if type(model) == tuple:
        model[0].eval()
        model[1].eval()
        model[2].eval()
    else:
        model.eval()

    device = args["device"]

    model[0].to(device)
    model[1].to(device)
    model[2].to(device)

    model_param = ModelConfig(args["dataset"])
    y_pred = []
    y_test = []
    with torch.no_grad():
        for batch_idx, data in enumerate(dl):
            eog, eeg, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            epoch_size = model_param.EpochLength
            eog = eog.view(-1, model_param.EogNum, epoch_size)
            eeg = eeg.view(-1, model_param.EegNum, epoch_size)

            eeg_eog_feature = model[0](eeg, eog)

            # EEG + EOG
            eeg_eog_feature = model[1](eeg_eog_feature)  # batch, 20, 512
            prediction = model[2](eeg_eog_feature)

            _, predicted = torch.max(prediction.data, dim=1)
            predicted, labels = torch.flatten(predicted), torch.flatten(labels)

            predicted = predicted.tolist()
            y_pred.extend(predicted)
            labels = labels.tolist()
            y_test.extend(labels)
        report = (y_test, y_pred)
        return report


def get_new_task_loader(args, new_task_idx, is_buffer):
    new_task_path = [[], []]
    if args["dataset"] != "SleepEDF":
        file_path = args['filepath'] + f"/{new_task_idx}/data"
        label_path = args['filepath'] + f"/{new_task_idx}/label"
        num = 0
        while os.path.exists(file_path + f"/{num}.npy"):
            new_task_path[0].append(file_path + f"/{num}.npy")
            new_task_path[1].append(label_path + f"/{num}.npy")
            num += 1
    else:
        file_path = args['filepath'] + f"/seq/{new_task_idx}"
        label_path = args['filepath'] + f"/labels/{new_task_idx}"
        file_path_list = os.listdir(file_path)
        label_path_list = os.listdir(label_path)
        for f in file_path_list:
            new_task_path[0].append(file_path + f"/{f}")
        for ll in label_path_list:
            new_task_path[1].append(label_path + f"/{ll}")

    new_task_builder = Builder(new_task_path, args).Dataset
    new_task_loader = DataLoader(dataset=new_task_builder, batch_size=args["batch"], shuffle=True, num_workers=4)

    return new_task_loader


def compared_learning(old_task_loader, new_task_idx, args):
    num = 1
    for new_task_id in new_task_idx:
        print("New Task Id", new_task_id)
        new_task_loader = get_new_task_loader(args, new_task_id, False)

        feature_extractor = FeatureExtractor(args).float()
        sleep_classifier = SleepMLP(args).float()
        feature_encoder = TransformerEncoder(args).float()

        if not os.path.exists(f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num}"):
            os.makedirs(f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num}")

        if num == 1:
            """
            First Individual for Incremental Learning 
            """
            feature_extractor.load_state_dict(
                torch.load(
                    f"model_parameter/{args['dataset']}/Pretrain/feature_extractor_parameter_{args['rand']}.pkl"))
            feature_encoder.load_state_dict(
                torch.load(
                    f"model_parameter/{args['dataset']}/Pretrain/feature_encoder_parameter_{args['rand']}.pkl"))
            sleep_classifier.load_state_dict(
                torch.load(
                    f"model_parameter/{args['dataset']}/Pretrain/sleep_classifier_parameter_{args['rand']}.pkl"))
            old_task_ans = evaluator((feature_extractor, feature_encoder, sleep_classifier), old_task_loader, args)

            old_task_evaluator = Evaluator(old_task_ans[0], old_task_ans[1])
            old_task_acc, old_task_mf1 = old_task_evaluator.metric_acc(), old_task_evaluator.metric_mf1()
            args['old_task_performance']['ACC'].append(old_task_acc)
            args['old_task_performance']['MF1'].append(old_task_mf1)

            old_task_aaa = compute_aaa(args['old_task_performance']['ACC'])
            old_task_forget = compute_forget(args['old_task_performance']['ACC'])
            args['old_task_performance']['AAA'].append(old_task_aaa)
            args['old_task_performance']['FR'].append(old_task_forget)

        else:
            """
            Load last model
            """
            feature_extractor.load_state_dict(torch.load(
                    f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num-1}/feature_extractor_parameter_{args['rand']}.pkl"))
            feature_encoder.load_state_dict(torch.load(
                    f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num-1}/feature_encoder_parameter_{args['rand']}.pkl"))
            sleep_classifier.load_state_dict(torch.load(
                    f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num-1}/sleep_classifier_parameter_{args['rand']}.pkl"))

        cur_blocks = (feature_extractor, feature_encoder, sleep_classifier)
        last_blocks = copy.deepcopy(cur_blocks)
        tmp_blocks = incremental_trainer(cur_blocks, args, new_task_loader, new_task_id, num)

        """Store Newest Model"""
        state_f = tmp_blocks[0].state_dict()
        for key in state_f.keys():
            state_f[key] = state_f[key].to(torch.device("cpu"))

        state_encoder = tmp_blocks[1].state_dict()
        for key in state_encoder.keys():
            state_encoder[key] = state_encoder[key].to(torch.device("cpu"))

        state_sleep = tmp_blocks[2].state_dict()
        for key in state_sleep.keys():
            state_sleep[key] = state_sleep[key].to(torch.device("cpu"))
        torch.save(state_f,
                   f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num}/feature_extractor_parameter_{args['rand']}.pkl")
        torch.save(state_encoder,
                   f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num}/feature_encoder_parameter_{args['rand']}.pkl")
        torch.save(state_sleep,
                   f"model_parameter/{args['dataset']}/{args['algorithm']}_BufferMerge{args['buffer_merge']}_{args['incremental_lr']}_{args['incremental_epoch']}/individual_{num}/sleep_classifier_parameter_{args['rand']}.pkl")

        """Initial Model"""
        feature_extractor_initial = FeatureExtractor(args).float()
        sleep_classifier_initial = SleepMLP(args).float()
        feature_encoder_initial = TransformerEncoder(args).float()

        feature_extractor_initial.load_state_dict(
            torch.load(
                f"model_parameter/{args['dataset']}/Pretrain/feature_extractor_parameter_{args['rand']}.pkl"))
        feature_encoder_initial.load_state_dict(
            torch.load(
                f"model_parameter/{args['dataset']}/Pretrain/feature_encoder_parameter_{args['rand']}.pkl"))
        sleep_classifier_initial.load_state_dict(
            torch.load(
                f"model_parameter/{args['dataset']}/Pretrain/sleep_classifier_parameter_{args['rand']}.pkl"))

        """Metric"""
        new_task_initial_ans = evaluator((feature_extractor_initial,
                                          feature_encoder_initial,
                                          sleep_classifier_initial), new_task_loader, args)
        new_task_before_ans = evaluator(last_blocks, new_task_loader, args)
        new_task_after_ans = evaluator(tmp_blocks, new_task_loader, args)

        new_initial_evaluator = Evaluator(new_task_initial_ans[0], new_task_initial_ans[1])
        new_before_evaluator = Evaluator(new_task_before_ans[0], new_task_before_ans[1])
        new_after_evaluator = Evaluator(new_task_after_ans[0], new_task_after_ans[1])

        new_task_initial_acc, new_task_initial_mf1 = new_initial_evaluator.metric_acc(), new_initial_evaluator.metric_mf1()
        new_task_before_acc, new_task_before_mf1 = new_before_evaluator.metric_acc(), new_before_evaluator.metric_mf1()
        new_task_after_acc, new_task_after_mf1 = new_after_evaluator.metric_acc(), new_after_evaluator.metric_mf1()
        print([new_task_before_acc, new_task_after_acc])

        args['new_task_performance'][new_task_id]['ACC'] = [new_task_initial_acc, new_task_before_acc, new_task_after_acc]
        args['new_task_performance'][new_task_id]['MF1'] = [new_task_initial_mf1, new_task_before_mf1, new_task_after_mf1]

        print(f"=========New Task {new_task_id}=========")
        print(" ACC Initial                    ", args['new_task_performance'][new_task_id]['ACC'][0], "\n",
              "ACC Before                    ", args['new_task_performance'][new_task_id]['ACC'][1], "\n",
              "ACC After Joint Training       ", args['new_task_performance'][new_task_id]['ACC'][2], "\n")

        old_task_ans = evaluator(tmp_blocks, old_task_loader, args)
        old_task_evaluator = Evaluator(old_task_ans[0], old_task_ans[1])
        old_task_acc = old_task_evaluator.metric_acc()
        old_task_mf1 = old_task_evaluator.metric_mf1()
        args['old_task_performance']['ACC'].append(old_task_acc)
        args['old_task_performance']['MF1'].append(old_task_mf1)

        old_task_aaa = compute_aaa(args['old_task_performance']['ACC'])
        old_task_forget = compute_forget(args['old_task_performance']['ACC'])
        args['old_task_performance']['AAA'].append(old_task_aaa)
        args['old_task_performance']['FR'].append(old_task_forget)

        print(args["old_task_performance"])
        print("=========Old Task=========")
        if num == 1:
            print(" ACC Before                    ", args['old_task_performance']['ACC'][0], "\n",
                  "ACC After Joint Training       ", args['old_task_performance']['ACC'][-1], "\n",
                  "MF1 Before                     ", args['old_task_performance']['MF1'][0], "\n",
                  "MF1 After Joint Training       ", args['old_task_performance']['MF1'][-1], "\n")
        else:
            print(" ACC Before                    ", args['old_task_performance']['ACC'][-2], "\n",
                  "ACC After Joint Training       ", args['old_task_performance']['ACC'][-1], "\n",
                  "MF1 Before                     ", args['old_task_performance']['MF1'][-2], "\n",
                  "MF1 After Joint Training       ", args['old_task_performance']['MF1'][-1], "\n")

        """If needs merge new individual's data to the buffer"""
        num += 1


def incremental_trainer(blocks, args, new_task_loader, new_task_id,  num):

    if args['algorithm'] == 'finetune':
        algorithm = Finetune(blocks, args)
    elif args['algorithm'] == 'cpc':
        algorithm = CPC(blocks, args)
    elif args['algorithm'] == 'simsiam':
        algorithm = SimSiam(blocks, args)
    elif args['algorithm'] == 'mmd':
        algorithm = MMD(blocks, args)
    elif args['algorithm'] == 'kl':
        algorithm = KL(blocks, args)
    else:
        algorithm = None

    device = args["device"]
    tmp_blocks = None

    blocks[0].to(device)
    blocks[1].to(device)
    blocks[2].to(device)
    if args["is_buffer"]:
        new_task_loader = get_new_task_loader(args, new_task_id, True)

    for epoch in range(1, args["incremental_epoch"] + 1):
        blocks[0].train()
        blocks[1].train()
        blocks[2].train()

        epoch_loss = []

        for batch_idx, data in enumerate(new_task_loader):
            eog, eeg, label = data[0].to(device), data[1].to(device), data[2].to(device)
            loss, tmp_blocks = algorithm.update(eeg, eog, label)
            epoch_loss.append(loss)
        print(f"New Task ID {int(num)}  Epoch {epoch} Loss {np.mean(epoch_loss)}")

    return tmp_blocks

