import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from algo.utils.disc import DiscTrainer
from torch.utils.tensorboard import SummaryWriter
from algo.utils.buffer import ReplayMemory
from algo.utils.wsre import wsre
from gym_mm.envs.ant_angle import AntAngle
from gym_mm.envs.cheetah_jump import CheetahJump
from gym_mm.envs.ant_custom_env import AntCustomEnv

# Reduced obs for ant angle envs
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntCustom-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--reduced-obs', type=bool, default=False,
                    help='Reducing obs dimensions for discriminability and distance computing (default: False)')
parser.add_argument('--include-base-reward', type=bool, default=False,
                    help='Includes base rewards when unsupervised (default: False)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=200001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=10, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--save_interval', type=int, default=10000, metavar='N',
                    help='interval for saving models (default: 10000)')  
parser.add_argument('--eval_interval', type=int, default=10000, metavar='N',
                    help='interval for evaluating models (default: 10000)')                        
parser.add_argument('--rc', type=float, default=0.0, metavar='G',
                    help='scale for rewards (default: 1.0)')  
parser.add_argument('--src', type=float, default=1.0, metavar='G',
                    help='scale for pseudo rewards (default: 0.0)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
if args.env_name == "AntAngle-v2":
    env = AntAngle()
    test_env = AntAngle()
    args.reduced_obs = 2
    args.include_base_reward = True
elif args.env_name == "AntCustom-v2":
    env = AntCustomEnv()
    test_env = AntCustomEnv()
    args.reduced_obs = 2
    args.include_base_reward = False
elif args.env_name == "CheetahJump-v2":
    env = CheetahJump()
    test_env = CheetahJump()
    args.reduced_obs = 1
    args.include_base_reward = True
else:
    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)

env.seed(args.seed)
env.action_space.seed(args.seed)
test_env.seed(args.seed)
test_env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
# trainers = [SACTrainer(env.observation_space.shape[0], env.action_space, args) for _ in range(args.num_modes)]
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    args.hidden_dim += 1
if args.reduced_obs:
    disc_trainer = DiscTrainer(args.reduced_obs, args)
    from utils_ro import eval
else:
    disc_trainer = DiscTrainer(env.observation_space.shape[0], args)
    from utils import eval

