import __init__
import torch
import torch.optim as optim
import torch.nn.functional as F
import statistics
import os
import random
from dataset import OGBNDataset, OGBNDatasetInductive
# from model import DeeperGCN
# from model_geq import DEQGCN
from model_rev import RevGCN, JKRevGCN, MLP, JKNonlinearRevGCN
# from model_revwt import WTRevGCN
from args import ArgsInit
import time
import numpy as np
from scipy.special import expit
from ogb.nodeproppred import Evaluator
from utils.ckpt_util import save_ckpt
from utils.data_util import intersection, process_indexes
import logging
from torch.utils.tensorboard import SummaryWriter

from ipdb import set_trace as stc
from tqdm import tqdm


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train(args, data, dataset, targets_t, embs_t, model, optimizer, criterion_gl, criterion_sl, device, epoch=-1):
    r_e, r_t = args.r_e, args.r_t
    loss_list = []
    model.train()
    sg_nodes, sg_edges, sg_edges_index, _ = data

    train_y = dataset.y[dataset.train_idx]
    train_targets_t = targets_t[dataset.train_idx]
    train_embs_t = embs_t[dataset.train_idx]

    idx_clusters = np.arange(len(sg_nodes))
    np.random.shuffle(idx_clusters)

    for idx in idx_clusters:

        x = dataset.x[sg_nodes[idx]].float().to(device)
        sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device)

        sg_edges_ = sg_edges[idx].to(device)
        sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device)

        mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])}

        inter_idx = intersection(sg_nodes[idx], dataset.train_idx.tolist())
        training_idx = [mapper[t_idx] for t_idx in inter_idx]

        optimizer.zero_grad()

        with torch.autograd.set_detect_anomaly(True):
            pred, embs_norelu, embs_relu = model(x, sg_nodes_idx, sg_edges_, sg_edges_attr, epoch=epoch)
            pred = pred.log_softmax(dim=1)

            target = train_y[inter_idx].to(device)
            target_t = train_targets_t[inter_idx].to(device)
            emb_t = train_embs_t[inter_idx].to(device)
            num_n, dim_h = emb_t.shape[0], emb_t.shape[1]

            # loss_gl = criterion_gl(pred[training_idx].to(torch.float32), target.to(torch.float32))
            loss_gl = criterion_gl(pred[training_idx], target)
            if r_t != 0:
                # loss_t = criterion_sl(pred[training_idx].to(torch.float32), target_t.to(torch.float32))
                loss_t = criterion_sl(pred[training_idx], target_t)
                if r_e != 0:
                    if args.embs_relu:
                        loss_e = ((embs_relu[training_idx] - emb_t).norm(p=2)**2) / (num_n * dim_h)
                    else:
                        loss_e = ((embs_norelu[training_idx] - emb_t).norm(p=2)**2) / (num_n * dim_h)
                    loss = (1 / (1 + r_t + r_e)) * loss_gl + (r_t / (1 + r_t + r_e)) * loss_t + (r_e / (1 + r_t + r_e)) * loss_e
                else:
                    loss = (1 / (1 + r_t)) * loss_gl + (r_t / (1 + r_t)) * loss_t
            else:
                loss = loss_gl
            loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        loss_list.append(loss.item())

    return statistics.mean(loss_list)

