import os
import pickle
import numpy as np
import torch
from Backbones.model_factory import get_model
from dataset.utils import accuracy, evaluate, NodeLevelDataset
from training.utils import mkdir_if_missing, shuffle_list
from dataset.utils import semi_task_manager
import importlib
import copy
import dgl

joint_alias = ['joint', 'Joint', 'joint_replay_all', 'jointtrain']
def get_pipeline(args):
    # choose the pipeline for the chosen setting
    if args.method in joint_alias:
            return pipeline_joint
    return pipeline_online_IL

def data_prepare(args):
    """
    check whether the processed data exist or create new processed data
    if args.load_check is True, loading data will be tried, else, will only check the existence of the files
    """
    torch.cuda.set_device(args.gpu)
    dataset = NodeLevelDataset(args.dataset, ratio_valid_test=args.ratio_valid_test,args=args)
    args.d_data, args.n_cls = dataset.d_data, dataset.n_cls
    cls = [list(range(i, i + args.n_cls_per_task)) for i in range(0, args.n_cls-1, args.n_cls_per_task)]
    args.task_seq = cls
    args.n_tasks = len(args.task_seq)
    n_cls_so_far = 0
    # check whether the preprocessed data exist and can be loaded
    str_int_tsk = 'inter_tsk_edge' if args.inter_task_edges else 'no_inter_tsk_edge'
    for task, task_cls in enumerate(args.task_seq):
        n_cls_so_far += len(task_cls)
        try:
            if args.load_check:
                subgraph, ids_per_cls, [train_ids, valid_ids, test_ids] = pickle.load(open(
                    f'{args.data_path}/{str_int_tsk}/{args.dataset}_{task_cls}.pkl', 'rb'))
            else:
                if f'{args.dataset}_{task_cls}.pkl' not in os.listdir(f'{args.data_path}/{str_int_tsk}'):
                    subgraph, ids_per_cls, [train_ids, valid_ids, test_ids] = pickle.load(open(
                        f'{args.data_path}/{str_int_tsk}/{args.dataset}_{task_cls}.pkl', 'rb'))
        except:
            # if not exist or cannot be loaded correctly, create new processed data
            print(f'preparing data for task {task}')
            mkdir_if_missing(f'{args.data_path}/inter_tsk_edge')
            mkdir_if_missing(f'{args.data_path}/no_inter_tsk_edge')
            cls_retain = []
            for clss in args.task_seq[0:task + 1]:
                cls_retain.extend(clss)
            subgraph, ids_per_cls_all, [train_ids, valid_ids, test_ids] = dataset.get_graph(
                tasks_to_retain=cls_retain)
            with open(f'{args.data_path}/inter_tsk_edge/{args.dataset}_{task_cls}.pkl', 'wb') as f:
                pickle.dump([subgraph, ids_per_cls_all, [train_ids, valid_ids, test_ids]], f)

