# Equal episode sampling
# Single mode sampling
# eval = True
import argparse
import h5py
import gym
import numpy as np
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from gym_mm.envs.ant_angle import AntAngle
from gym_mm.envs.cheetah_jump import CheetahJump
from gym_mm.envs.ant_custom_env import AntCustomEnv

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntCustom-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=1000000, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=10, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--target_size', type=int, default=1000000, metavar='G',
                    help='target number of transitions for collection (default: 10000)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
if args.env_name == "AntAngle-v2":
    env = AntAngle()
elif args.env_name == "CheetahJump-v2":
    env = CheetahJump()
elif args.env_name == "AntCustom-v2":
    env = AntCustomEnv()
else:
    env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

path_prefix = os.environ['UDG_DATA_PATH']
model_prefix = path_prefix+"url-data/models/"
model_type = "wasserstein"
model_time = "2023-04-29_12-37-47"
model_step = "final"
# modes = [x for x in range(args.num_modes)]
datasets_path = path_prefix+"url-data/data/"+model_type+"/"+args.env_name+"/"+model_time+"/"
reward_matrix = np.load(datasets_path+"reward_matrix.np", allow_pickle=True)
num_modes, num_angles = reward_matrix.shape
min_interval = 360 // num_angles
requested_angle = 300
requested_index = ((requested_angle + min_interval//2)//min_interval) % num_angles
# k = 2
# modes = np.argpartition(reward_matrix[:,requested_index], -k)[-k:].tolist()
modes = [np.argmax(reward_matrix[:,requested_index]).tolist(), np.argmin(reward_matrix[:,requested_index]).tolist()]
compound_prefix = model_prefix + model_type + '/' + args.env_name + '/' + model_time + '/' + model_step
# Agent
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    trainers[i].load_model('{}/sac_actor_{}_{}'.format(compound_prefix, args.env_name, i), '{}/sac_critic_{}_{}'.format(compound_prefix, args.env_name, i))
    args.hidden_dim += 1

# Training Loop
total_numsteps = 0
episodes_num = 0
episode_rewards = []

data_path = "url-data/data/{}/{}/{}/{}_mixed_{}_contra/".format(model_type, args.env_name, model_time, model_step, requested_angle)
if not os.path.exists(data_path):
    os.makedirs(data_path)
f = h5py.File(data_path+"experience.h5", "w")
# Truncated obs for ant
if args.env_name == 'AntCustom-v2':
    observations = f.create_dataset("observations", (args.target_size, 29), 'f')
else:
    observations = f.create_dataset("observations", (args.target_size,)+env.observation_space.shape, 'f')
actions = f.create_dataset("actions", (args.target_size,)+env.action_space.shape, 'f')
rewards = f.create_dataset("rewards", (args.target_size,), 'f')
terminals = f.create_dataset("terminals", (args.target_size,), 'b')
if args.env_name == "CheetahJump-v2":
    z_position = f.create_dataset("z_position", (args.target_size,), 'f')
    x_velocity = f.create_dataset("x_velocity", (args.target_size,), 'f')
    additional_r = f.create_dataset("additional_r", (args.target_size,), 'f')
elif args.env_name == "AntAngle-v2":
    x_velocity = f.create_dataset("x_velocity", (args.target_size,), 'f')
    y_velocity = f.create_dataset("y_velocity", (args.target_size,), 'f')
    additional_r = f.create_dataset("additional_r", (args.target_size,), 'f')

text_file = open(data_path+"log.txt", "w")

terminate = False
mode_index = 0
while not terminate:
    episode_reward = 0
    episode_steps = 0

    done = False
    obs = env.reset()
    episode_end = False
    if total_numsteps//(args.target_size//len(modes)) > mode_index:
        mode_index += 1
        print("Mode changes to {}.".format(mode_index))
    while not episode_end:
        action, _ = trainers[modes[mode_index]].act(obs, eval=True)  # Sample action from policy

        new_obs, reward, done, info = env.step(action) # Step

        if args.env_name == "AntCustom-v2":
            observations[total_numsteps] = obs[:29]
        else:
            observations[total_numsteps] = obs
        actions[total_numsteps] = action
        rewards[total_numsteps] = reward
        terminals[total_numsteps] = done
        if args.env_name == "CheetahJump-v2":
            z_position[total_numsteps] = info['z_position']
            x_velocity[total_numsteps] = info['x_velocity']
            additional_r[total_numsteps] = info['additional_r']
        elif args.env_name == "AntAngle-v2":
            x_velocity[total_numsteps] = info['x_velocity']
            y_velocity[total_numsteps] = info['y_velocity']
            additional_r[total_numsteps] = info['additional_r']

        obs = new_obs

        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward

        if total_numsteps >= args.target_size:
            terminate = True
            episodes_num += 1
            episode_end = True

        if done and not terminate:
            episodes_num += 1
            episode_end = True

        if episode_end:
            episode_rewards.append(episode_reward)
            print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(episodes_num, total_numsteps, episode_steps, round(episode_reward, 2)))
            break

f.close()
print("Data collection finished. Avg return: {}, max: {}, min: {}, std: {}. Total episodes: {}. Total samples: {}".format(round(np.mean(episode_rewards), 2), round(np.max(episode_rewards), 2), round(np.min(episode_rewards), 2),round(np.std(episode_rewards), 2), episodes_num, total_numsteps))
text_file.write("Avg return: {}, max: {}, min: {}, std: {}. Total episodes: {}. Total samples: {}\n".format(round(np.mean(episode_rewards), 2), round(np.max(episode_rewards), 2), round(np.min(episode_rewards), 2),round(np.std(episode_rewards), 2), episodes_num, total_numsteps))
env.close()
text_file.close()