@torch.no_grad()
def multi_evaluate(valid_data_list, dataset, model, evaluator, device, epoch=-1):
    model.eval()
    target = dataset.y.detach().numpy()

    train_pre_ordered_list = []
    valid_pre_ordered_list = []
    test_pre_ordered_list = []

    test_idx = dataset.test_idx.tolist()
    train_idx = dataset.train_idx.tolist()
    valid_idx = dataset.valid_idx.tolist()

    for valid_data_item in valid_data_list:
        sg_nodes, sg_edges, sg_edges_index, _ = valid_data_item
        idx_clusters = np.arange(len(sg_nodes))

        test_predict = []
        test_target_idx = []

        train_predict = []
        valid_predict = []

        train_target_idx = []
        valid_target_idx = []

        for idx in idx_clusters:
            x = dataset.x[sg_nodes[idx]].float().to(device)
            sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device)

            mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])}
            sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device)

            inter_tr_idx = intersection(sg_nodes[idx], train_idx)
            inter_v_idx = intersection(sg_nodes[idx], valid_idx)

            train_target_idx += inter_tr_idx
            valid_target_idx += inter_v_idx

            tr_idx = [mapper[tr_idx] for tr_idx in inter_tr_idx]
            v_idx = [mapper[v_idx] for v_idx in inter_v_idx]

            pred, _, _ = model(x, sg_nodes_idx, sg_edges[idx].to(device),
                         sg_edges_attr, epoch=epoch)
            
            pred = pred.cpu().detach()

            train_predict.append(pred[tr_idx])
            valid_predict.append(pred[v_idx])

            inter_te_idx = intersection(sg_nodes[idx], test_idx)
            test_target_idx += inter_te_idx

            te_idx = [mapper[te_idx] for te_idx in inter_te_idx]
            test_predict.append(pred[te_idx])

        train_pre = torch.cat(train_predict, 0).numpy()
        valid_pre = torch.cat(valid_predict, 0).numpy()
        test_pre = torch.cat(test_predict, 0).numpy()

        train_pre_ordered = train_pre[process_indexes(train_target_idx)]
        valid_pre_ordered = valid_pre[process_indexes(valid_target_idx)]
        test_pre_ordered = test_pre[process_indexes(test_target_idx)]

        train_pre_ordered_list.append(train_pre_ordered)
        valid_pre_ordered_list.append(valid_pre_ordered)
        test_pre_ordered_list.append(test_pre_ordered)

    train_pre_final = torch.mean(torch.Tensor(train_pre_ordered_list), dim=0)
    valid_pre_final = torch.mean(torch.Tensor(valid_pre_ordered_list), dim=0)
    test_pre_final = torch.mean(torch.Tensor(test_pre_ordered_list), dim=0)

    eval_result = {}

    input_dict = {"y_true": target[train_idx].reshape(-1,1), "y_pred": train_pre_final.argmax(1, keepdim=True)}
    eval_result["train"] = evaluator.eval(input_dict)

    input_dict = {"y_true": target[valid_idx].reshape(-1,1), "y_pred": valid_pre_final.argmax(1, keepdim=True)}
    eval_result["valid"] = evaluator.eval(input_dict)

    input_dict = {"y_true": target[test_idx].reshape(-1,1), "y_pred": test_pre_final.argmax(1, keepdim=True)}
    eval_result["test"] = evaluator.eval(input_dict)

    return eval_result


