import os
import yaml
import torch
import networkx as nx
import numpy as np
from dotmap import DotMap
from data_utils import get_pyg_data
from trainer import get_encoder_trainer, get_decoder_trainer
from eval_utils import evaluate_model

from encoder.encoder_factory import EncoderFactory
from decoder.decoder_factory import DecoderFactory

from sampler import Sampler
import matplotlib.pyplot as plt
import logging
from datetime import datetime
import pickle
import random


logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s:%(message)s'
)

def set_seed(seed: int = 42):
    random.seed(seed)                        
    np.random.seed(seed)                     
    torch.manual_seed(seed)                  
    torch.cuda.manual_seed(seed) 

def read_config():
    with open("./config.yaml", "r") as f:
        config = yaml.safe_load(f)

    config = DotMap(config)
    return config


if __name__ == "__main__":
    set_seed()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    current_time = datetime.now().strftime('%H-%M')

    config = read_config()
    graphs, data_list = get_pyg_data(config)
    all_generated = []
    degree_list, clustering_list, spectral_list, orbit_list, motif_list = [], [], [], [], []

    for i,graph in enumerate(graphs):

        config.construction.avg_degree = (2 * graph.number_of_edges()) / graph.number_of_nodes()
        config.num_nodes = graph.number_of_nodes()

        # get encoder
        encoder_factory = EncoderFactory()
        encoder = encoder_factory.get_encoder(config, data_list[i])

        encoder_trainer = get_encoder_trainer(encoder, config, graph)
        encoder_trainer.train()

        # after training, get embeddings
        embeddings = encoder_trainer.get_final_embeddings()

        # get decoder
        decoder_factory = DecoderFactory()

        if config.decoder_model == "mlp":
            decoder = decoder_factory.get_decoder(config, embeddings)
        else:
            decoder = decoder_factory.get_decoder(config)

        decoder_trainer = get_decoder_trainer(decoder, config, embeddings)
        decoder = decoder_trainer.train(data_list[i])

        sampler = Sampler(embeddings, config)

        generated_graphs = []
        folder_path = f"./generated_figures/{config.graph}_{config.encoder_model}_{config.sampler_model}_{config.decoder_model}"
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        print("Generation started.")

        for i in range(config.graphs_to_generate):

            sampled = sampler.sample() # sample new node embeddings based on original ones
            sampled = torch.from_numpy(sampled).to(device)

            adj_pred = decoder.construct(sampled) # construct graph topology

            if config.decoder_model == "dot_product":
                threshold = config.decoder.dot_product.threshold
            else:
                threshold = config.decoder.mlp.train.threshold
                    
            adj_binary = (adj_pred > threshold).int()
            adj_numpy = adj_binary.cpu().numpy()

            np.fill_diagonal(adj_numpy, 0)

            G = nx.from_numpy_array(adj_numpy)
            G.remove_nodes_from(list(nx.isolates(G)))

            if config.keep_largest_cc == True:
                G = G.subgraph(max(list(nx.connected_components(G)), key=len))

            plt.figure()
            nx.draw(G, node_size=20)

            filename = f"{folder_path}/{current_time}_{i}.png"
            plt.savefig(filename)
            plt.close() 

            generated_graphs.append(G)
        all_generated.extend(generated_graphs)

    metrics_list = ["spectral", "orbit", "motif"]
    mmd_scores = evaluate_model(graphs, all_generated, metrics_list)
    spectral_list.append(mmd_scores["spectral"])
    orbit_list.append(mmd_scores["orbit"])
    motif_list.append(mmd_scores["motif"])

    print(f"Spectral  : {mmd_scores["spectral"]:.4f}")
    print(f"Orbit     : {mmd_scores["orbit"]:.4f}")
    print(f"Motif     : {mmd_scores["motif"]:.4f}")

    