def pipeline_online_IL(args, valid=False):
    epochs = args.epochs if valid else 0
    torch.cuda.set_device(args.gpu)
    dataset = NodeLevelDataset(args.dataset, ratio_valid_test=args.ratio_valid_test, args=args)
    args.d_data, args.n_cls = dataset.d_data, dataset.n_cls
    cls = [list(range(i, i + args.n_cls_per_task)) for i in range(0, args.n_cls-1, args.n_cls_per_task)]
    args.task_seq = cls
    args.n_tasks = len(args.task_seq)

    model = get_model(dataset, args).cuda(args.gpu)
    life_model = importlib.import_module(f'Baselines.{args.method}_model')
    life_model_ins = life_model.NET(model, args) if valid else None

    n_layers = len(model.gat_layers)

    acc_matrix = np.zeros([args.n_tasks, args.n_tasks])
    meanas = []
    prev_model = None
    n_cls_so_far = 0
    all_accs_batch = []
    for task, task_cls in enumerate(args.task_seq):
        if task >= args.n_validation_tasks:
            continue
        name, ite = args.current_model_save_path
        config_name = name.split('/')[-1]
        subfolder_c = name.split(config_name)[-2]
        save_model_name = f'{config_name}_{ite}_{task_cls}'
        save_model_path = f'{args.result_path}/{subfolder_c}val_models/{save_model_name}.pkl'
        n_cls_so_far += len(task_cls)
        cls_retain = []
        for clss in args.task_seq[0:task+1]:
            cls_retain.extend(clss)
        num_nodes_task = sum([dataset.cls_sizes[task_cls[i]] for i in range(args.n_cls_per_task)])

        new_nodes_ids_sublists = [(task_cls[i], dataset.tr_va_te_split[task_cls[i]][j]) for i in range(args.n_cls_per_task) for j in range(3)]
        new_nodes_ids = [item for sublist in new_nodes_ids_sublists for item in sublist[1]]
        new_nodes_ids = shuffle_list(new_nodes_ids, random_seed=42)

        for stream_batch in range(0, num_nodes_task, args.n_nodes_per_batch):

            end = stream_batch+args.n_nodes_per_batch if stream_batch+args.n_nodes_per_batch < num_nodes_task else num_nodes_task
            stream_node_ids = new_nodes_ids[stream_batch:stream_batch+args.n_nodes_per_batch]
            
            subgraph, ids_per_cls_all, test_ids, train_ids_current_batch = dataset.update_subgraph(node_ids=stream_node_ids,
                                                        tasks_to_retain=cls_retain, valid=valid, device='cuda:{}'.format(args.gpu))
            if len(train_ids_current_batch) == 0:
                continue
            features, labels = subgraph.srcdata['feat'], subgraph.dstdata['label'].view(-1)
            torch.cuda.empty_cache()
            
            cls_ids_new = [cls_retain.index(i) for i in task_cls]
            ids_per_cls_current_batch = [ids_per_cls_all[i] for i in cls_ids_new]

            for epoch in range(epochs):
                life_model_ins.observe_minibatch(args, subgraph, features, labels, train_ids_current_batch,
                                                    ids_per_cls_current_batch)
            
            if args.method == 'lwf' and stream_batch % args.lwf_args['save_every'] == 0:
                prev_model = get_model(dataset, args).cuda(args.gpu)
                prev_model.load_state_dict(model.state_dict())
                life_model_ins.prev_model = prev_model

            # anytime evaluation
            if valid and args.end_batch_test:
                features, labels = subgraph.srcdata['feat'], subgraph.dstdata['label'].squeeze()
                label_offset1, label_offset2 = 0, n_cls_so_far
                accs_per_task = np.zeros(args.n_tasks)
                model.eval()
                with torch.no_grad():
                    for t in range(task + 1):
                        try:
                            cls_ids_new = [cls_retain.index(i) for i in args.task_seq[t]]
                            ids_per_cls_current_task = [ids_per_cls_all[i] for i in cls_ids_new]
                            ids_per_cls_test = [list(set(ids).intersection(set(test_ids))) for ids in ids_per_cls_current_task]
                            test_ids_t = [item for sublist in ids_per_cls_test for item in sublist]
                            labels_t = labels[test_ids_t]
                            batches = [test_ids_t[i:i + args.batch_size] for i in range(0, len(test_ids_t), args.batch_size)]
                            output = torch.tensor([]).cuda(args.gpu)
                            for batch in batches:
                                nb_sampler = dgl.dataloading.NeighborSampler(args.n_nbs_sample) if args.sample_nbs else dgl.dataloading.MultiLayerFullNeighborSampler(n_layers)
                                _, _, blocks = nb_sampler.sample_blocks(subgraph, torch.tensor(batch).to(device='cuda:{}'.format(args.gpu)))
                                input_features = blocks[0].srcdata['feat']
                                output_predictions, _ = model.forward_batch(blocks, input_features)
                                output = torch.cat((output,output_predictions),dim=0)
                                torch.cuda.empty_cache()
                            logits = output[:, label_offset1:label_offset2]
                            acc = accuracy(logits, labels_t, cls_balance=args.cls_balance, ids_per_cls=ids_per_cls_test)
                            accs_per_task[t] = acc
                        except:
                            print("No validation nodes yet")
                all_accs_batch.append(accs_per_task)

        # test
        label_offset1, label_offset2 = 0, n_cls_so_far
        features, labels = subgraph.srcdata['feat'], subgraph.dstdata['label'].squeeze()
        if not valid:
            try:
                model = pickle.load(open(save_model_path,'rb')).cuda(args.gpu)
            except:
                model.load_state_dict(torch.load(save_model_path.replace('.pkl','.pt')))
            os.remove(save_model_path)
        acc_mean = []

        model.eval()
        with torch.no_grad():
            for t in range(task + 1):
                cls_ids_new = [cls_retain.index(i) for i in args.task_seq[t]]
                ids_per_cls_current_task = [ids_per_cls_all[i] for i in cls_ids_new]
                ids_per_cls_test = [list(set(ids).intersection(set(test_ids))) for ids in ids_per_cls_current_task]
                test_ids_t = [item for sublist in ids_per_cls_test for item in sublist]
                labels_t = labels[test_ids_t]
                batches = [test_ids_t[i:i + args.batch_size] for i in range(0, len(test_ids_t), args.batch_size)]
                output = torch.tensor([]).cuda(args.gpu)
                for batch in batches:
                    nb_sampler = dgl.dataloading.NeighborSampler(args.n_nbs_sample) if args.sample_nbs else dgl.dataloading.MultiLayerFullNeighborSampler(n_layers)
                    _, _, blocks = nb_sampler.sample_blocks(subgraph, torch.tensor(batch).to(device='cuda:{}'.format(args.gpu)))
                    input_features = blocks[0].srcdata['feat']
                    output_predictions, _ = model.forward_batch(blocks, input_features)
                    output = torch.cat((output,output_predictions),dim=0)
                    torch.cuda.empty_cache()
                logits = output[:, label_offset1:label_offset2]
                acc = accuracy(logits, labels_t, cls_balance=args.cls_balance, ids_per_cls=ids_per_cls_test)
                acc_matrix[task][t] = round(acc * 100, 2)
                acc_mean.append(acc)
                print(f"T{t:02d} {acc * 100:.2f}|", end="")

        accs = acc_mean[:task + 1]
        meana = round(np.mean(accs) * 100, 2)
        meanas.append(meana)

        acc_mean = round(np.mean(acc_mean) * 100, 2)
        print(f"acc_mean: {acc_mean}", end=" ", flush=True)
        print()
        if valid:
            mkdir_if_missing(f'{args.result_path}/{subfolder_c}/val_models')
            try:
                with open(save_model_path, 'wb') as f:
                    pickle.dump(model, f) # save the best model for each hyperparameter composition
            except:
                torch.save(model.state_dict(), save_model_path.replace('.pkl','.pt'))
        prev_model = copy.deepcopy(model).cuda()

    print('AP: ', acc_mean)
    backward = []
    for t in range(args.n_tasks - 1):
        b = acc_matrix[args.n_tasks - 1][t] - acc_matrix[t][t]
        backward.append(round(b, 2))
    mean_backward = round(np.mean(backward), 2)
    print('AF: ', mean_backward)
    print('\n')

    if valid and args.end_batch_test:
        batch_accs = np.vstack(all_accs_batch)
        np.save(f'{args.result_path}/batch_accs/{save_model_name}.npy', batch_accs)

    return acc_mean, mean_backward, acc_matrix

