import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from algo.utils.disc import DiscTrainer
from torch.utils.tensorboard import SummaryWriter
from algo.utils.buffer import ReplayMemory
from algo.utils.wsre import wsre
from utils import eval

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="HalfCheetah-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=5, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--save_interval', type=int, default=10000, metavar='N',
                    help='interval for saving models (default: 10000)')  
parser.add_argument('--eval_interval', type=int, default=10000, metavar='N',
                    help='interval for evaluating models (default: 10000)')                        
parser.add_argument('--rc', type=float, default=1.0, metavar='G',
                    help='scale for rewards (default: 1.0)')  
parser.add_argument('--src', type=float, default=0.0, metavar='G',
                    help='scale for pseudo rewards (default: 0.0)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)
test_env = gym.make(args.env_name)
test_env.seed(args.seed)
test_env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
# trainers = [SACTrainer(env.observation_space.shape[0], env.action_space, args) for _ in range(args.num_modes)]
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    args.hidden_dim += 1
disc_trainer = DiscTrainer(env.observation_space.shape[0], args)

path_prefix = os.environ['UDG_DATA_PATH']

# Tesnorboard
start_T = datetime.datetime.now()
writer = SummaryWriter(path_prefix+'url-data/logs/{}/{}_wasserstein_{}_{}'.format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"),
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Logging hyperparameters
# writer.add_hparams(hparam_dict=vars(args), metric_dict={})

# Memory
memories = [ReplayMemory(args.replay_size) for _ in range(args.num_modes)]
disc_memory = ReplayMemory(args.replay_size)

# Training Loop
total_numsteps = 0
updates = 0

SR_MIN = np.log(1 / args.num_modes) * 10
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_sr = 0
    episode_r = 0
    episode_steps = 0
    _state = []
    _action = []
    _reward = []
    _next_state = []
    _mask = []
    done = False
    state = env.reset()
    label = i_episode % args.num_modes
    while not done:
        if args.start_steps > len(memories[label]):
            action = env.action_space.sample()  # Sample random action
        else:
            action, _ = trainers[label].act(state)  # Sample action from policy

        if len(memories[label]) > args.batch_size:
            # Number of updates per step in environment
            for i in range(args.updates_per_step):
                # Update parameters of all the networks
                
                samples = memories[label].sample(args.batch_size)
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = trainers[label].update_parameters(samples, updates)
                disc_samples = disc_memory.sample(args.batch_size)
                d_loss = disc_trainer.update_parameters(disc_samples)

                # writer.add_scalar('loss/critic_1', critic_1_loss, updates)
                # writer.add_scalar('loss/critic_2', critic_2_loss, updates)
                # writer.add_scalar('loss/policy', policy_loss, updates)
                # writer.add_scalar('loss/entropy_loss', ent_loss, updates)
                # writer.add_scalar('entropy_temprature/alpha', alpha, updates)
                # writer.add_scalar('loss/disc_loss', d_loss, updates)

                updates += 1

        next_state, reward, done, _ = env.step(action) # Step
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward

        # sr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
        # r = sr * args.src + reward * args.rc
        # episode_sr += sr
        # episode_r += r # reward for environment reward, r for composite reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        # memories[label].push((state, action, np.array([1.0]), r, next_state, mask)) # Append transition to memory
        disc_memory.push((np.array([label]), next_state))
        _state.append(state)
        _action.append(action)
        _reward.append(reward)
        _next_state.append(next_state)
        _mask.append(mask)

        state = next_state

        if total_numsteps % (args.eval_interval * args.num_modes) == 0 and args.eval is True:
            eval(args, test_env, trainers, disc_trainer, memories, writer, total_numsteps//1000, algo="wasserstein")

        # Save models
        if total_numsteps % (args.save_interval * args.num_modes) == 0 or total_numsteps == 1:
            prefix = path_prefix+"url-data/models/wasserstein/{}/{}/step{}/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"), total_numsteps//1000)
            for i in range(args.num_modes):
                trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
                print("Saving model {} to /models...".format(i))

    assert len(_state) == episode_steps
    labels = np.array([label for _ in range(episode_steps)])
    score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
    avg_acc = np.mean(np.exp(score)).tolist()
    state_batch = np.stack(_next_state, axis=0)
    _dist = []
    sum_dist = []
    for i in range(args.num_modes):
        if i!=label:
            dist = np.zeros(episode_steps)
            if len(memories[i]) > args.max_episode_len: # Ensure we have a non-empty target batch
                target_state_batch = list(memories[i].dump(args.max_episode_len))[0]
                dist = wsre(state_batch, target_state_batch)
            sum_dist.append(np.sum(dist))
            _dist.append(dist)
    min_dist_idx = np.argmin(sum_dist)
    srs = _dist[min_dist_idx]

    rewards = np.stack(_reward, axis=0)
    rs = srs * args.src + rewards * args.rc
    episode_sr = sum(srs.tolist())
    episode_r = sum(rs.tolist())

    for i in range(episode_steps):
        memories[label].push((_state[i], _action[i], np.array([1.0]), rs[i], _next_state[i], _mask[i]))

    if total_numsteps > args.num_steps * args.num_modes:
        break

    writer.add_scalar('train/reward', episode_reward, i_episode)
    writer.add_scalar('train/sr', episode_sr, i_episode)
    writer.add_scalar('train/r', episode_r, i_episode)
    writer.add_scalar('train/acc', avg_acc, i_episode)
    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}, sr: {}, r: {}, acc: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_sr, 2), round(episode_r, 2), round(avg_acc, 2)))

prefix = path_prefix+"url-data/models/wasserstein/{}/{}/final/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"))
for i in range(args.num_modes):
    trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
    print("Saving model {} to /models...".format(i))
env.close()
test_env.close()

