import argparse
import h5py
import gym
import numpy as np
import torch
import os
import cv2
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from gym_mm.envs.ant_angle import AntAngle
from gym_mm.envs.ant_custom_env import AntCustomEnv
from gym_mm.envs.cheetah_jump import CheetahJump

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntCustom-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=1000000, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=10, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--target_size', type=int, default=1000000, metavar='G',
                    help='target number of transitions for collection (default: 10000)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
if args.env_name == "AntAngle-v2":
    env = AntAngle()
elif args.env_name == "CheetahJump-v2":
    env = CheetahJump()
elif args.env_name == "AntCustom-v2":
    env = AntCustomEnv()
else:
    env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

path_prefix = os.environ['UDG_DATA_PATH']
model_prefix = path_prefix+"url-data/models/"
model_type = "wasserstein"
model_time = "2023-04-29_08-29-20"
model_step = "final"
compound_prefix = model_prefix + model_type + '/' + args.env_name + '/' + model_time + '/' + model_step
# Agent
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    trainers[i].load_model('{}/sac_actor_{}_{}'.format(compound_prefix, args.env_name, i), '{}/sac_critic_{}_{}'.format(compound_prefix, args.env_name, i))
    args.hidden_dim += 1

# Training Loop
total_numsteps = 0
episodes_num = 0
episode_rewards = []

# Initialize discriminator trainer
img = np.ones((1024, 1024, 3), np.uint8)*255
color_theme = [(152, 225, 204), (152, 213, 172), (151, 201, 137), (201, 204, 132), (244, 206, 126), 
    (218, 171, 136), (203, 151, 140), (188, 130, 143), (189, 137, 170), (191, 144, 196),
    (194, 158, 241), (159, 156, 242)]
scale_f = 20.48
boundaries = 25
for l in range(0, 1):
    for s in range(5):
        obs = env.reset()
        episode_reward = 0
        episode_steps = 0
        # print("Test episode: {}".format(s))   
        trajectory = [obs[:2]]    
        while True:
            action, _ = trainers[l].act(obs)
            if len(action.shape) > 1:
                action = action[0]
            new_obs, reward, done, _ = env.step(action)
            done = done or (episode_steps >= args.max_episode_len)            
            trajectory.append(new_obs[:2])
            episode_reward += reward
            episode_steps += 1

            obs = new_obs
            if done:
                break
        n = len(trajectory)
        for i in range(n-1):
            s = trajectory[i]
            t = trajectory[i+1]
            s = ((s + boundaries) * scale_f).astype(np.int16).tolist()
            t = ((t + boundaries) * scale_f).astype(np.int16).tolist()
            cv2.line(img, tuple(s), tuple(t), color=color_theme[(l-1)%12], thickness=2)        

img = cv2.flip(img, 0)
cv2.imwrite("fig/demo_"+model_time+"_"+model_step+"_0.jpg", img)  
env.close()
