import os
import pickle
import time
from collections import Counter
import matplotlib.pyplot as plt
import torch
import networkx as nx
import numpy as np
import envs.core as ising_env
from experiments.utils import test_network, load_graph_set
from model.modules import N2Node
from mpnn import MPNN
from envs.util import ErdosRenyiGraphGenerator,SetGraphGenerator
import cp_solver
from training_setting import args

try:
    import seaborn as sns
    plt.style.use('seaborn')
except ImportError:
    pass

def run(save_loc=r"./experiments/ER_n40_p15",
        graph_save_loc = r"./graphs/validation/ER_n40_p15_100graphs.pkl", 
        batched=True,
        max_batch_size=None):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))
    gamma = 0.95
    max_step = 40
    graphs_test = load_graph_set(graph_save_loc)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)
    print("Set torch default device to {}.".format(device))

    network = N2Node(T=args.nlayers,
                     d_in=46,
                     d_ein=0,
                     d_model=args.d_model,
                     nclass=1,
                     q_dim=args.q_dim,
                     n_q=args.n_q,
                     n_c=1,
                     n_pnode=20,
                     task_type="single-class",
                     dropout=args.dropout,
                     self_loop=~args.wo_selfloop,
                     pre_encoder=None,
                     pos_encoder=None)

    model_path = r"./experiments/ER_n40_p15/network_N2_best.pth"
    network.load_state_dict(torch.load(model_path, map_location=device))
    network = network.to(device)

    for param in network.parameters():
        param.requires_grad = False
    network.eval()

    print("Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}".format(repr(network)))

    ####################################################
    # TEST NETWORK ON VALIDATION GRAPHS
    ####################################################
    t1 = time.time()
    results = test_network(network, graphs_test, device, max_step,
                                                 return_raw=True, return_history=True,
                                                 batched=batched, max_batch_size=max_batch_size, IAP=False)
    print("solved_time:", (time.time()-t1)/100)
    print(Counter(results))
    print("avg colors:", sum(results)/len(results))
    eq = 0
    opt = cp_solver.get_opt(graph_save_loc)

    for i in range(len(opt)):
        if opt[i] >= results[i]:
            eq += 1
        else:
            print("idx", i)
    accuracy = eq/len(results)
    print(accuracy)

if __name__ == "__main__":
    run()