import copy
import random

import torch
import os
import numpy as np
from utils.util import fix_randomness, analysis
from torch.utils.data import DataLoader
from dataloader.data_loader import Builder
from trainer.pretrain import pretraining
from trainer.pretrain_face import pretraining_face
from trainer.pretrain_bci2000 import pretraining_bci2000
from trainer.incremental_trainer import incremental_learning
from trainer.incremental_trainer_contrastive_buffer import incremental_learning_contrastive_buffer
from trainer.incremental_trainer_contrastive_buffer_face import incremental_learning_contrastive_buffer_face
from trainer.incremental_trainer_contrastive_buffer_bci2000 import incremental_learning_contrastive_buffer_bci2000

def get_filepath(dataset):
    path = None
    path_ = None
    if dataset == "ISRUC":
        path = path_

    elif dataset == "FACE":
        path = path_

    elif dataset == 'BCI2000':
        path = path_

    return path


def get_path_loader(parser):
    path = None
    if parser["dataset"] == "ISRUC":
        path = [i for i in range(1, 101) if i not in [8, 40]]
    elif parser["dataset"] == "Hang7":
        path = [i for i in range(1, 68)]
    elif parser["dataset"] == "HMC":
        path = [i for i in range(1, 148)]
    elif parser['dataset'] == "FACE":
        path = [i for i in range(123)]
    elif parser['dataset'] in ['BCI2000', 'BCI2000_2']:
        path = [i for i in range(109) if i not in [38, 88, 89, 92, 100, 104]]
    elif parser['dataset'] == "MDD":
        path = [i for i in range(63)]
    elif parser['dataset'] == "TUEV":
        path = [i for i in range(1, 251)]
    path_name = {int(j): [[], []] for j in path}

    parser["info"] = {int(j): [[], []] for j in path}

    for t_idx in path:
        num = 0
        if parser["dataset"] == "SleepEDF":
            file_path = parser['filepath'] + f"/seq/{t_idx}"
            label_path = parser['filepath'] + f"/labels/{t_idx}"
            file_path_list = os.listdir(file_path)
            label_path_list = os.listdir(label_path)
            for f in file_path_list:
                path_name[int(t_idx)][0].append(file_path + f"/{f}")
            for ll in label_path_list:
                path_name[int(t_idx)][1].append(label_path + f"/{ll}")
        else:
            file_path = parser['filepath'] + f"/{t_idx}/data"
            label_path = parser['filepath'] + f"/{t_idx}/label"
            while os.path.exists(file_path + f"/{num}.npy"):
                path_name[t_idx][0].append(file_path + f"/{num}.npy")
                path_name[t_idx][1].append(label_path + f"/{num}.npy")
                num += 1


    return path, path_name


def get_idx(parser, path):
    fix_randomness(parser["rand"])
    idx = path
    path_len = len(idx)
    old_task_idx = list(np.random.choice(idx, int(path_len*0.2), replace=False))
    new_task_idx = list(np.random.choice(list(set(idx)-set(old_task_idx)), int(path_len*0.5), replace=False))

    train_val_idx = list(set(idx)-set(old_task_idx)-set(new_task_idx))

    train_idx = list(np.random.choice(train_val_idx, int(len(train_val_idx)*0.8), replace=False))
    parser['train_num'] = len(train_idx)
    val_idx = [i for i in train_val_idx if i not in train_idx]

    parser["old_task_performance"] = {"ACC": [], "MF1": [], "AAA": [], "FR": []}
    parser["new_task_performance"] = {i: {"ACC": [], "MF1": []} for i in new_task_idx}
    print(" Train Idx  ", len(train_idx), sorted(train_idx), "\n",
          "Val Idx  ", len(val_idx), sorted(val_idx), "\n",
          "Old Task Idx", len(old_task_idx), sorted(old_task_idx), "\n",
          "New Task Idx", len(new_task_idx), sorted(new_task_idx))
    return train_idx, val_idx, old_task_idx, new_task_idx


