import argparse
import h5py
import gym
import numpy as np
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from gym_mm.envs.ant_angle import AntAngle
from gym_mm.envs.cheetah_jump import CheetahJump

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntAngle-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=1000000, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=5, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--target_size', type=int, default=1000000, metavar='G',
                    help='target number of transitions for collection (default: 10000)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
if args.env_name == "AntAngle-v2":
    env = AntAngle()
elif args.env_name == "CheetahJump-v2":
    env = CheetahJump()
else:
    env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

path_prefix = os.environ['UDG_DATA_PATH']
model_prefix = path_prefix+"url-data/models/"
model_type = "diayn2"
model_time = "2023-04-18_11-51-49"
model_step = "step1500"
compound_prefix = model_prefix + model_type + '/' + args.env_name + '/' + model_time + '/' + model_step
# Agent
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    trainers[i].load_model('{}/sac_actor_{}_{}'.format(compound_prefix, args.env_name, i), '{}/sac_critic_{}_{}'.format(compound_prefix, args.env_name, i))
    args.hidden_dim += 1

# Training Loop
total_numsteps = 0
episodes_num = 0
episode_rewards = []

data_path = path_prefix+"url-data/data/{}/{}/{}/{}/".format(model_type, args.env_name, model_time, model_step)
if not os.path.exists(data_path):
    os.makedirs(data_path)
f = h5py.File(data_path+"experience.h5", "w")
observations = f.create_dataset("observations", (args.target_size,)+env.observation_space.shape, 'f')
actions = f.create_dataset("actions", (args.target_size,)+env.action_space.shape, 'f')
rewards = f.create_dataset("rewards", (args.target_size,), 'f')
terminals = f.create_dataset("terminals", (args.target_size,), 'b')
if args.env_name == "CheetahJump-v2":
    z_position = f.create_dataset("z_position", (args.target_size,), 'f')
    x_velocity = f.create_dataset("x_velocity", (args.target_size,), 'f')
    additional_r = f.create_dataset("additional_r", (args.target_size,), 'f')
elif args.env_name == "AntAngle-v2":
    x_velocity = f.create_dataset("x_velocity", (args.target_size,), 'f')
    y_velocity = f.create_dataset("y_velocity", (args.target_size,), 'f')
    additional_r = f.create_dataset("additional_r", (args.target_size,), 'f')

text_file = open(data_path+"log.txt", "w")
steps = []
files = []
subset_target_size = args.target_size // args.num_modes
o = []
a = []
r = []
t = []
if args.env_name == "CheetahJump-v2":
    z = []
    x = []
    a_r = []
elif args.env_name == "AntAngle-v2":
    x = []
    y = []
    a_r = []
avg_reward = []
for i in range(args.num_modes):
    steps.append(0)
    files.append(h5py.File(data_path+"experience_{}.h5".format(i), "w"))
    avg_reward.append([])
    o.append(files[i].create_dataset("observations", (subset_target_size,)+env.observation_space.shape, 'f'))
    a.append(files[i].create_dataset("actions", (subset_target_size,)+env.action_space.shape, 'f'))
    r.append(files[i].create_dataset("rewards", (subset_target_size,), 'f'))
    t.append(files[i].create_dataset("terminals", (subset_target_size,), 'b'))
    if args.env_name == "CheetahJump-v2":
        z.append(files[i].create_dataset("z_position", (subset_target_size,), 'f'))
        x.append(files[i].create_dataset("x_velocity", (subset_target_size,), 'f'))
        a_r.append(files[i].create_dataset("additional_r", (subset_target_size,), 'f'))
    elif args.env_name == "AntAngle-v2":
        x.append(files[i].create_dataset("x_velocity", (subset_target_size,), 'f'))
        y.append(files[i].create_dataset("y_velocity", (subset_target_size,), 'f'))
        a_r.append(files[i].create_dataset("additional_r", (subset_target_size,), 'f'))
mode = next_mode = 0
terminate = False
while not terminate:
    episode_reward = 0
    episode_steps = 0

    done = False
    obs = env.reset()
    mode = next_mode
    assert mode <= args.num_modes
    episode_end = False
    while not episode_end:
        action, _ = trainers[mode].act(obs, eval=True)  # Sample action from policy

        new_obs, reward, done, info = env.step(action) # Step

        observations[total_numsteps] = obs
        actions[total_numsteps] = action
        rewards[total_numsteps] = reward
        terminals[total_numsteps] = done
        o[mode][steps[mode]] = obs
        a[mode][steps[mode]] = action
        r[mode][steps[mode]] = reward
        t[mode][steps[mode]] = done
        if args.env_name == "CheetahJump-v2":
            z[mode][steps[mode]] = info['z_position']
            x[mode][steps[mode]] = info['x_velocity']
            a_r[mode][steps[mode]] = info['additional_r']
        elif args.env_name == "AntAngle-v2":
            x[mode][steps[mode]] = info['x_velocity']
            y[mode][steps[mode]] = info['y_velocity']
            a_r[mode][steps[mode]] = info['additional_r']  

        obs = new_obs

        episode_steps += 1
        total_numsteps += 1
        steps[mode] += 1
        episode_reward += reward

        if steps[mode] >= subset_target_size:
            episodes_num += 1
            next_mode += 1
            print("Mode {} experience collection complete.".format(mode+1))
            episode_end = True

        if total_numsteps >= args.target_size:
            terminate = True
            episodes_num += 1
            episode_end = True

        if done:
            episodes_num += 1
            episode_end = True

        if episode_end:
            episode_rewards.append(episode_reward)
            avg_reward[mode].append(episode_reward)
            print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(episodes_num, total_numsteps, episode_steps, round(episode_reward, 2)))
            break

f.close()
for i in range(args.num_modes):
    files[i].close()

print("Data collection finished. Avg return: {}, max: {}, min: {}, std: {}. Total episodes: {}. Total samples: {}".format(round(np.mean(episode_rewards), 2), round(np.max(episode_rewards), 2), round(np.min(episode_rewards), 2),round(np.std(episode_rewards), 2), episodes_num, total_numsteps))
text_file.write("Avg return: {}, max: {}, min: {}, std: {}. Total episodes: {}. Total samples: {}\n".format(round(np.mean(episode_rewards), 2), round(np.max(episode_rewards), 2), round(np.min(episode_rewards), 2),round(np.std(episode_rewards), 2), episodes_num, total_numsteps))
for i in range(args.num_modes):
    print("Mode {} Avg return: {}".format(i, round(np.mean(avg_reward[i]), 2)))
    text_file.write("Mode {} Avg return: {}\n".format(i, round(np.mean(avg_reward[i]), 2)))
env.close()
text_file.close()