def pipeline_joint(args, valid=False):
    args.method = 'joint_replay_all'
    epochs = args.epochs if valid else 0
    torch.cuda.set_device(args.gpu)
    dataset = NodeLevelDataset(args.dataset,ratio_valid_test=args.ratio_valid_test,args=args)
    args.d_data, args.n_cls = dataset.d_data, dataset.n_cls
    cls = [list(range(i, i + args.n_cls_per_task)) for i in range(0, args.n_cls-1, args.n_cls_per_task)]
    args.task_seq = cls
    args.n_tasks = len(args.task_seq)

    task_manager = semi_task_manager()

    model = get_model(dataset, args).cuda(args.gpu)
    life_model = importlib.import_module(f'Baselines.{args.method}')
    life_model_ins = life_model.NET(model, task_manager, args) if valid else None

    acc_matrix = np.zeros([args.n_tasks, args.n_tasks])
    meanas = []
    n_cls_so_far = 0
    data_prepare(args)
    for task, task_cls in enumerate(args.task_seq):
        n_cls_so_far += len(task_cls)
        task_manager.add_task(task, n_cls_so_far)

    for task, task_cls in enumerate(args.task_seq):
        name, ite = args.current_model_save_path
        config_name = name.split('/')[-1]
        subfolder_c = name.split(config_name)[-2]
        save_model_name = f'{config_name}_{ite}_{task_cls}'
        save_model_path = f'{args.result_path}/{subfolder_c}val_models/{save_model_name}.pkl'

        cls_retain = []
        for clss in args.task_seq[:task + 1]:
            cls_retain.extend(clss)

        subgraph, ids_per_cls_all, [train_ids, valid_ids_, test_ids_] = pickle.load(
            open(f'{args.data_path}/inter_tsk_edge/{args.dataset}_{task_cls}.pkl', 'rb'))
        test_ids = valid_ids_ if valid else test_ids_
        subgraph = subgraph.to(device='cuda:{}'.format(args.gpu))
        features, labels = subgraph.srcdata['feat'], subgraph.dstdata['label'].squeeze()

        for epoch in range(epochs):
            life_model_ins.observe(args, subgraph, features, labels, task, train_ids, ids_per_cls_all, dataset)

        if not valid:
            try:
                model = pickle.load(open(save_model_path,'rb')).cuda(args.gpu)
            except:
                model.load_state_dict(torch.load(save_model_path.replace('.pkl','.pt')))
        acc_mean = []
        label_offset1, label_offset2 = task_manager.get_label_offset(task)
        for t, cls_ids_new in enumerate(args.task_seq[0:task+1]):
            cls_ids_new = [cls_retain.index(i) for i in args.task_seq[t]]
            if cls_ids_new != args.task_seq[t]:
                print(
                    '-------------------------------sequence is not as default--------------------------------------------------------')
            ids_per_cls_current_task = [ids_per_cls_all[i] for i in cls_ids_new]
            ids_per_cls_test = [list(set(ids).intersection(set(test_ids))) for ids in ids_per_cls_current_task]
            acc = evaluate(model, subgraph, features, labels, test_ids, label_offset1, label_offset2,
                               cls_balance=args.cls_balance, ids_per_cls=ids_per_cls_test)
            acc_matrix[task][t] = round(acc * 100, 2)
            acc_mean.append(acc)
            print(f"T{t:02d} {acc * 100:.2f}|", end="", flush=True)

        accs = acc_mean[:task + 1]
        meana = round(np.mean(accs) * 100, 2)
        meanas.append(meana)

        acc_mean = round(np.mean(acc_mean) * 100, 2)
        print(f"acc_mean: {acc_mean}", end="", flush=True)
        print()
        if valid and args.perform_testing:
            mkdir_if_missing(f'{args.result_path}/{subfolder_c}/val_models')
            try:
                with open(save_model_path, 'wb') as f:
                    pickle.dump(model, f) # save the best model for each hyperparameter composition
            except:
                torch.save(model.state_dict(), save_model_path.replace('.pkl','.pt'))

    print('AP: ', acc_mean)
    backward = []
    for t in range(args.n_tasks - 1):
        b = acc_matrix[args.n_tasks - 1][t] - acc_matrix[t][t]
        backward.append(round(b, 2))
    mean_backward = round(np.mean(backward), 2)
    print('AF: ', mean_backward)
    print('\n')
    return acc_mean, mean_backward, acc_matrix
