import numpy as np
import torch
import logging
import resource
import time
from sacred import Experiment
import pickle
import seml
import os
from numba.typed import List
from data.data_preparation import check_consistence, load_data, graph_preprocess, get_partitions, get_ppr_mat
from data.customed_dataset import MYDataset
from neighboring import get_neighbors
from neighboring.ppr_power_iteration import ppr_power_iter
from batching import get_loader
from models import DeeperGCN, GAT
from train.trainer import Trainer
from copy import deepcopy


ex = Experiment()
seml.setup_logger(ex)


@ex.post_run_hook
def collect_stats(_run):
    seml.collect_exp_stats(_run)


@ex.config
def config():
    overwrite = None
    db_collection = None
    if db_collection is not None:
        ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite))


@ex.automain
def run(dataset_name,
        graphmodel,
        mode,
        neighbor_sampling,
        diffusion_param,
        ppr_params,
        small_trainingset,
        batch_size,
        micro_batch,
        num_batches,
        batch_order, 
        part_topk,
        reg,
        hidden_channels,
        heads,
        
        store_adj,
        inference,
        LBMB_val,
        
        n_sampling_params = None,
        rw_sampling_params = None,
        ladies_params = None,

        epoch_min=300,
        epoch_max=800,
        patience=100,
        lr=1e-3,
        num_layers=3):
    
    seed = np.random.choice(2**16)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    
    check_consistence(mode, neighbor_sampling, batch_order['ordered'], batch_order['sampled'])

    logging.info(
        f'dataset: {dataset_name}, graphmodel: {graphmodel}, mode: {mode}, neighbor_sampling: {neighbor_sampling}, '
        f'num_batches: {num_batches}, batch_order: {batch_order}, part_topk: {part_topk}, micro_batch: {micro_batch}, '
        f'rw_sampling_params: {rw_sampling_params}, n_sampling_params: {n_sampling_params}, ladies_params: {ladies_params}')

    device = 'cuda' if torch.cuda.is_available else 'cpu'
    N_sampling_params = deepcopy(n_sampling_params)

    # common preprocess
    cache_data = True
    if mode in ['n_sampling', 'rand', 'rw_sampling'] or 'ladies' in [mode, neighbor_sampling]:
        cache_data = False

    start_time = time.time()
    graph, (train_indices, val_indices, test_indices) = load_data(dataset_name, small_trainingset)
