import torch
import numpy as np
import matplotlib.pyplot as plt

from gen_rl.policy.ddpg import DDPG
from gen_rl.commons.launcher import launch_env, launch_models
from gen_rl.commons.args import get_all_args

args = get_all_args()
args = vars(args)
args["num_envs"] = 2
env, eval_env, args = launch_env(args=args)

# weight_dir = "./weights/DDPG-Q-1684195443.7906053"
# weight_dir = "weights/DDPG-V-1684198107.1590505"
# weight_dir = "weights/SAC-Q-1684330656.4116514"
weight_dir = "weights/SAC-V-1684332063.6435866"
import os

if not os.path.exists(f"{weight_dir}/images"):
    os.makedirs(f"{weight_dir}/images")

agent = DDPG(random_act_fn=args["random_act_fn"], args=args)

if not args["if_use_act_val_fn"] and args["if_train_models"]:
    state_model, reward_model = launch_models(env=env, args=args)
    agent.set_models(reward_model=reward_model, state_model=state_model, decompose_obs_fn=None)
agent.load(filename=f"{weight_dir}/session-1300")

s = env.reset()
for t in range(200):
    print(t)
    a = agent.select_action(state=s)
    s, _, _, _ = env.step(a)
    frame = env.render(mode="rgb_array")
    plt.subplot(121).imshow(frame[0])

    num_samples = 100
    s = np.tile(A=s[0, :], reps=(num_samples, 1)).astype(np.float32)
    a = np.linspace(-2.0, 2.0, num=num_samples)[:, None].astype(np.float32)
    q = agent._eval_val_reward(obs_t=torch.tensor(s), a_t=torch.tensor(a))[0][0].detach().cpu().numpy().astype(np.float)

    plt.subplot(122).scatter(a, q)
    plt.subplot(122).set_xlabel("Action")
    plt.subplot(122).set_ylabel("Q(s, a)")
    plt.title(t)
    plt.tight_layout()
    plt.savefig(f"{weight_dir}/images/t-{t}")
    plt.clf()
