from pathlib import Path
import logging
import random
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from sacred import Experiment

from dataloader.graphdist_dataset import GraphDistDataset
from dataloader.graphbatch_collator import GraphBatchCollator
from utils.metrics import Metrics
from utils.optimizer import add_weight_decay
from utils.training import train_model
from dataloader.io import load_from_npz
from utils import aggregation
from gnn import geometric_gnn
from gnn import label_scaling
import gtn

ex = Experiment()
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s (%(levelname)s): %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S')


@ex.automain
def run(model_type='gtn', gnn_type='gcn_geometric', device='cpu', graph_distance='ged',
        distance='sinkhorn', data_path='data', dataname='pref_att',
        learning_rate=0.01, l2_reg=5e-4, max_epochs=300, display_step=np.inf,
        emb_size=32, nlayers=3, emb_dist_p=2, dropout=0.0, act_fn='leaky_relu', aggregate='MLP',
        agg_degree_mean=True, sinkhorn_reg=0.2, sinkhorn_niter=500, sinkhorn_reg_stepval=None,
        patience=np.inf, prefetch=1, batch_size=100, virtual_batch_size=None, sparse_batching=False,
        optimizer='adam', run=0, bp_dist_matrix=True,
        nystrom=None, sparse=None,
        gnn_bilin='full', lr_stepsize=100, dist_loss='mse',
        matching_size=1, scale_embeddings=True, test=False,
        output_sim=False):

    my_dict = locals().copy()
    run_config = json.dumps(my_dict, indent=4, sort_keys=True)

    logging.basicConfig(
            format='%(asctime)s: %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            level=logging.DEBUG)
    logger = logging.getLogger(__name__)

    logger.info('Run config:' + run_config)

    if ex.current_run:
        seed = ex.current_run.config['seed']
    else:
        seed = 42
    logger.info("Seed: " + str(seed))
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True
    # torch.autograd.set_detect_anomaly(True)  # Use to debug backpropagation
    # torch.set_printoptions(threshold=1024)

    device = torch.device(device)
    graph_distance = graph_distance.lower()
    distance = distance.lower()
    dataname = dataname.lower()

    triplets = False
    gen_match_label = False
    return_matching = False
    # Get Dataloaders
    # data_path = Path.home() / "graph-distance" / "data"
    data_path = Path(data_path)
    gcolls = {dataset: load_from_npz(data_path / f"{dataname}_{graph_distance}_{dataset}.npz")
              for dataset in ['train', 'val', 'test']}
    node_feat_size = int(max((max((np.max(graph.attr_matrix)
                                   for graph in gcoll))
                              for gcoll in gcolls.values())) + 1)
    if gcolls['train'][0].edge_attr_matrix is None:
        edge_feat_size = 0
    else:
        edge_feat_size = int(max((max((np.max(graph.edge_attr_matrix)
                                       for graph in gcoll))
                                  for gcoll in gcolls.values())) + 1)
    datasets = {key: GraphDistDataset(gcoll, node_feat_size, edge_feat_size, edge_onehot=True)
                for key, gcoll in gcolls.items()}
    loss = dist_loss
    metrics_list = ['rmse', 'cvrmse', 'label_std']
    metric_to_stop_on = 'rmse'
    minimise_stop_on = True

    collator = GraphBatchCollator(sparse_batching=sparse_batching, triplets=triplets, gen_match_label=gen_match_label)
    dataloaders = {}
    for key, dataset in datasets.items():
        if key not in ['train', 'val', 'test']:
            loader = dataset
        elif isinstance(dataset, list):
            loader = [DataLoader(subset, batch_size=batch_size, shuffle=True, collate_fn=collator, num_workers=prefetch) for subset in dataset]
        else:
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collator, num_workers=prefetch)
        dataloaders[key] = loader

    # Get Metrics
    metrics_trackers = {'iter': {}, 'epoch': {}}
    if display_step != np.inf:
        metrics_trackers['iter']['train'] = Metrics(metrics_list)
    metrics_trackers['epoch']['train'] = Metrics(metrics_list)
    if display_step != np.inf:
        metrics_trackers['iter']['val'] = Metrics(metrics_list)
    metrics_trackers['epoch']['val'] = Metrics(metrics_list, metric_to_stop_on=metric_to_stop_on,
                                               minimise_stop_on=minimise_stop_on, patience=patience)
    if test:
        if display_step != np.inf:
            metrics_trackers['iter']['test'] = Metrics(metrics_list)
        metrics_trackers['epoch']['test'] = Metrics(metrics_list, metric_to_stop_on=metric_to_stop_on,
                                                    minimise_stop_on=minimise_stop_on, patience=patience)

    # Build Model
    if act_fn == 'linear':
        act_fn = lambda x: x
    elif act_fn == 'relu':
        act_fn = torch.nn.functional.relu
    elif act_fn == 'sigmoid':
        act_fn = torch.nn.functional.sigmoid
    elif act_fn == 'leaky_relu':
        act_fn = torch.nn.functional.leaky_relu
    else:
        raise ValueError(f"Invalid act_fn '{act_fn}'.")

    if aggregate == 'All':
        aggregate = aggregation.All()
    elif aggregate == 'Last':
        aggregate = aggregation.Last()
    elif aggregate == 'LayerAttention':
        aggregate = aggregation.LayerAttention(nlayers=nlayers, norm_weights=True, sum_embeddings=True)
    elif aggregate == 'FullAttention':
        aggregate = aggregation.FullAttention(emb_size=emb_size, nlayers=nlayers, norm_weights=True)
    elif aggregate == 'PPRWeighted':
        aggregate = aggregation.PPRWeighted(nlayers=nlayers, alpha=0.5)
    elif aggregate == 'Weighted':
        aggregate = aggregation.Weighted(nlayers=nlayers, weight_fn=lambda x: 3 - x/nlayers)
    elif aggregate == 'FullProjection':
        aggregate = aggregation.FullProjection(emb_size=emb_size, nlayers=nlayers, output_size=emb_size)
    elif aggregate == 'MLP':
        aggregate = aggregation.MLP(emb_size=emb_size, nlayers=nlayers, output_size=emb_size)
    else:
        raise ValueError(f"Invalid aggregate '{aggregate}'.")

    if model_type == 'gtn':
        if gnn_type == 'gcn_geometric':
            gnn = geometric_gnn.Net(
                    node_feat_size=node_feat_size, edge_feat_size=edge_feat_size, emb_size=emb_size, nlayers=nlayers,
                    aggregate=aggregate, device=device, dropout=dropout, act_fn=act_fn, bilin_type=gnn_bilin,
                    agg_degree_mean=agg_degree_mean)
        elif gnn_type == 'label':
            gnn = label_scaling.LabelDist(device)
        else:
            raise ValueError(f"Invalid gnn '{gnn_type}'.")

        emb_dist_scale = np.mean(gcolls['train'].dists.A) / np.mean([graph.num_nodes() for graph in gcolls['train']])

        model = gtn.GTN(
                gnn=gnn, emb_dist_scale=emb_dist_scale, sparse_batching=sparse_batching,
                distance=distance, device=device, p_norm=emb_dist_p,
                sinkhorn_reg=sinkhorn_reg, sinkhorn_niter=sinkhorn_niter, sinkhorn_reg_stepval=sinkhorn_reg_stepval,
                return_matching=return_matching, bp_dist_matrix=bp_dist_matrix,
                nystrom=nystrom, sparse=sparse,
                matching_size=matching_size, scale_embeddings=scale_embeddings)
    else:
        raise ValueError(f"Invalid model '{model_type}'.")

    parameters = add_weight_decay(model, weight_decay=l2_reg)
    if optimizer == 'adam':
        _optimizer = torch.optim.Adam(parameters, lr=learning_rate)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(_optimizer, step_size=lr_stepsize, gamma=0.1)
    elif optimizer == 'cyclic_sgd':
        _optimizer = torch.optim.SGD(parameters, lr=learning_rate, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(_optimizer, step_size_up=30, base_lr=learning_rate / 50., max_lr=learning_rate)
    elif optimizer == 'amsgrad':
        _optimizer = torch.optim.Adam(parameters, lr=learning_rate, amsgrad=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(_optimizer, step_size=lr_stepsize, gamma=0.1)
    else:
        raise ValueError(f"Invalid optimizer '{optimizer}'.")

    logger.info("Start training")
    result = train_model(
            model, loss, device, dataloaders, _optimizer, lr_scheduler, metrics_trackers, logger=logger, ex=ex,
            learning_rate=learning_rate, num_epochs=max_epochs, print_iter=display_step, config_str=run_config,
            virtual_batch_size=virtual_batch_size, test=test)

    del result['best_model_wts']
    return result
