import numpy as np
import torch
import logging
import resource
import time
from sacred import Experiment
import pickle
import seml
import os
from numba.typed import List
from data.data_preparation import check_consistence, load_data, graph_preprocess, get_partitions, get_ppr_mat
from data.customed_dataset import MYDataset
from neighboring import get_neighbors
from batching import get_loader
from models import DeeperGCN, GAT
from train.trainer import Trainer
from copy import deepcopy


ex = Experiment()
seml.setup_logger(ex)


@ex.post_run_hook
def collect_stats(_run):
    seml.collect_exp_stats(_run)


@ex.config
def config():
    overwrite = None
    db_collection = None
    if db_collection is not None:
        ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite))


@ex.automain
def run(dataset_name,
        graphmodel,
        mode,
        neighbor_sampling,
        diffusion_param,
        ppr_params,
        num_batches,
        hidden_channels,
        heads,
        part_topk,
        
        micro_batch = 1,
        batch_size = 1,
        small_trainingset = 1,
        batch_order={'ordered': False, 'sampled': False},
        store_adj=True,
        inference=True,
        
        n_sampling_params = None,
        rw_sampling_params = None,
        ladies_params = None,

        epoch_min=300,
        epoch_max=800,
        patience=100,
        lr=1e-3,
        num_layers=3):
    
    seed = np.random.choice(2**16)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    
    check_consistence(mode, neighbor_sampling, batch_order['ordered'], batch_order['sampled'])

    logging.info(
        f'dataset: {dataset_name}, graphmodel: {graphmodel}, mode: {mode}, neighbor_sampling: {neighbor_sampling}, '
        f'num_batches: {num_batches}, batch_order: {batch_order}, part_topk: {part_topk}, micro_batch: {micro_batch}, '
        f'rw_sampling_params: {rw_sampling_params}, n_sampling_params: {n_sampling_params}, ladies_params: {ladies_params}')

    device = 'cuda' if torch.cuda.is_available else 'cpu'
    N_sampling_params = deepcopy(n_sampling_params)

    # common preprocess
    cache_data = True
    if mode in ['n_sampling', 'rand', 'rw_sampling'] or 'ladies' in [mode, neighbor_sampling]:
        cache_data = False

    start_time = time.time()
    graph, (_, val_indices, test_indices) = load_data(dataset_name, small_trainingset)
    logging.info("Graph loaded!\n")
    disk_loading_time = time.time() - start_time

    if 'ladies' in [mode, neighbor_sampling]:
        merge_max_size = ladies_params['sample_size']
    elif 'ppr' in [mode, neighbor_sampling]:
        merge_max_size = ppr_params['merge_max_size'] if ppr_params['merge_max_size'] is not None else \
                        [graph.num_nodes // num_batches[0] + 1, 
                         graph.num_nodes // num_batches[1] + 1, 
                         graph.num_nodes // num_batches[2] + 1]
    else:
        merge_max_size = [graph.num_nodes // num_batches[0] + 1, 
                         graph.num_nodes // num_batches[1] + 1, 
                         graph.num_nodes // num_batches[2] + 1]
        
    if isinstance(ppr_params['neighbor_topk'], int):
        neighbor_topk = [ppr_params['neighbor_topk']] * 3
    else:
        neighbor_topk = ppr_params['neighbor_topk']

    start_time = time.time()
    graph_preprocess(graph)
    logging.info("Graph processed!\n")
    graph_preprocess_time = time.time() - start_time

    trainer = Trainer(mode,
                      neighbor_sampling,
                      num_batches, 
                      micro_batch=micro_batch,
                      batch_size=batch_size,
                      epoch_max=epoch_max,
                      epoch_min=epoch_min,
                      patience=patience)

    scipy_adj = graph.adj_t.to_scipy('csr')

    # train & val
    start_time = time.time()
    val_partitions = get_partitions(mode, graph.adj_t, num_batches[1])
    logging.info("Val partitioned!\n")
    
    if dataset_name == 'products':
        val_mat = get_ppr_mat(mode, neighbor_sampling, val_indices, scipy_adj, topk=neighbor_topk[1])
    
    logging.info("Val ppr mat!\n")

    val_neighbors = get_neighbors(mode, neighbor_sampling, val_indices, scipy_adj, val_mat, topk=neighbor_topk[1])
    logging.info("Val neighbors!\n")

    val_loader = get_loader(mode,
                            neighbor_sampling,
                            graph.adj_t,
                            graph.num_nodes,
                            merge_max_size[1],
                            num_batches[1],
                            num_layers,
                            diffusion_param,
                            N_sampling_params,
                            rw_sampling_params,
                            train=False,
                            partitions=val_partitions,
                            part_topk=part_topk[1],
                            prime_indices=val_indices,
                            neighbors=val_neighbors,
                            ppr_mat=val_mat)
    logging.info("Val loader!\n")

    val_prep_time = time.time() - start_time

    # inference
    start_time = time.time()
    if inference:
        if dataset_name == 'products':
            with open('/path/to/models', 'rb') as handle:
                test_mat = pickle.load(handle)
        else:
            test_mat = get_ppr_mat(mode, neighbor_sampling, test_indices, scipy_adj, topk=neighbor_topk[2])
        logging.info("Test ppr mat!\n")
        
        if num_batches[2] == num_batches[1] and val_partitions is not None:
            test_partitions = val_partitions
        else:
            test_partitions = get_partitions(mode, graph.adj_t, num_batches[2])
        logging.info("test partitioned!\n")
        
        test_neighbors = get_neighbors(mode, neighbor_sampling, test_indices, scipy_adj, test_mat, topk=neighbor_topk[2])

        test_loader = get_loader(mode,
                                 neighbor_sampling,
                                 graph.adj_t,
                                 graph.num_nodes,
                                 merge_max_size[2],
                                 num_batches[2],
                                 num_layers,
                                 diffusion_param,
                                 N_sampling_params,
                                 rw_sampling_params,
                                 train=False,
                                 partitions=test_partitions,
                                 part_topk=part_topk[1],
                                 prime_indices=test_indices,
                                 neighbors=test_neighbors,
                                 ppr_mat=test_mat)
        logging.info("Test loader!\n")
    else:
        test_loader = [None, None]
        test_neighbors = None

    infer_prep_time = time.time() - start_time

    # common preprocess
    start_time = time.time()
    dataset = MYDataset(graph.x.cpu().detach().numpy(),
                        graph.y.cpu().detach().numpy(),
                        scipy_adj,
                        train_loader=None,
                        val_loader=val_loader,
                        test_loader=test_loader,
                        batch_order=batch_order,
                        store_adj=store_adj,
                        cache=cache_data)
    caching_time = time.time() - start_time

#     return dataset
    if graphmodel == 'gcn':
        model = DeeperGCN(num_node_features=graph.num_node_features,
                          num_classes=graph.y.max().item() + 1,
                          hidden_channels=hidden_channels,
                          num_layers=num_layers).to(device)

    elif graphmodel == 'gat':
        model = GAT(in_channels=graph.num_node_features,
                    hidden_channels=hidden_channels,
                    out_channels=graph.y.max().item() + 1,
                    num_layers=num_layers,
                    heads=heads).to(device)

    for _file in os.listdir(f'../pretrained/{graphmodel}_{dataset_name}/'):
        no = _file.split('.')[0].split('_')[1]
        trainer.inference(dataset=dataset,
                          model=model,
                          val_nodes=val_indices,
                          test_nodes=test_indices,
                          adj=graph.adj_t,
                          x=graph.x,
                          y=graph.y,
                          file_dir='../pretrained',
                          comment=f'{graphmodel}_{dataset_name}',
                          run_no=no, 
                          full_infer=False, 
                          record_numbatch=True)
#         break

    results = {
        'seed': seed,
        'disk_loading_time': disk_loading_time,
        'graph_preprocess_time': graph_preprocess_time,
        'val_prep_time': val_prep_time,
        'infer_prep_time': infer_prep_time,
        'caching_time': caching_time,
        'gpu_memory': torch.cuda.max_memory_allocated(),
        'max_memory': 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    }

    for key, item in trainer.database.items():
        if key != 'training_curves':
            results[f'{key}_record'] = item
            item = np.array(item)
            results[f'{key}_stats'] = (item.mean(), item.std(),) if len(item) else (0., 0.,)

    return results
