import os
import pickle
import time
from collections import Counter
import matplotlib.pyplot as plt
import torch
import networkx as nx
import numpy as np
import envs.core as ising_env
from experiments.utils import test_network, load_graph_set
from model.modules import N2Node
from mpnn import MPNN
from envs.util import ErdosRenyiGraphGenerator,SetGraphGenerator
import cp_solver
from training_setting import args

try:
    import seaborn as sns
    plt.style.use('seaborn')
except ImportError:
    pass

def run(save_loc="./experiments/BA_n200_m4",  
        graph_save_loc = r"./graphs/test/BA_n200_m4_100graphs.pkl",
        batched=True,
        max_batch_size=None):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    data_folder = os.path.join(save_loc, 'network_best')
    network_folder = os.path.join(save_loc, 'network')

    print("network_best folder :", data_folder)
    print("network folder :", network_folder)

    test_save_path = os.path.join(network_folder, 'test_scores.pkl')
    network_save_path = os.path.join(save_loc, 'network_best.pth')

    print("network params :", network_save_path)


    gamma = 0.95
    max_step = 200

    graphs_test = load_graph_set(graph_save_loc)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)
    print("Set torch default device to {}.".format(device))

    network = N2Node(T=args.nlayers,
                     d_in=206,
                     d_ein=0,
                     d_model=args.d_model,
                     nclass=1,
                     q_dim=args.q_dim,
                     n_q=args.n_q,
                     n_c=1,
                     n_pnode=30,
                     task_type="single-class",
                     dropout=args.dropout,
                     self_loop=~args.wo_selfloop,
                     pre_encoder=None,
                     pos_encoder=None)

    model_path = r".\experiments\BA_n200_m4\network_N2_best_ac.pth"
    network.load_state_dict(torch.load(model_path, map_location=device))
    network = network.to(device)

    for param in network.parameters():
        param.requires_grad = False
    network.eval()

    print("Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}".format(repr(network)))

    t1 = time.time()
    results = test_network(network, graphs_test, device, max_step,
                                                 return_raw=True, return_history=True,
                                                 batched=batched, max_batch_size=max_batch_size, IAP=False)
    print("solve_time:",(time.time()-t1)/100)
    print(Counter(results))
    print("avg colors:", sum(results)/len(results))
    eq = 0
    opt = cp_solver.get_opt(graph_save_loc)

    for i in range(len(opt)):
        if opt[i] >= results[i]:
            eq += 1
        else:
            print(i,end=",")
    accuracy = eq/len(results)
    print("acc:",accuracy)

    
if __name__ == "__main__":
    run()