import os
import random
import logging
import pickle
import numpy as np
import torch

from time import time
from torch_geometric.utils import from_networkx
from torch.utils.tensorboard import SummaryWriter
from arrow_diff.utils import read_config_file, initialize_logging

from arrow_diff.trainer import Trainer
from arrow_diff.unet import UNetAdapter


def main() -> None:
    """
    Main function.
    """
    os.chdir('./arrow_diff')

    # Read config file
    config = read_config_file('./configs/config_multi_graph.yaml')

    dataset_name = config['data']['dataset']

    save_path = f'./ar_results/{dataset_name}'

    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    # Initialize logging
    initialize_logging(save_path, experiment_name='logging')

    logging.info(f'Config:\n{config}')

    # Load the training sub-graphs
    training_graphs_path = f'./GraphRNN_output_new/graphs/GraphRNN_RNN_{dataset_name}_4_64_train_0.dat'

    with open(training_graphs_path, "rb") as f:
        training_graphs_list = pickle.load(f)

    start_run = time()
    for i, subgraph in enumerate(training_graphs_list):

        seed = config['seed']

        if seed is not None:
            # Set seed for reproducibility
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
                torch.cuda.manual_seed_all(seed)
                torch.backends.cudnn.deterministic = True

        # # Transform to torch.geometric.Data object
        dataset = from_networkx(subgraph)

        # Set batch_size = num_walks to be sampled = num start nodes to sample from
        config['training']['batch_size'] = dataset.num_nodes
        # config['training']['batch_size'] = num_nodes

        logging.info(f'Dataset:\n{dataset}')

        # Initialize the SummaryWriter
        writer = SummaryWriter(log_dir=f'./event_writer_ours_ar/{dataset_name}_subgraph_{i}/', flush_secs=10)

        # Initialize the trainer
        trainer = Trainer(config['training'], writer=writer)

        # Initialize the model
        model = UNetAdapter(config['network']['hidden_channels'], dataset.num_nodes,
                            config['network']['node_embedding_dim'], config['training']['num_diffusion_steps'],
                            config['network']['time_embedding_dim'], num_res_blocks=config['network']['num_res_blocks'],
                            kernel_size=config['network']['kernel_size'])

        # Start the training
        loss_history, mean_likelihood_history_time_steps, first_importance_sampling_epoch = trainer.train(model, dataset)

        if writer:
            writer.close()

        model.cpu()

        # Save the model
        torch.save(model.state_dict(), f'{save_path}/{dataset_name}_subgraph_{i}_model.pt')

        # Save the loss history and the mean log-likelihood history for all time steps
        torch.save(loss_history, f'{save_path}/{dataset_name}_subgraph_{i}_loss_history.pt')
        torch.save(mean_likelihood_history_time_steps, f'{save_path}/{dataset_name}_subgraph_{i}_mean_likelihood_history_time_steps.pt')

    logging.info(f'\n Time to train all graphs {time() - start_run:.3f} seconds')


if __name__ == '__main__':
    main()
