import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from algo.utils.disc import DiscTrainer
from torch.utils.tensorboard import SummaryWriter
from algo.utils.buffer import ReplayMemory
from algo.utils.wsre import wsre
from utils_ro import eval
from gym_mm.envs.ant_angle import AntAngle
from gym_mm.envs.cheetah_jump import CheetahJump
from gym_mm.envs.ant_custom_env import AntCustomEnv

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntCustom-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--max_episode_len', type=int, default=1000, metavar='N', 
                    help='maximum episode length (default: 1000)')
parser.add_argument('--num_steps', type=int, default=200001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=10, metavar='N',
                    help='number of modes (default: 1)')    
parser.add_argument('--save_interval', type=int, default=10000, metavar='N',
                    help='interval for saving models (default: 10000)')  
parser.add_argument('--eval_interval', type=int, default=10000, metavar='N',
                    help='interval for evaluating models (default: 10000)')  
parser.add_argument('--rc', type=float, default=0.0, metavar='G',
                    help='scale for rewards (default: 1.0)')  
parser.add_argument('--src', type=float, default=1.0, metavar='G',
                    help='scale for pseudo rewards (default: 0.0)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
if args.env_name == "AntAngle-v2":
    env = AntAngle()
    test_env = AntAngle()
    args.reduced_obs = True
    args.include_base_reward = True
elif args.env_name == "AntCustom-v2":
    env = AntCustomEnv()
    test_env = AntCustomEnv()
    args.reduced_obs = True
    args.include_base_reward = False
elif args.env_name == "CheetahJump-v2":
    env = CheetahJump()
    test_env = CheetahJump()
else:
    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    
env.seed(args.seed)
env.action_space.seed(args.seed)
test_env.seed(args.seed)
test_env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
# trainers = [SACTrainer(env.observation_space.shape[0], env.action_space, args) for _ in range(args.num_modes)]
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    args.hidden_dim += 1
if args.reduced_obs:
    disc_trainer = DiscTrainer(2, args)
    from utils_ro import eval
else:
    disc_trainer = DiscTrainer(env.observation_space.shape[0], args)
    from utils import eval

path_prefix = os.environ['UDG_DATA_PATH']

#Tesnorboard
start_T = datetime.datetime.now()
writer = SummaryWriter(path_prefix+'url-data/logs/{}/{}_diayn2_{}_{}'.format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"),
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Logging hyperparameters
# writer.add_hparams(hparam_dict=vars(args), metric_dict={})

# Memory
memories = [ReplayMemory(args.replay_size) for _ in range(args.num_modes)]
disc_memory = ReplayMemory(args.replay_size)

# Training Loop
total_numsteps = 0
updates = 0

SR_MIN = np.log(1 / args.num_modes) * 10
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_sr = 0
    episode_r = 0
    episode_steps = 0
    episode_base_r = 0
    avg_acc = 0.
    done = False
    state = env.reset()
    label = i_episode % args.num_modes
    while not done:
        if args.start_steps > len(memories[label]):
            action = env.action_space.sample()  # Sample random action
        else:
            action, _ = trainers[label].act(state)  # Sample action from policy

        if len(memories[label]) > args.batch_size:
            # Number of updates per step in environment
            for i in range(args.updates_per_step):
                # Update parameters of all the networks
                
                samples = memories[label].sample(args.batch_size)
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = trainers[label].update_parameters(samples, updates)
                disc_samples = disc_memory.sample(args.batch_size)
                d_loss = disc_trainer.update_parameters(disc_samples)

                # writer.add_scalar('loss/critic_1', critic_1_loss, updates)
                # writer.add_scalar('loss/critic_2', critic_2_loss, updates)
                # writer.add_scalar('loss/policy', policy_loss, updates)
                # writer.add_scalar('loss/entropy_loss', ent_loss, updates)
                # writer.add_scalar('entropy_temprature/alpha', alpha, updates)
                # writer.add_scalar('loss/disc_loss', d_loss, updates)

                updates += 1

        next_state, reward, done, info = env.step(action) # Step
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward
        episode_base_r += info['additional_r']
        base_r = info['additional_r']

        if args.reduced_obs:
            score = disc_trainer.score(next_state[:2], np.array([label]))  
        else:  
            score = disc_trainer.score(next_state, np.array([label]))
        sr = max(score, SR_MIN)
        if args.include_base_reward:      
            if label == 0:
                r = base_r + (reward - base_r) * args.rc
            else:
                r = base_r + (reward - base_r) * args.rc + sr * args.src
        else:
            if label == 0:
                r = reward * args.rc
            else:
                r = sr * args.src + reward * args.rc
        # r = sr * args.src + reward * args.rc
        episode_sr += sr
        episode_r += r # reward for environment reward, r for composite reward
        avg_acc += np.exp(score)

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        memories[label].push((state, action, np.array([1.0]), r, next_state, mask)) # Append transition to memory
        if args.reduced_obs:
            disc_memory.push((np.array([label]), next_state[:2]))   
        else: 
            disc_memory.push((np.array([label]), next_state))

        state = next_state

        if total_numsteps % (args.eval_interval * args.num_modes) == 0 and args.eval is True:
            eval(args, test_env, trainers, disc_trainer, memories, writer, total_numsteps//env._max_episode_steps, algo="diayn2")

        # Save models
        if total_numsteps % (args.save_interval * args.num_modes) == 0 or total_numsteps == 1:
            prefix = path_prefix+"url-data/models/diayn2/{}/{}/step{}/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"), total_numsteps//env._max_episode_steps)
            for i in range(args.num_modes):
                trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
                print("Saving model {} to /models...".format(i))

    avg_acc /= episode_steps

    if total_numsteps > args.num_steps * args.num_modes:
        break

    writer.add_scalar('train/reward', episode_reward, i_episode)
    writer.add_scalar('train/sr', episode_sr, i_episode)
    writer.add_scalar('train/r', episode_r, i_episode)
    writer.add_scalar('train/acc', avg_acc, i_episode)
    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}, sr: {}, r: {}, acc: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_sr, 2), round(episode_r, 2), round(avg_acc, 2)))

prefix = path_prefix+"url-data/models/diayn2/{}/{}/final/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"))
for i in range(args.num_modes):
    trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
    print("Saving model {} to /models...".format(i))
env.close()
test_env.close()
