import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import os
import h5py
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from algo.utils.disc import DiscTrainer
from torch.utils.tensorboard import SummaryWriter
from algo.utils.buffer import ReplayMemory
from algo.utils.wsre import wsre
from utils import eval

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="HalfCheetah-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=100000, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=5, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--save_interval', type=int, default=10000, metavar='N',
                    help='interval for saving models (default: 10000)')  
parser.add_argument('--eval_interval', type=int, default=10000, metavar='N',
                    help='interval for evaluating models (default: 10000)')                        
parser.add_argument('--rc', type=float, default=1.0, metavar='G',
                    help='scale for rewards (default: 1.0)')  
parser.add_argument('--src', type=float, default=0.1, metavar='G',
                    help='scale for pseudo rewards (default: 0.0)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)
test_env = gym.make(args.env_name)
test_env.seed(args.seed)
test_env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
# trainers = [SACTrainer(env.observation_space.shape[0], env.action_space, args) for _ in range(args.num_modes)]
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    args.hidden_dim += 1
disc_trainer = DiscTrainer(env.observation_space.shape[0], args)

path_prefix = os.environ['UDG_DATA_PATH']
# Tesnorboard
start_T = datetime.datetime.now()
writer = SummaryWriter(path_prefix+'url-data/logs/{}/{}_wasserstein_{}_{}'.format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"),
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Logging hyperparameters
# writer.add_hparams(hparam_dict=vars(args), metric_dict={})

# Memory
memories = [ReplayMemory(args.replay_size) for _ in range(args.num_modes)]
disc_memory = ReplayMemory(args.replay_size)

# Training Loop
total_numsteps = 0
updates = 0

# Open dataset file to store replay memory
data_path = path_prefix+"url-data/data/wasserstein/{}/{}/mixed_{}/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"), args.num_modes)
if not os.path.exists(data_path):
    os.makedirs(data_path)
f = h5py.File(data_path+"experience.h5", "w")
target_size = args.num_steps * args.num_modes
f_observations = f.create_dataset("observations", (target_size,)+env.observation_space.shape, 'f')
f_actions = f.create_dataset("actions", (target_size,)+env.action_space.shape, 'f')
f_rewards = f.create_dataset("rewards", (target_size,), 'f')
f_terminals = f.create_dataset("terminals", (target_size,), 'b')

terminate = False
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_sr = 0
    episode_r = 0
    episode_steps = 0
    _state = []
    _action = []
    _reward = []
    _next_state = []
    _mask = []
    done = False
    state = env.reset()
    label = i_episode % args.num_modes
    while not done:
        if args.start_steps > len(memories[label]):
            action = env.action_space.sample()  # Sample random action
        else:
            action, _ = trainers[label].act(state)  # Sample action from policy

        if len(memories[label]) > args.batch_size:
            # Number of updates per step in environment
            for i in range(args.updates_per_step):
                # Update parameters of all the networks
                
                samples = memories[label].sample(args.batch_size)
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = trainers[label].update_parameters(samples, updates)
                disc_samples = disc_memory.sample(args.batch_size)
                d_loss = disc_trainer.update_parameters(disc_samples)

                # writer.add_scalar('loss/critic_1', critic_1_loss, updates)
                # writer.add_scalar('loss/critic_2', critic_2_loss, updates)
                # writer.add_scalar('loss/policy', policy_loss, updates)
                # writer.add_scalar('loss/entropy_loss', ent_loss, updates)
                # writer.add_scalar('entropy_temprature/alpha', alpha, updates)
                # writer.add_scalar('loss/disc_loss', d_loss, updates)

                updates += 1

        next_state, reward, done, _ = env.step(action) # Step
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward

        # sr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
        # r = sr * args.src + reward * args.rc
        # episode_sr += sr
        # episode_r += r # reward for environment reward, r for composite reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        # memories[label].push((state, action, np.array([1.0]), r, next_state, mask)) # Append transition to memory
        disc_memory.push((np.array([label]), next_state))
        _state.append(state)
        _action.append(action)
        _reward.append(reward)
        _next_state.append(next_state)
        _mask.append(mask)

        # Store replay memory to file
        f_observations[total_numsteps-1] = state
        f_actions[total_numsteps-1] = action
        f_rewards[total_numsteps-1] = reward
        f_terminals[total_numsteps-1] = done

        state = next_state

        if total_numsteps % (args.eval_interval * args.num_modes) == 0 and args.eval is True:
            eval(args, test_env, trainers, disc_trainer, memories, writer, total_numsteps//1000, algo="wasserstein")

        # Save models
        if total_numsteps % (args.save_interval * args.num_modes) == 0 or total_numsteps == 1:
            prefix = path_prefix+"url-data/models/wasserstein/{}/{}/step{}/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"), total_numsteps//1000)
            for i in range(args.num_modes):
                trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
                print("Saving model {} to /models...".format(i))

        if total_numsteps >= args.num_steps * args.num_modes:
            terminate = True
            break

    assert len(_state) == episode_steps
    labels = np.array([label for _ in range(episode_steps)])
    score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
    avg_acc = np.mean(np.exp(score)).tolist()
    if args.num_modes == 1:
        srs = np.array([0 for _ in range(episode_steps)])
    else:
        state_batch = np.stack(_next_state, axis=0)
        _dist = []
        sum_dist = []
        for i in range(args.num_modes):
            if i!=label:
                dist = np.zeros(episode_steps)
                if len(memories[i]) > args.max_episode_len: # Ensure we have a non-empty target batch
                    target_state_batch = list(memories[i].dump(args.max_episode_len))[0]
                    dist = wsre(state_batch, target_state_batch)
                sum_dist.append(np.sum(dist))
                _dist.append(dist)
        min_dist_idx = np.argmin(sum_dist)
        srs = _dist[min_dist_idx]

    rewards = np.stack(_reward, axis=0)
    # rewards are not modified for the first policy
    if label == 0:
        rs = rewards * args.rc
    else:
        rs = srs * args.src + rewards * args.rc
    episode_sr = sum(srs.tolist())
    episode_r = sum(rs.tolist())

    for i in range(episode_steps):
        memories[label].push((_state[i], _action[i], np.array([1.0]), rs[i], _next_state[i], _mask[i]))

    writer.add_scalar('train/reward', episode_reward, i_episode)
    writer.add_scalar('train/sr', episode_sr, i_episode)
    writer.add_scalar('train/r', episode_r, i_episode)
    writer.add_scalar('train/acc', avg_acc, i_episode)
    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}, sr: {}, r: {}, acc: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_sr, 2), round(episode_r, 2), round(avg_acc, 2)))

    if terminate:
        break

f.close()
print("Replay memory saved to file: "+data_path)
prefix = path_prefix+"url-data/models/wasserstein/{}/{}/final/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"))
for i in range(args.num_modes):
    trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
    print("Saving model {} to /models...".format(i))
env.close()
test_env.close()

