import numpy as np
import torch as th
from stochastic_actor import StochasticActor
from optimize import calc_true_trajectory
import sys
import math
import random
random.seed(0)
th.manual_seed(0)
np.random.seed(0)

import argparse
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", action="store", default="tmp")
parser.add_argument("--depth", action="store", default=2, type = int)
parser.add_argument("--optimize_steps", action="store", default=10000, type = int)

args = parser.parse_args()

log_dir = args.log_dir
space_name = 'Panda'

episode_num = 0

from pick_space import pick_space
space, eval_episodes = pick_space(space_name)
dim = space.dim

eval_episodes = [[
    th.tensor([ 1.,  0.75,  -1.2, 0.5, -1.2,  0., 0.]),
    th.tensor([-1.,  0.75, -1.2, 0.5, -1.2,  0., 0.]),    
]]

actor = StochasticActor(dim)
actor.load_state_dict(th.load(log_dir+'/actor_model.pt', map_location = 'cpu'))
count=0
depth = args.depth

for start, goal in eval_episodes[episode_num:episode_num+1]:
    states = th.zeros((2,dim))
    states[0] = start
    states[1] = goal
    for rev_dep in range(depth):
        with th.no_grad():
            middles = actor(states[:-1],states[1:], deterministic = True)
            new_states = th.zeros((states.shape[0]+middles.shape[0], dim))
            new_states[::2] = states
            new_states[1::2] = middles
            states = new_states
    deltas=space.calc_deltas(states[:-1], states[1:])
    if deltas.max() >= 0.1:
        count+=1
true_traj = calc_true_trajectory(space, states[0], states[-1], N=31, timesteps=args.optimize_steps)
true_dist = space.calc_deltas(true_traj[:-1], true_traj[1:]).sum().item()
true_traj = true_traj.numpy()

import roboticstoolbox as rp
import numpy as np

panda = rp.models.DH.Panda()

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.size"] = 0
plt.rcParams["figure.figsize"] = (7,8)

radius = 0.01
height = 3
theta = np.linspace(0, 2*np.pi, 100)
z = np.linspace(0, height, 100)
theta_grid, z_grid=np.meshgrid(theta, z)
x_grid = 0.4 + radius * np.cos(theta_grid)
y_grid = radius * np.sin(theta_grid)


states = states.numpy()

azim = -90
elev = 80

for index, state in enumerate(states):
    fig = panda.plot(state, jointaxes = False, eeframe = False, name = False)
    ax = fig.ax
    
    ax.set_xlim(-0.15, 0.5)
    ax.set_ylim(-0.25, 0.4)
    ax.set_zlim(0., 0.65)
    ax.view_init(azim = azim, elev = elev)
    ax.dist = 7.7
    ax.set_box_aspect((1.,1.,1.))
    
    ax.plot_surface(x_grid, y_grid, z_grid, color='b', alpha=0.6)
    
    plt.savefig('../panda_figures/panda_'+str(index)+'.pdf', bbox_inches='tight')
    plt.clf()

for index, state in enumerate(true_traj[::8]):
    fig = panda.plot(state, jointaxes = False, eeframe = False, name = False)
    ax = fig.ax
    
    ax.set_xlim(-0.15, 0.5)
    ax.set_ylim(-0.25, 0.4)
    ax.set_zlim(0., 0.65)
    ax.view_init(azim = azim, elev = elev)
    ax.dist = 7.7
    ax.set_box_aspect((1.,1.,1.))
    
    ax.plot_surface(x_grid, y_grid, z_grid, color='b', alpha=0.6)
    
    plt.savefig('../panda_figures/panda_truth_'+str(index)+'.pdf', bbox_inches='tight')
    plt.clf()

for index in range(2**depth+1):
    alpha = index / 2**depth
    state = (1-alpha)*states[0]+alpha*states[-1]
    fig = panda.plot(state, jointaxes = False, eeframe = False, name = False)
    ax = fig.ax
    
    ax.set_xlim(-0.15, 0.5)
    ax.set_ylim(-0.25, 0.4)
    ax.set_zlim(0., 0.65)
    ax.view_init(azim = azim, elev = elev)
    ax.dist = 7.7
    ax.set_box_aspect((1.,1.,1.))
    
    ax.plot_surface(x_grid, y_grid, z_grid, color='b', alpha=0.6)
    
    plt.savefig('../panda_figures/panda_linear_'+str(index)+'.pdf', bbox_inches='tight')
    plt.clf()
