import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import os
# from algo.utils.sac import SACTrainer
from algo.utils.sac2 import SACTrainer2 as SACTrainer
from algo.utils.disc import DiscTrainer
from torch.utils.tensorboard import SummaryWriter
from algo.utils.buffer import ReplayMemory

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="HalfCheetah-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--disc_lr', type=float, default=0.002, metavar='G',
                    help='learning rate (default: 0.002)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=100, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=200000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_false",
                    help='run on CUDA (default: True)')
parser.add_argument('--num_modes', type=int, default=5, metavar='N',
                    help='number of modes (default: 1)')    
parser.add_argument('--save_interval', type=int, default=10, metavar='N',
                    help='interval for saving models (default: 50)')  
parser.add_argument('--eval_interval', type=int, default=50, metavar='N',
                    help='interval for evaluating models (default: 50)')  
parser.add_argument('--rc', type=float, default=1.0, metavar='G',
                    help='scale for rewards (default: 1.0)')  
parser.add_argument('--src', type=float, default=0.0, metavar='G',
                    help='scale for pseudo rewards (default: 0.0)')                
args = parser.parse_args()

# Environment
# env = NormalizedActions(gym.make(args.env_name))
env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
# trainers = [SACTrainer(env.observation_space.shape[0], env.action_space, args) for _ in range(args.num_modes)]
trainers = []
for i in range(args.num_modes):
    trainers.append(SACTrainer(env.observation_space.shape[0], env.action_space, args))
    args.hidden_dim += 1
disc_trainer = DiscTrainer(env.observation_space.shape[0], args)

path_prefix = os.environ['UDG_DATA_PATH']

# Tesnorboard
start_T = datetime.datetime.now()
writer = SummaryWriter(path_prefix+'url-data/logs/{}/{}_diayn_{}_{}'.format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"),
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Logging hyperparameters
# writer.add_hparams(hparam_dict=vars(args), metric_dict={})

# Memory
memories = [ReplayMemory(args.replay_size) for _ in range(args.num_modes)]
disc_memory = ReplayMemory(args.replay_size)

# Training Loop
total_numsteps = 0
updates = 0

SR_MIN = np.log(1 / args.num_modes) * 10
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_sr = 0
    episode_r = 0
    episode_steps = 0
    _state = []
    _action = []
    _reward = []
    _next_state = []
    _mask = []
    done = False
    state = env.reset()
    label = i_episode % args.num_modes
    while not done:
        if args.start_steps > len(memories[label]):
            action = env.action_space.sample()  # Sample random action
        else:
            action, _ = trainers[label].act(state)  # Sample action from policy

        if len(memories[label]) > args.batch_size:
            # Number of updates per step in environment
            for i in range(args.updates_per_step):
                # Update parameters of all the networks
                
                samples = memories[label].sample(args.batch_size)
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = trainers[label].update_parameters(samples, updates)
                disc_samples = disc_memory.sample(args.batch_size)
                d_loss = disc_trainer.update_parameters(disc_samples)

                # writer.add_scalar('loss/critic_1', critic_1_loss, updates)
                # writer.add_scalar('loss/critic_2', critic_2_loss, updates)
                # writer.add_scalar('loss/policy', policy_loss, updates)
                # writer.add_scalar('loss/entropy_loss', ent_loss, updates)
                # writer.add_scalar('entropy_temprature/alpha', alpha, updates)
                # writer.add_scalar('loss/disc_loss', d_loss, updates)

                updates += 1

        next_state, reward, done, _ = env.step(action) # Step
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward

        # sr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
        # r = sr * args.src + reward * args.rc
        # episode_sr += sr
        # episode_r += r # reward for environment reward, r for composite reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        # memories[label].push((state, action, np.array([1.0]), r, next_state, mask)) # Append transition to memory
        disc_memory.push((np.array([label]), next_state))
        _state.append(state)
        _action.append(action)
        _reward.append(reward)
        _next_state.append(next_state)
        _mask.append(mask)

        state = next_state

    assert len(_state) == episode_steps
    labels = np.array([label for _ in range(episode_steps)])
    score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
    avg_acc = np.mean(np.exp(score)).tolist()
    srs = np.clip(score, SR_MIN, None)
    rewards = np.stack(_reward, axis=0)
    rs = srs * args.src + rewards * args.rc
    episode_sr = sum(srs.tolist())
    episode_r = sum(rs.tolist())
    for i in range(episode_steps):
        memories[label].push((_state[i], _action[i], np.array([1.0]), rs[i], _next_state[i], _mask[i]))

    if total_numsteps > args.num_steps * args.num_modes:
        break

    writer.add_scalar('train/reward', episode_reward, i_episode)
    writer.add_scalar('train/sr', episode_sr, i_episode)
    writer.add_scalar('train/r', episode_r, i_episode)
    writer.add_scalar('train/acc', avg_acc, i_episode)
    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}, sr: {}, r: {}, acc: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_sr, 2), round(episode_r, 2), round(avg_acc, 2)))

    if i_episode % args.eval_interval == 0 and args.eval is True:
        avg_reward = 0.
        avg_sr = 0.
        avg_acc = 0.
        episodes = 10
        for i in range(episodes):
            state = env.reset()
            episode_reward = 0.
            episode_sr = 0.
            episode_steps = 0
            done = False
            _next_state = []
            label = i % args.num_modes
            while not done:
                action, _ = trainers[label].act(state, eval=True)

                next_state, reward, done, _ = env.step(action)
                sr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
                episode_reward += reward
                episode_sr += sr
                episode_steps += 1

                _next_state.append(next_state)
                state = next_state

            assert len(_next_state) == episode_steps
            labels = np.array([label for _ in range(episode_steps)])
            score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
            avg_acc += np.mean(np.exp(score))

            avg_reward += episode_reward
            avg_sr += episode_sr
        avg_reward /= episodes
        avg_sr /= episodes
        avg_r = avg_reward * args.rc + avg_sr * args.src
        avg_acc /= episodes

        writer.add_scalar('test/avg_reward', avg_reward, i_episode)
        writer.add_scalar('test/avg_sr', avg_sr, i_episode)
        writer.add_scalar('test/avg_r', avg_r, i_episode)
        writer.add_scalar('test/avg_acc', avg_acc, i_episode)

        T = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        print("----------------------------------------")
        print("[{}] Test Episodes: {}, Avg. Reward: {}, Avg. SR: {}, Avg. R: {}, Avg. Acc: {}".format(T, episodes, round(avg_reward, 2), round(avg_sr, 2), round(avg_r, 2), round(avg_acc, 2)))
        print("----------------------------------------")

prefix = path_prefix+"url-data/models/diayn/{}/{}/final/".format(args.env_name, start_T.strftime("%Y-%m-%d_%H-%M-%S"))
for i in range(args.num_modes):
    trainers[i].save_model(env_name=args.env_name, prefix=prefix, suffix="{}".format(i))
    print("Saving model {} to /models...".format(i))
env.close()