path_prefix = os.environ['UDG_DATA_PATH']
# Tesnorboard
start_T = datetime.datetime.now()
writer = SummaryWriter(path_prefix + 'url-data/logs/{}/{}_wasserstein_{}_{}'.format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"),
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Logging hyperparameters
# writer.add_hparams(hparam_dict=vars(args), metric_dict={})

# Memory
memories = [ReplayMemory(args.replay_size) for _ in range(args.num_modes)]
disc_memory = ReplayMemory(args.replay_size)

# Training Loop
total_numsteps = 0
updates = 0

SR_MIN = np.log(1 / args.num_modes) * 10
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_sr = 0
    episode_r = 0
    episode_steps = 0
    episode_base_r = 0
    _state = []
    _action = []
    _reward = []
    _next_state = []
    _mask = []
    _base_r = []
    done = False
    state = env.reset()
    label = i_episode % args.num_modes
    while not done:
        if args.start_steps > len(memories[label]):
            action = env.action_space.sample()  # Sample random action
        else:
            action, _ = trainers[label].act(state)  # Sample action from policy

        if len(memories[label]) > args.batch_size:
            # Number of updates per step in environment
            for i in range(args.updates_per_step):
                # Update parameters of all the networks
                
                samples = memories[label].sample(args.batch_size)
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = trainers[label].update_parameters(samples, updates)
                disc_samples = disc_memory.sample(args.batch_size)
                d_loss = disc_trainer.update_parameters(disc_samples)

                # writer.add_scalar('loss/critic_1', critic_1_loss, updates)
                # writer.add_scalar('loss/critic_2', critic_2_loss, updates)
                # writer.add_scalar('loss/policy', policy_loss, updates)
                # writer.add_scalar('loss/entropy_loss', ent_loss, updates)
                # writer.add_scalar('entropy_temprature/alpha', alpha, updates)
                # writer.add_scalar('loss/disc_loss', d_loss, updates)

                updates += 1

        next_state, reward, done, info = env.step(action) # Step
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward
        episode_base_r += info['additional_r']

        # sr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
        # r = sr * args.src + reward * args.rc
        # episode_sr += sr
        # episode_r += r # reward for environment reward, r for composite reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        # memories[label].push((state, action, np.array([1.0]), r, next_state, mask)) # Append transition to memory
        if args.reduced_obs:
            disc_memory.push((np.array([label]), next_state[:args.reduced_obs]))   
        else: 
            disc_memory.push((np.array([label]), next_state))
        _state.append(state)
        _action.append(action)
        _reward.append(reward)
        _next_state.append(next_state)
        _mask.append(mask)
        _base_r.append(info['additional_r'])

        state = next_state

        if total_numsteps % (args.eval_interval * args.num_modes) == 0 and args.eval is True:
            eval(args, test_env, trainers, disc_trainer, memories, writer, total_numsteps//env._max_episode_steps, algo="wasserstein")

        # Save models
        if total_numsteps % (args.save_interval * args.num_modes) == 0 or total_numsteps == 1:
            prefix = path_prefix + "url-data/models/wasserstein/{}/{}/step{}/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"), total_numsteps//env._max_episode_steps)
            for i in range(args.num_modes):
                trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
                print("Saving model {} to /models...".format(i))

    assert len(_state) == episode_steps
    labels = np.array([label for _ in range(episode_steps)])
    if args.reduced_obs:
        score = disc_trainer.score(np.stack(_next_state, axis=0)[:,:args.reduced_obs], labels)
    else:
        score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
    avg_acc = np.mean(np.exp(score)).tolist()
    state_batch = np.stack(_next_state, axis=0)
    _dist = []
    sum_dist = []
    for i in range(args.num_modes):
        if i!=label:
            dist = np.zeros(episode_steps)
            if len(memories[i]) > args.max_episode_len: # Ensure we have a non-empty target batch
                target_state_batch = list(memories[i].dump(args.max_episode_len))[0]
                if args.reduced_obs:
                    dist = wsre(state_batch[:,:args.reduced_obs], target_state_batch[:,:args.reduced_obs])
                else:
                    dist = wsre(state_batch, target_state_batch)
            sum_dist.append(np.sum(dist))
            _dist.append(dist)
    min_dist_idx = np.argmin(sum_dist)
    srs = _dist[min_dist_idx]

    rewards = np.stack(_reward, axis=0)
    base_rs = np.stack(_base_r, axis=0)
    # rewards are not modified for the first policy
    if args.include_base_reward:      
        if label == 0:
            rs = base_rs + (rewards - base_rs) * args.rc
        else:
            rs = base_rs + (rewards - base_rs) * args.rc + srs * args.src
    else:
        if label == 0:
            rs = rewards * args.rc
        else:
            rs = srs * args.src + rewards * args.rc
    episode_sr = sum(srs.tolist())
    episode_r = sum(rs.tolist())
    episode_base_r = sum(base_rs.tolist())

    for i in range(episode_steps):
        memories[label].push((_state[i], _action[i], np.array([1.0]), rs[i], _next_state[i], _mask[i]))

    if total_numsteps > args.num_steps * args.num_modes:
        break

    writer.add_scalar('train/reward', episode_reward, i_episode)
    writer.add_scalar('train/sr', episode_sr, i_episode)
    writer.add_scalar('train/r', episode_r, i_episode)
    writer.add_scalar('train/base_r', episode_base_r, i_episode)
    writer.add_scalar('train/acc', avg_acc, i_episode)
    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}, sr: {}, r: {}, acc: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_sr, 2), round(episode_r, 2), round(avg_acc, 2)))

prefix = path_prefix + "url-data/models/wasserstein/{}/{}/final/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"))
for i in range(args.num_modes):
    trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
    print("Saving model {} to /models...".format(i))
env.close()
test_env.close()