def get_loader(parser, path, path_name):
    train_path = [[], []]
    val_path = [[], []]
    old_path = [[], []]
    train_idx, val_idx, old_task_idx, new_task_idx = get_idx(parser, path)
    if parser["dataset"] == "SleepEDF":
        for t_idx in train_idx:
            train_path[0].extend(path_name[int(t_idx)][0])
            train_path[1].extend(path_name[int(t_idx)][1])

        for v_idx in val_idx:
            val_path[0].extend(path_name[int(v_idx)][0])
            val_path[1].extend(path_name[int(v_idx)][1])

        for o_idx in old_task_idx:
            old_path[0].extend(path_name[int(o_idx)][0])
            old_path[1].extend(path_name[int(o_idx)][1])
    elif parser["dataset"] == "FACE":
        for t_idx in train_idx:
            train_path[0].extend(path_name[int(t_idx)][0])
            train_path[1].extend(path_name[int(t_idx)][1])

        for v_idx in val_idx:
            val_path[0].extend(path_name[int(v_idx)][0])
            val_path[1].extend(path_name[int(v_idx)][1])

        for o_idx in old_task_idx:
            old_path[0].extend(path_name[int(o_idx)][0])
            old_path[1].extend(path_name[int(o_idx)][1])
    else:
        for t_idx in train_idx:
            train_path[0].extend(path_name[t_idx][0])
            train_path[1].extend(path_name[t_idx][1])

        for v_idx in val_idx:
            val_path[0].extend(path_name[v_idx][0])
            val_path[1].extend(path_name[v_idx][1])

        for o_idx in old_task_idx:
            old_path[0].extend(path_name[int(o_idx)][0])
            old_path[1].extend(path_name[int(o_idx)][1])

    # 创建数据集
    parser['train_path'] = train_path
    parser['train_len'] = len(parser['train_path'][0])
    train_builder = Builder(train_path, parser).Dataset
    val_builder = Builder(val_path, parser).Dataset
    old_task_builder = Builder(old_path, parser).Dataset
    print("Buffer_Length", len(parser["train_path"][0]), len(parser["train_path"][1]))

    return train_builder, val_builder, old_task_builder, new_task_idx


def main():
    parser = dict()
    parser["pretrain_epoch"] = 100
    parser["incremental_epoch"] = 10
    parser["ssl_epoch"] = 10
    parser["algorithm"] = "cpc"
    parser["is_buffer"] = True
    parser["is_teacher"] = True
    parser["buffer_merge"] = True
    parser['teacher'] = False
    parser['CEA'] = True
    parser['DCB'] = True
    parser['Stability'] = [False, 4]
    parser["dataset"] = "FACE"
    parser["gpu"] = 7
    parser['KL_Epoch'] = 3
    parser["dict"] = {"ISRUC": [16, 4321], "FACE": [28, 432], "BCI2000": [32, 4321]}
    parser["batch"] = parser['dict'][parser['dataset']][0]
    parser["rand"] = parser['dict'][parser['dataset']][1]
    parser["lr"] = 0.0001
    parser["contrastive_lr"] = 0.000001
    if parser['dataset'] == 'BCI2000':
        parser["incremental_lr"] = 0.000001
    else:
        parser["incremental_lr"] = 0.0000001
    parser["alpha"] = 0.01
    parser['alpha_type'] = 1
    parser['gamma'] = [1, 1, 1]
    parser["filepath"] = get_filepath(parser["dataset"])
    parser["optimizer"] = "AdamW"
    parser["device"] = torch.device(f"cuda:{parser['gpu']}" if torch.cuda.is_available() else "cpu")
    parser["beta"] = [0.5, 0.99]
    parser["weight_decay"] = 3e-4
    parser["num_worker"] = 4
    parser["print_p"] = True
    parser["pretrain"] = False

    for key in parser.keys():
        print(f"{key}:  {parser[key]}")

    fix_randomness(parser["rand"])
    torch.multiprocessing.set_start_method('spawn')
    path, path_name = get_path_loader(parser)

    train_dataset, val_dataset, old_dataset, new_task_idx = get_loader(parser, path, path_name)

    # 加载数据集
    train_loader = DataLoader(dataset=train_dataset, batch_size=parser['batch'],
                              shuffle=True, num_workers=parser["num_worker"])
    val_loader = DataLoader(dataset=val_dataset, batch_size=parser['batch'],
                            shuffle=True, num_workers=parser["num_worker"])
    old_task_loader = DataLoader(dataset=old_dataset, batch_size=parser['batch'],
                                 shuffle=True, num_workers=parser['num_worker'])
    if parser["pretrain"]:
        if parser['dataset'] == "FACE":
            pretraining_face(train_loader, val_loader, parser)
        elif parser['dataset'] == "BCI2000":
            pretraining_bci2000(train_loader, val_loader, parser)
        else:
            pretraining(train_loader, val_loader, parser)
    else:
        if parser["dataset"] == "FACE":
            incremental_learning_contrastive_buffer_face(old_task_loader, new_task_idx, parser)
        elif parser['dataset'] == 'BCI2000':
            incremental_learning_contrastive_buffer_bci2000(old_task_loader, new_task_idx, parser)
        elif parser['dataset'] == 'ISRUC':
            incremental_learning_contrastive_buffer(old_task_loader, new_task_idx, parser)


        analysis(parser)


if __name__ == '__main__':
    main()

