import os
import random
import logging
import numpy as np
import pickle
import torch
import torch.nn.functional as F

from time import time
from torch_geometric.data import Data
from torch_geometric.utils import coalesce, degree, to_undirected
import networkx as nx
from torch_geometric.utils import from_networkx, to_networkx
from typing import TypeVar

from arrow_diff.utils import read_config_file, initialize_logging, save_config_to_file
from arrow_diff.unet import UNetAdapter

from arrow_diff.graph_generation.trainer import GNNTrainer
from arrow_diff.graph_generation.network import GCN
from baselines.DiGress.src.analysis.spectre_utils import degree_stats, clustering_stats, orbit_stats_all
from arrow_diff.graph_generation.graph_generation import generate_graph

# Create TypeVars for FloatTensor and LongTensor
FloatTensor = TypeVar('FloatTensor', torch.FloatTensor, torch.cuda.FloatTensor)
LongTensor = TypeVar('LongTensor', torch.LongTensor, torch.cuda.LongTensor)


def main() -> None:
    """
    Main function for sampling and building graphs for the two datasets community-small and citeseer-small and
    calculating graph statistics on these generated graphs.
    """
    os.chdir('./arrow_diff')
    # Read config file
    config = read_config_file('./configs/config_multi_graph.yaml')

    dataset_name = config['data']['dataset']

    save_path = f'./results/sampling/{dataset_name}'

    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    # Initialize logging
    initialize_logging(save_path, experiment_name='logging')

    logging.info(f'Config:\n{config}')

    # Save the config to a file
    save_config_to_file(config, save_path)

    seed = config['seed']

    if seed is not None:
        # Set seed for reproducibility
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True

    # Initialize the device
    device = torch.device('cuda' if config['training']['device'] == 'cuda' and torch.cuda.is_available() else 'cpu')

    training_graphs_path = f'./GraphRNN_output_new/graphs/GraphRNN_RNN_{dataset_name}_4_64_train_0.dat'
    with open(training_graphs_path, "rb") as f:
        training_graphs_list = pickle.load(f)

    # Load True/test graphs
    real_graphs_path = f'./GraphRNN_output_new/graphs/GraphRNN_RNN_{dataset_name}_4_64_test_0.dat'
    with open(real_graphs_path, "rb") as f:
        real_graphs_list_networkx = pickle.load(f)

    print(
        'Number of real/test graphs to be considered:' + str(dataset_name) + ',' + str(
            len(real_graphs_list_networkx)) + '\n')

    num_graphs = config['graph_generation']['num_samples']

    our_gen_graphs_from_all = []
    start_run_time = time()
    for i, subgraph in enumerate(training_graphs_list):
        logging.info(f'Sampling {num_graphs} graphs for subgraph {i}: \n')
        # Transform to torch.geometric.Data object
        dataset = from_networkx(subgraph)
        num_nodes = dataset.num_nodes

        # Set batch_size = num_walks to be sampled = num start nodes to sample from
        config['training']['batch_size'] = num_nodes

        # Initialize the model
        model = UNetAdapter(config['network']['hidden_channels'], num_nodes, config['network']['node_embedding_dim'],
                            config['training']['num_diffusion_steps'], config['network']['time_embedding_dim'],
                            num_res_blocks=config['network']['num_res_blocks'],
                            kernel_size=config['network']['kernel_size'])

        # Load the trained model
        model.load_state_dict(torch.load(f'./results/{dataset_name}/{dataset_name}_subgraph_{i}_model.pt'))

        model = model.to(device)
        model.eval()

        # Initialize the GNN trainer
        gnn_trainer = GNNTrainer(config['gnn_training'])

        if seed is not None:
            # Use a seed for reproducible initialization of the GCN
            torch.manual_seed(seed)

        # set the node features to be one-hot vector encoding
        dataset.x = F.one_hot(torch.arange(0, num_nodes)).type(torch.float)

        # Initialize the model
        gnn = GCN(dataset.num_node_features, config['gnn']['hidden_channels'], config['gnn']['out_channels'])

        # Training of the GNN

        logging.info('\n\n\nGNN Training:')

        start_time_gnn = time()

        loss_history_train, loss_history_val = gnn_trainer.train(gnn, dataset)

        logging.info(f'\n\n\nGNN Training finished in : {time() - start_time_gnn:.3f} seconds')

        gnn = gnn.cpu()

        # Save the model
        torch.save(gnn.state_dict(), f'{save_path}/{dataset_name}_gnn.pt')

        # Save the training and validation loss histories of the GNN
        torch.save({'train_loss': loss_history_train, 'val_loss': loss_history_val}, f'{save_path}/gnn_loss_history.pt')

        #############  Graph Generation  #################

        # Initialize the device
        device = torch.device('cuda' if config['training']['device'] == 'cuda' and torch.cuda.is_available() else 'cpu')

        # Initialize a generator
        generator = None if seed is None else torch.Generator(device=device)

        logging.info(f'\n\n\nGraph Generation:')

        gnn.eval()

        gnn.to(device)

        x = dataset.x.to(device)

        # TODO: While torch.any(max(0, deg_gt - deg) > 0)?
        num_steps = config['graph_generation']['num_steps']

        # Compute the node degrees of the nodes in the original graph
        # TODO: Training graph or complete graph?
        # deg_gt = degree(train_edge_index[0], num_nodes)
        deg_gt = degree(dataset.edge_index[0], num_nodes).to(device)

        our_generated_graphs = []
        with torch.no_grad():
            for j in range(1, config['graph_generation']['num_samples'] + 1):
                logging.info(f'\nGenerating graph {j}:')
                start_time = time()

                edge_index_pred = generate_graph(model, config['training']['batch_size'],
                                                 config['training']['random_walks']['walk_length'], gnn, x, deg_gt,
                                                 num_steps, device=device, seed=i * seed)

                logging.info(f'\t\tGenerating graph {j} for training graph {i} took : {time() - start_time:.3f} seconds')
                edge_index_pred = edge_index_pred.cpu()

                logging.info(f'Final graph contains {edge_index_pred.size(1)} edges')

                data_pred = Data(edge_index=edge_index_pred, num_nodes=num_nodes)

                G = to_networkx(data_pred, to_undirected=True, remove_self_loops=True)

                our_generated_graphs.append(G)
        our_gen_graphs_from_all.extend(our_generated_graphs)

    torch.save(our_gen_graphs_from_all, f'{save_path}/graphs.pt')
    logging.info(f'\n Time to generate all graphs from all models {time() - start_run_time:.3f} seconds')
    # Calculate the graph metric
    degree_ours = degree_stats(real_graphs_list_networkx, our_gen_graphs_from_all, is_parallel=True,
                               compute_emd=True)
    orbit_ours = orbit_stats_all(real_graphs_list_networkx, our_gen_graphs_from_all, compute_emd=True)
    clustering_ours = clustering_stats(real_graphs_list_networkx, our_gen_graphs_from_all, bins=100,
                                       is_parallel=True,
                                       compute_emd=True)

    logging.info(f'\n Metrics calculated for all generated graphs of all sub-graphs:')
    logging.info(f'Degree: {degree_ours}')
    logging.info(f'Orbit: {orbit_ours}')
    logging.info(f'Clustering: {clustering_ours}\n')


if __name__ == '__main__':
    main()
