import os
import time
from collections import Counter
import matplotlib.pyplot as plt
import torch
import networkx as nx
import numpy as np
import envs.core as ising_env
from experiments.utils import test_network, load_graph_set
from model.modules import N2Node
from mpnn import MPNN
from envs.util import ErdosRenyiGraphGenerator,SetGraphGenerator
import cp_solver
from training_setting import args
try:
    import seaborn as sns
    plt.style.use('seaborn')
except ImportError:
    pass

def run(save_loc="./experiments/ER_n200_p15/",  
        graph_save_loc = r"./graphs/test/ER_n200_p15_100graphs.pkl",
        batched=True,
        max_batch_size=None):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    gamma = 0.95
    max_step = 200

    graphs_test = load_graph_set(graph_save_loc)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    torch.device(device)
    print("Set torch default device to {}.".format(device))
    network = N2Node(T=3,
                     d_in=206,
                     d_ein=0,
                     d_model=args.d_model,
                     nclass=1,
                     q_dim=args.q_dim,
                     n_q=args.n_q,
                     n_c=1,
                     n_pnode=100,
                     task_type="single-class",
                     dropout=args.dropout,
                     self_loop=~args.wo_selfloop,
                     pre_encoder=None,
                     pos_encoder=None)

     model_path = r"./experiments/ER_n200_p15/network_N2_best_ac.pth"

    network.load_state_dict(torch.load(model_path, map_location=device))

    for param in network.parameters():
        param.requires_grad = False
    network.eval()
    network=network.to("cuda")
    print("Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}".format(repr(network)))


    t1 = time.time()
    results = test_network(network, graphs_test, device, max_step,
                                                 return_raw=True, return_history=True,
                                                 batched=batched, max_batch_size=max_batch_size, IAP=False)

    print("solve_time",(time.time()-t1)/100)
    print(Counter(results))
    print("avg colors:", sum(results)/len(results))
    eq = 0
    opt = cp_solver.get_opt(graph_save_loc)

    for i in range(len(opt)):
        if opt[i] >= results[i]:
            eq += 1
    accuracy = eq/len(results)
    print("acc",accuracy)

if __name__ == "__main__":
    run()