import torch
from dqn import DQN
import envs.core as ising_env
from envs.util import ErdosRenyiGraphGenerator,SetGraphGenerator
from experiments.utils import load_graph_set
from model.modules import N2Node
from mpnn import MPNN
from gcn import GCN
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt
from training_setting import args

save_loc = r"./experiments/ER_n60_p15/"
network_folder = os.path.join(save_loc, 'network')
loss_save_path = os.path.join(network_folder, 'losses.pkl')

max_step = 60
node_num = 60  
p_connecting = 0.15

train_graph_generator = ErdosRenyiGraphGenerator(node_num=node_num, p=p_connecting)

graph_save_loc = r"./graphs/validation/ER_n60_p15_50graphs.pkl"   
graphs_test = load_graph_set(graph_save_loc)
n_test = len(graphs_test)
test_graph_generator = SetGraphGenerator(graphs_test, ordered=True)

train_envs = [ising_env.make(max_step, node_num, train_graph_generator)]
test_envs = [ising_env.make(max_step, node_num, test_graph_generator)]


model = lambda: N2Node(T=args.nlayers,
           d_in=6,
           d_ein=0,
           d_model=args.d_model,
           nclass=1,
           q_dim=args.q_dim,
           n_q=args.n_q,
           n_c=1,
           n_pnode=args.n_pnode,
           task_type="single-class",
           dropout=args.dropout,
           self_loop=~args.wo_selfloop,
           pre_encoder=None,
           pos_encoder=None)


if __name__ == '__main__':

    agent = DQN(envs=train_envs,
                path=save_loc,
                network=model,
                test_episodes=n_test,
                test_envs=test_envs,
                graph_save_loc=graph_save_loc,
                IAP=False,
                init_network_params=None,
                # init_weight_std=0.01,
                # DQN parameters
                gamma=0.99,
                double_dqn=True,
                update_exploration=True,
                initial_exploration_rate=1,
                final_exploration_rate=0.05,
                final_exploration_step=300000,
                update_target_frequency=1000,

                # Test
                evaluate=True,
                test_frequency=10000,  # 2000
                test_score_save_path='test_scores',
                test_color_save_path='color_nums',
                test_accurary_path='test_accurary',
                # Replay buffer
                replay_start_size=500,  # 50000
                replay_buffer_size=20000,  # 1000000
                minibatch_size=64,
                update_frequency=16,

                # learning rate
                update_learning_rate=True,
                initial_learning_rate=0,
                peak_learning_rate=1e-3,
                peak_learning_rate_step=10000,
                final_learning_rate=5e-5,
                final_learning_rate_step=200000,

                # regularization
                weight_decay=0,

                # Loss function
                adam_epsilon=1e-8,
                )
    agent.learn(2000000)

    ############
    # PLOT - losses
    ############
    data = pickle.load(open(loss_save_path, 'rb'))
    data = np.array(data)

    fig_fname = os.path.join(network_folder, "loss")

    N = 50
    data_x = np.convolve(data[:, 0], np.ones((N,)) / N, mode='valid')
    data_y = np.convolve(data[:, 1], np.ones((N,)) / N, mode='valid')

    plt.plot(data_x, data_y)
    plt.xlabel("Timestep")
    plt.ylabel("Loss")

    plt.yscale("log")
    plt.grid(True)

    plt.savefig(fig_fname + ".png", bbox_inches='tight')
    plt.savefig(fig_fname + ".pdf", bbox_inches='tight')
    plt.show()


