import torch
import numpy as np
import sys, os, argparse
from os import path as osp
from continuous.models.irl_agent import MLPReward, MeshReward
import json
import gym
from continuous import envs
from continuous.models.sac import ReplayBuffer, SAC
from continuous.utils import system, collect
from continuous.plots import train_plot as plot
from sklearn import mixture, neighbors
from matplotlib import pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default='')
parser.add_argument('--epoch', type=str, default='475')

args = parser.parse_args()

data_path = [
            # 'continuous/logs/fkl/5-5-add-t-goal_kde/2020_05_05_15_53_04_1',
            # 'continuous/logs/rkl/5-5-add-t-goal_kde/2020_05_05_16_17_41_1',
            # 'continuous/logs/maxentirl/5-5-add-t-goal_kde/2020_05_06_00_21_25_1',

            # 'continuous/logs/fkl/5-5-add-t-uniform_kde/2020_05_05_03_47_51_1',
            # 'continuous/logs/rkl/5-5-add-t-uniform_kde/2020_05_05_15_51_06_1',
            # 'continuous/logs/maxentirl/5-5-add-t-uniform_kde/2020_05_06_00_21_19_1',

            # 'continuous/logs/fkl/5-5-add-t-multi_goal_kde/2020_05_06_00_22_10_1',
            # 'continuous/logs/rkl/5-5-add-t-multi_goal_kde/2020_05_06_00_22_16_1',    
            'continuous/logs/maxentirl/5-5-add-t-multi_goal_kde/2020_05_06_16_01_49_1',   
]



for path in data_path:
    print("path: ", path)
    v_path = osp.join(path, "variant.json")
    v = json.load(open(v_path, 'r'))
    if v.get('random_born') is None:
        v['random_born'] = False
    if v.get('goal_radius') is None:
        v['goal_radius'] = 0.5
    if v.get('goal') is None:
        v['goal'] = (2, 2)
    if v.get('add_time') is None:
        v['add_time'] = False


    env_name = "ContinuousVecGridEnv-v0"
    gym_env = gym.make(env_name, T=v['T'], random_born=v['random_born'], add_time=v['add_time'])
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    reward_func = MLPReward(2, hidden_sizes=v['reward_hidden_sizes'], device=device)
    reward_path = osp.join(path, 'reward_model_{}.pkl'.format(args.epoch))
    reward_func.load_state_dict(torch.load(reward_path))
    print("reward model loaded from: ", reward_path)

    sac_train_epoch = 1000
    steps_per_epoch = v['T']
    seed = v['seed']
    sac_random_explore_episodes = 100
    batch_size = 64
    sac_lr = 3e-3
    sample_num = 1000
    _, _, grid = plot.setup_grid([0, 4], 32)
    reward_scale = np.mean(reward_func.get_scalar_reward(grid))
    # alpha = max(0.001, 0.3 * reward_scale)
    alpha = 0.3 * reward_scale

    replay_buffer = ReplayBuffer(
            state_size, 
            action_size,
            device=device,
            size=sac_train_epoch * steps_per_epoch)

    env_fn = lambda:gym.make(env_name,r=reward_func.get_scalar_reward, T=v['T'], random_born=v['random_born'], 
        add_time=v['add_time'])

    sac_agent = SAC(env_fn, replay_buffer, 
                update_after=gym_env.T * v['sac_random_explore_episodes'], 
                update_every=1, # NOTE: this can be larger to be roll out in vectorize env
                seed=seed,
                steps_per_epoch=steps_per_epoch, 
                epochs=sac_train_epoch,
                start_steps=sac_random_explore_episodes,
                batch_size=batch_size,
                device=device,
                automatic_alpha_tuning=False,
                alpha=alpha,
                lr=sac_lr,
                )

    sac_test_rets, sac_alphas, sac_log_pis, sac_test_timestep = sac_agent.learn(print_out=True)
    samples = collect.collect_trajectories_policy(env_fn(), sac_agent, n = sample_num)

    # Fit a density model using the samples
    agent_emp_states = samples[0].copy()
    agent_emp_states = agent_emp_states.reshape(-1,agent_emp_states.shape[2]) # n*T states

    agent_density = neighbors.KernelDensity(bandwidth=v['bandwidth'], kernel=v['kernel'])
    agent_density.fit(agent_emp_states)

    ###### start plotting
    n_pts = 32 # 0.001
    range_lim = [0, 4]

    # construct test points
    test_grid = plot.setup_grid(range_lim, n_pts)

    # plot
    ims = []
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    axs = axs.reshape(-1)
    ims.append(plot.plot_reward_fn(axs[0], test_grid, n_pts,100, reward_func.get_scalar_reward))
    ims.append(plot.plot_samples(samples[0].copy(), axs[1], range_lim, n_pts))
    ims.append(plot.plot_density(agent_density, axs[2], test_grid, n_pts, 100, "kde"))
    plot.plot_traj(samples[0].copy(), axs[3])

    # format
    for ax, im in zip(axs[:3], ims):
        fig.colorbar(im, ax=ax)
    for idx in range(4):
        axs[idx].set_xlim(range_lim[0], range_lim[1])
        axs[idx].set_ylim(range_lim[0], range_lim[1])

    plt.tight_layout()

    # save
    plt.savefig(os.path.join('./data/figures', '{}_{}.png'.format(v['task'], v['obj'])))
    plt.show()
    # plt.close()
    plt.cla()
    plt.clf()

       