#     return graph
#     print(len(train_indices), len(val_indices), len(test_indices))
    logging.info("Graph loaded!\n")
    disk_loading_time = time.time() - start_time

    if 'ladies' in [mode, neighbor_sampling]:
        merge_max_size = ladies_params['sample_size']
    elif 'ppr' in [mode, neighbor_sampling]:
        merge_max_size = ppr_params['merge_max_size'] if ppr_params['merge_max_size'] is not None else \
                        [graph.num_nodes // num_batches[0] + 1, 
                         graph.num_nodes // num_batches[1] + 1, 
                         graph.num_nodes // num_batches[2] + 1]
    else:
        merge_max_size = [graph.num_nodes // num_batches[0] + 1, 
                         graph.num_nodes // num_batches[1] + 1, 
                         graph.num_nodes // num_batches[2] + 1]
        
    if isinstance(ppr_params['neighbor_topk'], int):
        neighbor_topk = [ppr_params['neighbor_topk']] * 3
    else:
        neighbor_topk = ppr_params['neighbor_topk']

    start_time = time.time()
    graph_preprocess(graph)
    logging.info("Graph processed!\n")
    graph_preprocess_time = time.time() - start_time

    trainer = Trainer(mode,
                      neighbor_sampling,
                      num_batches, 
                      micro_batch=micro_batch,
                      batch_size=batch_size,
                      epoch_max=epoch_max,
                      epoch_min=epoch_min,
                      patience=patience)

    scipy_adj = graph.adj_t.to_scipy('csr')
    
#     return graph.adj_t

    comment = '_'.join([dataset_name,
                        graphmodel,
                        mode,
                        neighbor_sampling,
                        str(small_trainingset),
                        str(batch_size),
                        str(micro_batch),
                        str(merge_max_size[0]),
                        str(part_topk[0]),
                        str(store_adj)])

    # train & val
    start_time = time.time()
    train_partitions = get_partitions(mode, graph.adj_t, num_batches[0])
    logging.info("Train partitioned!\n")
    
    if num_batches[1] == num_batches[0] and train_partitions is not None:
        val_partitions = train_partitions
    else:
        val_partitions = get_partitions(mode, graph.adj_t, num_batches[1], force=LBMB_val)
    logging.info("Val partitioned!\n")
    
    if 'ppr' in [mode, neighbor_sampling]:
        neighbors, pprmat = ppr_power_iter(graph.adj_t, dataset_name, topk=neighbor_topk[0])

#     train_mat = get_ppr_mat(mode, neighbor_sampling, train_indices, scipy_adj)
    train_mat = pprmat[train_indices, :]
    logging.info("Train ppr mat!\n")
    
#     val_mat = get_ppr_mat(mode, neighbor_sampling, val_indices, scipy_adj)
    val_mat = pprmat[val_indices, :]
    logging.info("Val ppr mat!\n")

#     train_neighbors = get_neighbors(mode, neighbor_sampling, train_indices, scipy_adj, train_mat, topk=neighbor_topk[0])
    train_neighbors = list(neighbors[train_indices])
    logging.info("Train neighbors!\n")
    
#     val_neighbors = get_neighbors(mode, neighbor_sampling, val_indices, scipy_adj, val_mat, topk=neighbor_topk[1])
    val_neighbors = list(neighbors[val_indices])
    logging.info("Val neighbors!\n")

    train_loader = get_loader(mode,
                              neighbor_sampling,
                              graph.adj_t,
                              graph.num_nodes,
                              merge_max_size[0],
                              num_batches[0],
                              num_layers,
                              diffusion_param,
                              N_sampling_params,
                              rw_sampling_params,
                              train=True,
                              partitions=train_partitions,
                              part_topk=part_topk[0],
                              prime_indices=train_indices,
                              neighbors=train_neighbors,
                              ppr_mat=train_mat)
    logging.info("Train loader!\n")

    val_loader = get_loader(mode,
                            neighbor_sampling,
                            graph.adj_t,
                            graph.num_nodes,
                            merge_max_size[1],
                            num_batches[1],
                            num_layers,
                            diffusion_param,
                            N_sampling_params,
                            rw_sampling_params,
                            train=False,
                            partitions=val_partitions,
                            part_topk=part_topk[1],
                            prime_indices=val_indices,
                            neighbors=val_neighbors,
                            ppr_mat=val_mat)
    logging.info("Val loader!\n")

    train_prep_time = time.time() - start_time

    # inference
    start_time = time.time()
    if inference:
#         test_mat = get_ppr_mat(mode, neighbor_sampling, test_indices, scipy_adj)
        test_mat = pprmat[test_indices, :]
        logging.info("Test ppr mat!\n")
        
        if num_batches[2] == num_batches[0] and train_partitions is not None:
            test_partitions = train_partitions
        elif num_batches[2] == num_batches[1] and val_partitions is not None:
            test_partitions = val_partitions
        else:
            test_partitions = get_partitions(mode, graph.adj_t, num_batches[2])
        logging.info("test partitioned!\n")
        
#         test_neighbors = get_neighbors(mode, neighbor_sampling, test_indices, scipy_adj, test_mat, topk=neighbor_topk[2])
        test_neighbors = list(neighbors[test_indices])

        test_loader = get_loader(mode,
                                 neighbor_sampling,
                                 graph.adj_t,
                                 graph.num_nodes,
                                 merge_max_size[2],
                                 num_batches[2],
                                 num_layers,
                                 diffusion_param,
                                 N_sampling_params,
                                 rw_sampling_params,
                                 train=False,
                                 partitions=test_partitions,
                                 part_topk=part_topk[1],
                                 prime_indices=test_indices,
                                 neighbors=test_neighbors,
                                 ppr_mat=test_mat)
        logging.info("Test loader!\n")
    else:
        test_loader = [None, None]
        test_neighbors = None

    infer_prep_time = time.time() - start_time

    # common preprocess
    start_time = time.time()
    dataset = MYDataset(graph.x.cpu().detach().numpy(),
                        graph.y.cpu().detach().numpy(),
                        scipy_adj,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        test_loader=test_loader,
                        batch_order=batch_order,
                        store_adj=store_adj,
                        cache=cache_data)
    caching_time = time.time() - start_time

    stamp = ''.join(str(time.time()).split('.'))
    logging.info(f'model info: {comment}/model_{stamp}.pt')
    model = None
    if graphmodel == 'gcn':
        model = DeeperGCN(num_node_features=graph.num_node_features,
                          num_classes=graph.y.max().item() + 1,
                          hidden_channels=hidden_channels,
                          num_layers=num_layers).to(device)

    elif graphmodel == 'gat':
        model = GAT(in_channels=graph.num_node_features,
                    hidden_channels=hidden_channels,
                    out_channels=graph.y.max().item() + 1,
                    num_layers=num_layers,
                    heads=heads).to(device)

    if len(dataset.train_loader) > 1:
        trainer.train(dataset=dataset,
                      model=model,
                      lr=lr,
                      reg=reg,
                      train_nodes=train_indices,
                      val_nodes=val_indices,
                      comment=comment,
                      run_no=stamp)
    else:
        trainer.train_single_batch(dataset=dataset,
                                  model=model,
                                  lr=lr,
                                  reg=reg,
                                   val_per_epoch=5,
                                  comment=comment,
                                  run_no=stamp)

    logging.info(f'after train: {torch.cuda.memory_allocated()}')
    logging.info(f'after train: {torch.cuda.memory_reserved()}')
    
    gpu_memory = torch.cuda.max_memory_allocated()
    if inference:
        trainer.inference(dataset=dataset,
                          model=model,
                          val_nodes=val_indices,
                          test_nodes=test_indices,
                          adj=graph.adj_t,
                          x=graph.x,
                          y=graph.y,
                          comment=comment,
                          run_no=stamp)

    runtime_train_lst = []
    runtime_self_val_lst = []
    runtime_LBMB_val_lst = []
    for curves in trainer.database['training_curves']:
        runtime_train_lst += curves['per_train_time']
        runtime_self_val_lst += curves['per_self_val_time']
        runtime_LBMB_val_lst += curves['per_LBMB_val_time']

    results = {
        'seed': seed,
        'disk_loading_time': disk_loading_time,
        'graph_preprocess_time': graph_preprocess_time,
        'train_prep_time': train_prep_time,
        'infer_prep_time': infer_prep_time,
        'caching_time': caching_time,
        'runtime_train_perEpoch': sum(runtime_train_lst) / len(runtime_train_lst),
        'runtime_selfval_perEpoch': sum(runtime_self_val_lst) / len(runtime_self_val_lst),
        'runtime_LBMBval_perEpoch': sum(runtime_LBMB_val_lst) / len(runtime_LBMB_val_lst),
        'gpu_memory': gpu_memory,
        'max_memory': 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss,
        'curves': trainer.database['training_curves'],
        # ...
    }

    for key, item in trainer.database.items():
        if key != 'training_curves':
            results[f'{key}_record'] = item
            item = np.array(item)
            results[f'{key}_stats'] = (item.mean(), item.std(),) if len(item) else (0., 0.,)

    return results