def single_run(args):
    set_seed(args.seed)
    args.device = 0
    logging.getLogger().setLevel(logging.INFO)
    writer = SummaryWriter(log_dir=args.save)

    if args.use_gpu:
        device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    else:
        device = torch.device("cpu")

    logging.info('%s' % device)

    dataset = OGBNDataset(args, dataset_name=args.dataset)
    dataset_ind = OGBNDatasetInductive(args, dataset_name=args.dataset)

    teacher_dir = './arxiv/examples/ogb_eff/ogbn_arxiv/teacher'
    print(f'Loading the teacher model\'s logits...', flush=True, end=' ')
    t = time.perf_counter()
    logits_t = torch.from_numpy(np.load(os.path.join(teacher_dir, 'logits.npy')))
    print(f'Done! [{time.perf_counter() - t:.2f}s]')

    print(f'Loading the teacher model\'s hidden embeddings...', flush=True, end=' ')
    t = time.perf_counter()
    if args.embs_relu == True:
        embs_t = np.load(os.path.join(teacher_dir, 'embs_t.npy'))
    else:
        embs_t = np.load(os.path.join(teacher_dir, 'embs_t.npy'))
    embs_t = torch.from_numpy(embs_t)
    args.teacher_channels = embs_t.shape[1]
    print(f'Done! [{time.perf_counter() - t:.2f}s]')

    if args.dataset == 'ogbn-proteins':
        nf_path = dataset.extract_node_features(args.aggr)
        nf_path_ind = dataset_ind.extract_node_features(args.aggr)
        args.nf_path = nf_path
        args.nf_path_ind = nf_path_ind
    args.num_tasks = dataset.num_tasks
    args.num_features = dataset.num_features
    args.num_classes = dataset.num_classes
    
    logging.info('%s' % args)

    evaluator = Evaluator(args.dataset)
    criterion_gl = torch.nn.NLLLoss()
    criterion_sl = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)

    valid_data_list = []

    for i in range(args.num_evals):
        parts = dataset.random_partition_graph(dataset.total_no_of_nodes,
                                               cluster_number=args.valid_cluster_number)
        valid_data = dataset.generate_sub_graphs(parts,
                                                 cluster_number=args.valid_cluster_number)
        valid_data_list.append(valid_data)
    

    sub_dir = 'random-train_{}-test_{}-num_evals_{}'.format(args.cluster_number,
                                                            args.valid_cluster_number,
                                                            args.num_evals)
    logging.info(sub_dir)

    if args.backbone == 'deepergcn':
        # model = DeeperGCN(args).to(device)
        pass
    # elif args.backbone == 'deq':
        # model = DEQGCN(args).to(device)
    # elif args.backbone == 'revwt':
        # model = WTRevGCN(args).to(device)
    elif args.backbone == 'rev':
        model = RevGCN(args).to(device)
    elif args.backbone == 'jkrev':
        model = JKRevGCN(args).to(device)
    elif args.backbone == 'mlp':
        model = MLP(args).to(device)
    elif args.backbone == 'jkrevnonlinear':
        model = JKNonlinearRevGCN(args).to(device)
    else:
        raise Exception("unkown backbone")

    logging.info('# of params: {}'.format(sum(p.numel() for p in model.parameters())))
    logging.info('# of learnable params: {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    results = {'highest_valid': 0,
               'final_train': 0,
               'final_test': 0,
               'highest_train': 0}

    start_time = time.time()

    for epoch in range(1, args.epochs + 1):
        # do random partition every epoch
        train_parts = dataset_ind.random_partition_graph(
            dataset_ind.total_no_of_nodes,
            cluster_number=args.cluster_number,
        )
        data = dataset_ind.generate_sub_graphs(
            train_parts,
            cluster_number=args.cluster_number,
        )

        epoch_loss = train(
            args,
            data,
            dataset_ind,
            logits_t,
            embs_t,
            model,
            optimizer,
            criterion_gl,
            criterion_sl,
            device,
            epoch=epoch,
        )
        logging.info('Epoch {}, training loss {:.4f}'.format(epoch, epoch_loss))
        if epoch == 1:
            peak_memuse = torch.cuda.max_memory_allocated(device) / float(1024 ** 3)
            logging.info('Peak memuse {:.2f} G'.format(peak_memuse))
        torch.cuda.empty_cache()

        if args.backbone != 'mlp':
            model.print_params(epoch=epoch)

        with torch.cuda.amp.autocast():
            result = multi_evaluate(valid_data_list, dataset, model,
                                    evaluator, device, epoch=epoch)

        train_result = result['train']['acc']
        valid_result = result['valid']['acc']
        test_result = result['test']['acc']

        if epoch % 5 == 0:
            logging.info(f'train: {train_result * 100:.2f}, valid: {valid_result * 100:.2f}, test: {test_result * 100:.2f}.')

        writer.add_scalar('stats/train_acc', train_result, epoch)
        writer.add_scalar('stats/valid_acc', valid_result, epoch)
        writer.add_scalar('stats/test_acc', test_result, epoch)

        if valid_result > results['highest_valid']:
            results['highest_valid'] = valid_result
            results['final_train'] = train_result
            results['final_test'] = test_result

            save_ckpt(model, optimizer, round(epoch_loss, 4),
                      epoch,
                      args.model_save_path, sub_dir,
                      name_post='valid_best')

        if train_result > results['highest_train']:
            results['highest_train'] = train_result

    r1 = results['final_train'] * 100
    r2 = results['highest_valid'] * 100
    r3 = results['final_test'] * 100
    r4 = results['highest_train'] * 100
    logging.info(f'final_train: {r1:.2f}, highest_valid: {r2:.2f}, final_test: {r3:.2f}, highest_train: {r4:.2f}')

    end_time = time.time()
    total_time = end_time - start_time
    logging.info('Total time: {}'.format(time.strftime('%d-%H:%M:%S', time.gmtime(total_time))))
    return results


def main():
    args = ArgsInit().save_exp()
    print(args)
    results_list = []
    
    if args.runs == 1:
        print(f'********** There is 1 run. **********')
        args.seed = 0
        results_list.append(single_run(args))
    else:
        print(f'********** There are {args.runs} runs. **********')
        for seed in range(args.runs):
            args.seed = seed
            print(f'********** The seed is {args.seed}. **********')
            results_list.append(single_run(args))
    
    final_test_list = []
    for result in results_list:
        print(result)
    
    for result in results_list:
        final_test_list.append(result['final_test'])
    final_test_list = np.array(final_test_list)
    print(f'Mean: {final_test_list.mean() * 100: .2f}, Std: {final_test_list.std() * 100:.2f}')

    


if __name__ == "__main__":
    main()
