# Time : 2023/11/13 15:10
# Author : 小霸奔
# FileName: incremental_trainer.p
import copy
import random

import torch
from model.pretrain_net import FeatureExtractor, TransformerEncoder, SleepMLP
from utils.config import ModelConfig
import os
from dataloader.data_loader import Builder
from torch.utils.data import DataLoader
import numpy as np
from model.incremental_algorithm import  CPC, SimSiam, BufferPseudoLabelFinetune4

from utils.util import Evaluator, compute_aaa, compute_forget, fix_randomness
import torch.nn.functional as F
from utils.util import mmd_rbf


def evaluator(model, dl, args):
    if type(model) == tuple:
        model[0].eval()
        model[1].eval()
        model[2].eval()
    else:
        model.eval()

    device = args["device"]

    model[0].to(device)
    model[1].to(device)
    model[2].to(device)

    model_param = ModelConfig(args["dataset"])
    y_pred = []
    y_test = []
    predictions = None
    bh = False
    with torch.no_grad():
        for batch_idx, data in enumerate(dl):
            eog, eeg, labels = data[0].to(device), data[1].to(device), data[2].to(device)

            epoch_size = model_param.EpochLength
            eog = eog.view(-1, model_param.EogNum, epoch_size)
            eeg = eeg.view(-1, model_param.EegNum, epoch_size)
            eeg_eog_feature = model[0](eeg, eog)

            # EEG + EOG
            eeg_eog_feature = model[1](eeg_eog_feature)  # batch, 20, 512

            prediction = model[2](eeg_eog_feature)

            if not bh:
                predictions = prediction
                bh = True
            else:
                predictions = torch.concat((predictions, prediction), dim=0)
            _, predicted = torch.max(prediction.data, dim=1)
            predicted, labels = torch.flatten(predicted), torch.flatten(labels)

            predicted = predicted.tolist()
            y_pred.extend(predicted)
            labels = labels.tolist()
            y_test.extend(labels)
        report = (y_test, y_pred, predictions)
        return report


def get_new_task_loader(args, new_task_idx, is_buffer, shuffle):
    new_task_path = [[], []]
    if args["dataset"] != "SleepEDF":
        file_path = args['filepath'] + f"/{new_task_idx}/data"
        label_path = args['filepath'] + f"/{new_task_idx}/label"
        num = 0
        while os.path.exists(file_path + f"/{num}.npy"):
            new_task_path[0].append(file_path + f"/{num}.npy")
            new_task_path[1].append(label_path + f"/{num}.npy")
            num += 1
    else:
        file_path = args['filepath'] + f"/seq/{new_task_idx}"
        label_path = args['filepath'] + f"/labels/{new_task_idx}"
        file_path_list = os.listdir(file_path)
        label_path_list = os.listdir(label_path)
        for f in file_path_list:
            new_task_path[0].append(file_path + f"/{f}")
        for ll in label_path_list:
            new_task_path[1].append(label_path + f"/{ll}")
    if is_buffer:
        new_task_builder = Builder(new_task_path, args).BufferDataset
    else:
        new_task_builder = Builder(new_task_path, args).Dataset
    new_task_loader = DataLoader(dataset=new_task_builder, batch_size=args["batch"], shuffle=shuffle, num_workers=4)

    return new_task_loader


