import os
import pickle
import time
from collections import Counter
import matplotlib.pyplot as plt
import torch
import networkx as nx
import numpy as np
import envs.core as ising_env
from experiments.utils import test_network, load_graph_set
from model.modules import N2Node
from mpnn import MPNN
from envs.util import ErdosRenyiGraphGenerator,SetGraphGenerator
import cp_solver
from training_setting import args
try:
    import seaborn as sns
    plt.style.use('seaborn')
except ImportError:
    pass

def run(save_loc=r"./experiments/ER_n60_p15",
        graph_save_loc = r"./graphs/test/ER_n60_p15_100graphs.pkl",
        batched=True,
        max_batch_size=None):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))
    gamma = 0.95
    max_step = 60
    graphs_test = load_graph_set(graph_save_loc)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)
    print("Set torch default device to {}.".format(device))
    network =N2Node(T=args.nlayers,
           d_in=66,
           d_ein=0,
           d_model=args.d_model,
           nclass=1,
           q_dim=args.q_dim,
           n_q=args.n_q,
           n_c=1,
           n_pnode=30,
           task_type="single-class",
           dropout=args.dropout,
           self_loop=~args.wo_selfloop,
           pre_encoder=None,
           pos_encoder=None)

    model_path = r"./experiments/ER_n60_p15/network_N2_best_ac.pth"
    network.load_state_dict(torch.load(modelpath, map_location=device))
    network.to(device)
    for param in network.parameters():
        param.requires_grad = False
    network.eval()

    print("Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}".format(repr(network)))
    t0 = time.time()
    results = test_network(network, graphs_test, device, max_step,
                                                 return_raw=True, return_history=True,
                                                 batched=batched, max_batch_size=max_batch_size, IAP=False)
    print("solved_time:", (time.time()-t0)/100)
    print(sum(results)/len(results))
    print(Counter(results))
    eq = 0
    opt = cp_solver.get_opt(graph_save_loc)

    for i in range(len(opt)):
        if opt[i] >= results[i]:
            eq += 1
        else:
            print(i,end=",")

    accuracy = eq/len(results)
    print("acc",accuracy)


if __name__ == "__main__":
    run()