def incremental_learning_contrastive_buffer(old_task_loader, new_task_idx, args):
    EPOCH_NUMBER = args['KL_Epoch']
    PATH = 'path'
    num = 1
    if args['Stability'][0]:
        if args['Stability'][1] == 1:
            fix_randomness(args['rand'] + 1)
            print(f'Stability Seed: {args["rand"] + 1}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 2:
            fix_randomness(args['rand'] + 10)
            print(f'Stability Seed: {args["rand"] + 10}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 3:
            fix_randomness(args['rand'] + 100)
            print(f'Stability Seed: {args["rand"] + 100}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 4:
            fix_randomness(args['rand'] + 1000)
            print(f'Stability Seed: {args["rand"] + 1000}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 5:
            fix_randomness(args['rand'] + 10000)
            print(f'Stability Seed: {args["rand"] + 10000}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 6:
            fix_randomness(args['rand'] + 9)
            print(f'Stability Seed: {args["rand"] + 9}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 7:
            fix_randomness(args['rand'] + 99)
            print(f'Stability Seed: {args["rand"] + 99}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 8:
            fix_randomness(args['rand'] + 999)
            print(f'Stability Seed: {args["rand"] + 999}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())
        elif args['Stability'][1] == 9:
            fix_randomness(args['rand'] + 9999)
            print(f'Stability Seed: {args["rand"] + 9999}')
            new_task_idx = sorted(new_task_idx, key=lambda x: random.random())


    for new_task_id in new_task_idx:
        print("New Task Id", new_task_id)
        if args['alpha_type'] == 1:
            if num >= args['train_num']:
                args['alpha'] = np.power(0.1, np.log10(num/args['train_num'])+2)
        elif args['alpha_type'] == 2:
            if num >= args['train_num']:
                args['alpha'] = np.power(0.1, np.log10(num / args['train_num']) + 2)
            else:
                args['alpha'] = 0.1*np.power(0.1, num / args['train_num'])
        elif args['alpha_type'] == 3:
            if num >= args['train_num'] // 2:
                args['alpha'] = np.power(0.1, np.log2((num*2) / args['train_num']) + 2)
            else:
                args['alpha'] = 0.1 * np.power(0.1, (num*2) / args['train_num'])

        new_task_loader = get_new_task_loader(args, new_task_id, False, True)

        feature_extractor = FeatureExtractor(args).float()
        sleep_classifier = SleepMLP(args).float()
        feature_encoder = TransformerEncoder(args).float()

        if not os.path.exists(PATH):
            os.makedirs(PATH)

        if num == 1:
            """
            First Individual for Incremental Learning 
            """
            feature_extractor.load_state_dict(
                torch.load(PATH))
            feature_encoder.load_state_dict(
                torch.load(PATH))
            sleep_classifier.load_state_dict(
                torch.load(PATH))
            old_task_ans = evaluator((feature_extractor, feature_encoder, sleep_classifier), old_task_loader, args)

            old_task_evaluator = Evaluator(old_task_ans[0], old_task_ans[1])
            old_task_acc, old_task_mf1 = old_task_evaluator.metric_acc(), old_task_evaluator.metric_mf1()
            args['old_task_performance']['ACC'].append(old_task_acc)
            args['old_task_performance']['MF1'].append(old_task_mf1)

            old_task_aaa = compute_aaa(args['old_task_performance']['ACC'])
            old_task_forget = compute_forget(args['old_task_performance']['ACC'])
            args['old_task_performance']['AAA'].append(old_task_aaa)
            args['old_task_performance']['FR'].append(old_task_forget)

        else:
            """
            Load last model
            """
            feature_extractor.load_state_dict(torch.load(PATH))
            feature_encoder.load_state_dict(torch.load(PATH))
            sleep_classifier.load_state_dict(torch.load(PATH))

        cur_blocks = (feature_extractor, feature_encoder, sleep_classifier)
        teacher_blocks = copy.deepcopy(cur_blocks)
        last_blocks = copy.deepcopy(cur_blocks)
        tmp_blocks, tmp_blocks_teacher = incremental_trainer(cur_blocks, teacher_blocks, args, new_task_loader, new_task_id, num)

        """Store Newest Model"""
        state_f = tmp_blocks[0].state_dict()
        for key in state_f.keys():
            state_f[key] = state_f[key].to(torch.device("cpu"))

        state_encoder = tmp_blocks[1].state_dict()
        for key in state_encoder.keys():
            state_encoder[key] = state_encoder[key].to(torch.device("cpu"))

        state_sleep = tmp_blocks[2].state_dict()
        for key in state_sleep.keys():
            state_sleep[key] = state_sleep[key].to(torch.device("cpu"))
        torch.save(state_f, PATH)
        torch.save(state_encoder, PATH)
        torch.save(state_sleep, PATH)

        """Initial Model"""
        feature_extractor_initial = FeatureExtractor(args).float()
        sleep_classifier_initial = SleepMLP(args).float()
        feature_encoder_initial = TransformerEncoder(args).float()

        feature_extractor_initial.load_state_dict(
            torch.load(PATH))
        feature_encoder_initial.load_state_dict(
            torch.load(PATH))
        sleep_classifier_initial.load_state_dict(
            torch.load(PATH))

        """Metric"""
        new_task_initial_ans = evaluator((feature_extractor_initial,
                                          feature_encoder_initial,
                                          sleep_classifier_initial), new_task_loader, args)
        new_task_before_ans = evaluator(last_blocks, new_task_loader, args)
        new_task_after_teacher_ans = evaluator(tmp_blocks_teacher, new_task_loader, args)
        new_task_after_ans = evaluator(tmp_blocks, new_task_loader, args)

        new_initial_evaluator = Evaluator(new_task_initial_ans[0], new_task_initial_ans[1])
        new_before_evaluator = Evaluator(new_task_before_ans[0], new_task_before_ans[1])
        new_after_teacher_evaluator = Evaluator(new_task_after_teacher_ans[0], new_task_after_teacher_ans[1])
        new_after_evaluator = Evaluator(new_task_after_ans[0], new_task_after_ans[1])

        new_task_initial_acc, new_task_initial_mf1 = new_initial_evaluator.metric_acc(), new_initial_evaluator.metric_mf1()
        new_task_before_acc, new_task_before_mf1 = new_before_evaluator.metric_acc(), new_before_evaluator.metric_mf1()
        new_task_after_acc, new_task_after_mf1 = new_after_evaluator.metric_acc(), new_after_evaluator.metric_mf1()

        args['new_task_performance'][new_task_id]['ACC'] = [new_task_initial_acc, new_task_before_acc, new_task_after_acc]
        args['new_task_performance'][new_task_id]['MF1'] = [new_task_initial_mf1, new_task_before_mf1, new_task_after_mf1]

        print(f"=========New Task {new_task_id}=========")
        print(" ACC Initial                    ", args['new_task_performance'][new_task_id]['ACC'][0], "\n",
              "ACC Before                    ", args['new_task_performance'][new_task_id]['ACC'][1], "\n",
              "ACC After Contrastive Learning ", new_after_teacher_evaluator.metric_acc(), "\n",
              "ACC After Joint Training       ", args['new_task_performance'][new_task_id]['ACC'][2], "\n")

        old_task_ans = evaluator(tmp_blocks, old_task_loader, args)
        old_task_teacher_ans = evaluator(tmp_blocks_teacher, old_task_loader, args)

        old_task_evaluator = Evaluator(old_task_ans[0], old_task_ans[1])
        old_task_evaluator_teacher = Evaluator(old_task_teacher_ans[0], old_task_teacher_ans[1])

        old_task_acc = old_task_evaluator.metric_acc()
        old_task_mf1 = old_task_evaluator.metric_mf1()
        args['old_task_performance']['ACC'].append(old_task_acc)
        args['old_task_performance']['MF1'].append(old_task_mf1)

        old_task_aaa = compute_aaa(args['old_task_performance']['ACC'])
        old_task_forget = compute_forget(args['old_task_performance']['ACC'])
        args['old_task_performance']['AAA'].append(old_task_aaa)
        args['old_task_performance']['FR'].append(old_task_forget)

        print(args["old_task_performance"])
        print("=========Old Task=========")
        if num == 1:
            print(" ACC Before                    ", args['old_task_performance']['ACC'][0], "\n",
                  "ACC After Contrastive Learning ", old_task_evaluator_teacher.metric_acc(), "\n",
                  "ACC After Joint Training       ", args['old_task_performance']['ACC'][-1], "\n",
                  "MF1 Before                     ", args['old_task_performance']['MF1'][0], "\n",
                  "MF1 After Contrastive Learning ", old_task_evaluator_teacher.metric_mf1(), "\n",
                  "MF1 After Joint Training       ", args['old_task_performance']['MF1'][-1], "\n")
        else:
            print(" ACC Before                    ", args['old_task_performance']['ACC'][-2], "\n",
                  "ACC After Contrastive Learning ", old_task_evaluator_teacher.metric_acc(), "\n",
                  "ACC After Joint Training       ", args['old_task_performance']['ACC'][-1], "\n",
                  "MF1 Before                     ", args['old_task_performance']['MF1'][-2], "\n",
                  "MF1 After Contrastive Learning ", old_task_evaluator_teacher.metric_mf1(), "\n",
                  "MF1 After Joint Training       ", args['old_task_performance']['MF1'][-1], "\n")

        """If needs merge new individual's data to the buffer"""
        if args['buffer_merge']:
            buffer_single_merge(args, new_task_id, num, tmp_blocks)
        print("Buffer_Length", len(args["train_path"][0]), len(args["train_path"][1]))
        num += 1


def buffer_single_merge(args, new_task_id, num, tmp_blocks):
    EPOCH_NUMBER = args['KL_Epoch']
    new_task_path = [[], []]
    file_path = args['filepath'] + f"/{new_task_id}/data"
    label_path = args['filepath'] + f"/{new_task_id}/label"
    idx = 0
    while os.path.exists(file_path + f"/{idx}.npy"):
        new_task_path[0].append(file_path + f"/{idx}.npy")
        new_task_path[1].append(label_path + f"/{idx}.npy")
        idx += 1
    new_task_builder = Builder(new_task_path, args).Dataset
    new_task_loader = DataLoader(dataset=new_task_builder, batch_size=args["batch"], shuffle=False, num_workers=4)
    new_task_after_ans = evaluator(tmp_blocks, new_task_loader, args)

    new_task_pred = new_task_after_ans[1]
    new_task_pseudo_label = np.array(new_task_pred)
    seq_num = new_task_pseudo_label.shape[0] // 20
    save_path = 'PATH'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not args['DCB']:
        """直接更新新个体所有数据进BUFFER池"""

        for j in range(seq_num):
            save_label_path = save_path + f"/{j}.npy"
            save_label = new_task_pseudo_label[j * 20: j * 20 + 20]
            np.save(save_label_path, save_label)
            args["train_path"][1].append(save_label_path)
        args["train_path"][0].extend(new_task_path[0])
    else:
        """只更新新个体每个序列样本中有N个Epoch的预测概率大于阈值的个体"""
        confident_epoch_n = 15
        confidence_level = 0.9
        new_task_out = new_task_after_ans[2]
        mean_t_pred = torch.softmax(new_task_out, dim=1)
        pred_prob = mean_t_pred.max(1, keepdim=True)[0].squeeze()
        pred_label = mean_t_pred.max(1, keepdim=True)[1].squeeze()

        pred_prob = pred_prob.cpu().numpy()
        pred_label = pred_label.cpu().numpy()
        for bh in range(pred_prob.shape[0]):
            confident_epoch_num_per_seq = np.sum(pred_prob[bh, :] >= confidence_level)
            if confident_epoch_num_per_seq >= confident_epoch_n:
                confident_label = pred_label[bh, :].reshape(-1, 1)
                confident_label = np.squeeze(confident_label)
                save_label_path = save_path + f"/{bh}.npy"
                np.save(save_label_path, confident_label)
                args["train_path"][1].append(save_label_path)
                args["train_path"][0].append(new_task_path[0][bh])


def incremental_trainer(blocks, teacher_blocks, args, new_task_loader, new_task_id, num):
    EPOCH_NUMBER = args['KL_Epoch']
    """ Confirm Contrastive Method"""
    if args['algorithm'] == 'cpc':
        contrastive_algorithm = CPC(teacher_blocks, args)
    elif args['algorithm'] == 'simsiam':
        contrastive_algorithm = SimSiam(teacher_blocks, args)
    else:
        contrastive_algorithm = None

    device = args["device"]
    tmp_blocks = None

    blocks[0].to(device)
    blocks[1].to(device)
    blocks[2].to(device)

    teacher_blocks[0].to(device)
    teacher_blocks[1].to(device)
    teacher_blocks[2].to(device)

    """
    Firstly, train teacher model using contrastive method to obtain confident pseudo-label
    """

    for epoch in range(1, args["ssl_epoch"] + 1):
        teacher_blocks[0].train()
        teacher_blocks[1].train()
        teacher_blocks[2].train()

        epoch_loss = []

        for batch_idx, data in enumerate(new_task_loader):
            eog, eeg, label = data[0].to(device), data[1].to(device), data[2].to(device)
            loss, tmp_blocks_teacher = contrastive_algorithm.update(eeg, eog, label)
            epoch_loss.append(loss)
        print(f"New Task ID {int(num)}  Contrastive Epoch {epoch} Loss {np.mean(epoch_loss)}")

    """
    Secondly, using pseudo-label and train loader for joint-training.
    """

    """
    BufferPseudoLabelFinetune:  Finetune Method Buffer and New data both use Pseudo label for CrossEntropy
    BufferPseudoLabelFinetune2: Finetune Method Buffer use True label and New data use Pseudo label for CrossEntropy
    """

    algorithm = BufferPseudoLabelFinetune4(blocks, tmp_blocks_teacher, args)  # Buffer New Both True Encoder和Extractor层NEWOLD合并输入

    shuffle = True
    buffer_loader = get_new_task_loader(args, new_task_id, True, shuffle)
    print(f"New Task ID {int(num)}  Alpha={args['alpha']}")
    optimizer = torch.optim.Adam([{"params": list(blocks[0].parameters())},
                                  {"params": list(blocks[1].parameters())}],
                                    lr=args["incremental_lr"],
                                    betas=(args['beta'][0], args['beta'][1]),
                                    weight_decay=args['weight_decay'])
    model_param = ModelConfig(args['dataset'])


    ll = get_new_task_loader(args, new_task_id, False, True)
    for epoch in range(1, args["incremental_epoch"] + 1):
        params = list(blocks[0].named_parameters())

        blocks[0].train()
        blocks[1].train()
        blocks[2].train()

        tmp_blocks_teacher[0].eval()
        tmp_blocks_teacher[1].eval()
        tmp_blocks_teacher[2].eval()

        epoch_loss = []
        kl_loss = []
        for batch_idx, data in enumerate(buffer_loader):

            eog, eeg, label = data[0].to(device), data[1].to(device), data[2].to(device)
            loss, tmp_blocks, feature_before = algorithm.update(eeg, eog, label)
            epoch_loss.append(loss)
            if args['CEA']:
                if epoch % EPOCH_NUMBER == 0:
                    eog_train = eog[:, 20:, :, :].contiguous().view(-1, 2, 3000)
                    eeg_train = eeg[:, 20:, :, :].contiguous().view(-1, model_param.EegNum, 3000)
                    feature_latter = blocks[0](eeg_train, eog_train)
                    feature_latter = blocks[1](feature_latter)

                    optimizer.zero_grad()

                    z1 = torch.nn.functional.log_softmax(feature_latter, dim=-1)
                    z2 = F.softmax(feature_before.detach(), dim=-1)
                    loss_clr_new = F.kl_div(z1, z2, reduction='batchmean')
                    loss_clr_new.backward()
                    optimizer.step()
                    kl_loss.append(loss_clr_new.item())
                    # print(loss_clr_new.item())
        print(f"New Task ID {int(num)}  Joint Fine-tuning Epoch {epoch} Loss {np.mean(epoch_loss)}")

    return tmp_blocks, tmp_blocks_teacher

