import numpy as np
import torch
import argparse
from experiment_launcher import run_experiment
from experiment_launcher.launcher import add_launcher_base_args, get_experiment_default_params
import os
import time
import json
import dmc2gym

import utils
from logger import Logger

from mtc import TCSacAgent


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--domain-name', type=str)
    parser.add_argument('--task-name', type=str)
    parser.add_argument('--num-train-steps', type=int)
    parser.add_argument('--results-dir', type=str)  # modify
    parser.add_argument('--init-log-kl', type=float)
    parser.add_argument('--use-aug-reward', action='store_true')
    parser.add_argument('--use-wandb', action='store_true')
    parser.add_argument('--alg-name', type=str)
    parser.add_argument('--project-name', type=str)
    # Hyperparameters
    parser.add_argument('--kl-coef-lr', type=float)
    parser.add_argument('--kl-constraint', type=float)
    parser.add_argument('--horizon', type=int)
    # replay buffer
    parser.add_argument('--replay-buffer-capacity', type=int)
    # train
    parser.add_argument('--agent', type=str)
    parser.add_argument('--init-steps', type=int)
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--hidden-dim', type=int)
    # eval
    parser.add_argument('--eval-freq', type=int)
    parser.add_argument('--num-eval-episodes', type=int)
    # critic
    parser.add_argument('--critic-lr', type=float)
    parser.add_argument('--critic-beta', type=float)
    parser.add_argument('--critic-tau', type=float)
    parser.add_argument('--critic-target-update-freq', type=int)
    # actor
    parser.add_argument('--actor-lr', type=float)
    parser.add_argument('--actor-beta', type=float)
    parser.add_argument('--actor-log-std-min', type=float)
    parser.add_argument('--actor-log-std-max', type=float)
    parser.add_argument('--actor-update-freq', type=int)
    # encoder
    parser.add_argument('--encoder-type', type=str)
    parser.add_argument('--encoder-feature-dim', type=int)
    parser.add_argument('--encoder-tau', type=float)
    parser.add_argument('--num-layers', type=int)
    parser.add_argument('--num-filters', type=int)
    # sac
    parser.add_argument('--discount', type=float)
    parser.add_argument('--init-temperature', type=float)
    parser.add_argument('--alpha-lr', type=float)
    parser.add_argument('--alpha-beta', type=float)
    # misc
    parser.add_argument('--seed', type=int)

    parser.add_argument('--save-tb', action='store_true')
    parser.add_argument('--save-buffer', action='store_true')
    parser.add_argument('--save-model', action='store_true')
    parser.add_argument('--detach-encoder', action='store_true')

    parser.add_argument('--log-interval', type=int)

    parser = add_launcher_base_args(parser)
    parser.set_defaults(**get_experiment_default_params(experiment))
    args = parser.parse_args()
    return vars(args)


def evaluate(env, agent, num_episodes, L, step):
    '''
    Evaluate the agent
    '''
    all_ep_rewards = []

    def run_eval_loop(sample_stochastically=True):
        start_time = time.time()
        prefix = 'stochastic_' if sample_stochastically else ''
        for i in range(num_episodes):
            obs = env.reset()
            done = False
            episode_reward = 0
            while not done:
                with utils.eval_mode(agent):
                    if sample_stochastically:
                        action = agent.sample_action(obs)
                    else:
                        action = agent.select_action(obs)
                obs, reward, done, _ = env.step(action)
                episode_reward += reward

            L.log('eval/' + prefix + 'episode_reward', episode_reward, step)
            all_ep_rewards.append(episode_reward)

        L.log('eval/' + prefix + 'eval_time', time.time() - start_time, step)
        mean_ep_reward = np.mean(all_ep_rewards)
        best_ep_reward = np.max(all_ep_rewards)
        L.log('eval/' + prefix + 'mean_episode_reward', mean_ep_reward, step)
        L.log('eval/' + prefix + 'best_episode_reward', best_ep_reward, step)


    run_eval_loop(sample_stochastically=False)
    L.dump(step)

def make_agent(obs_shape, action_shape, args, device):
    if args.agent == 'MTC':
        return TCSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            kl_coef_lr=args.kl_coef_lr,
            kl_constraint=args.kl_constraint,
            horizon=args.horizon,
            init_log_kl=args.init_log_kl,
            use_aug_reward=args.use_aug_reward
        )
    else:
        assert 'agent is not supported: %s' % args.agent


def experiment(
        domain_name: str = 'walker',
        task_name: str = 'walk',
        replay_buffer_capacity: int = 1000000,
        agent: str = 'MTC',
        init_steps: int = 5000,
        num_train_steps: int = 1e6,
        batch_size: int = 256,
        hidden_dim: int = 1024,
        eval_freq: int = 10000,
        num_eval_episodes: int = 10,
        critic_lr: float = 1e-4,
        critic_beta: float = 0.9,
        critic_tau: float = 0.01,
        critic_target_update_freq: int = 2,
        actor_lr: float = 1e-4,
        actor_beta: float = 0.9,
        actor_log_std_min: float = -10,
        actor_log_std_max: float = 2,
        actor_update_freq: int = 1,
        encoder_type: str = 'prop',
        encoder_feature_dim: int = 30,
        encoder_tau: float = 0.05,
        num_layers: int = 4,
        num_filters: int = 32,
        discount: float = 0.99,
        init_temperature: float = 0.1,
        alpha_lr: float = 1e-4,
        alpha_beta: float = 0.5,
        save_tb: bool = True,
        save_buffer: bool = False,
        save_model: bool = False,
        detach_encoder: bool = False,
        kl_coef_lr: float = 1e-4,
        kl_constraint: float = -0.5,
        horizon: int = 8,
        init_log_kl: float = 1e-6,
        use_aug_reward: bool = True,
        log_interval: int = 100,
        seed: int = 0,
        results_dir: str = './logs'
):

    os.makedirs(results_dir, exist_ok=True)
    utils.set_seed_everywhere(seed)

    env = dmc2gym.make(
        domain_name=domain_name,
        task_name=task_name,
        seed=seed,
        visualize_reward=False
    )

    env.seed(seed)
    action_shape = env.action_space.shape
    obs_shape = env.observation_space.shape

    if encoder_feature_dim == 'obs':
        encoder_feature_dim = obs_shape[0]

    args = utils.Namespace(
        domain_name=domain_name,
        task_name=task_name,
        replay_buffer_capacity=replay_buffer_capacity,
        agent=agent,
        init_steps=init_steps,
        num_train_steps=num_train_steps,
        batch_size=batch_size,
        hidden_dim=hidden_dim,
        eval_freq=eval_freq,
        num_eval_episodes=num_eval_episodes,
        critic_lr=critic_lr,
        critic_beta=critic_beta,
        critic_tau=critic_tau,
        critic_target_update_freq=critic_target_update_freq,
        actor_lr=actor_lr,
        actor_beta=actor_beta,
        actor_log_std_min=actor_log_std_min,
        actor_log_std_max=actor_log_std_max,
        actor_update_freq=actor_update_freq,
        encoder_type=encoder_type,
        encoder_feature_dim=encoder_feature_dim,
        encoder_tau=encoder_tau,
        num_layers=num_layers,
        num_filters=num_filters,
        discount=discount,
        init_temperature=init_temperature,
        alpha_lr=alpha_lr,
        alpha_beta=alpha_beta,
        save_tb=save_tb,
        save_buffer=save_buffer,
        save_model=save_model,
        detach_encoder=detach_encoder,
        log_interval=log_interval,
        kl_coef_lr=kl_coef_lr,
        kl_constraint=kl_constraint,
        horizon=horizon,
        init_log_kl=init_log_kl,
        use_aug_reward=use_aug_reward,
        seed=seed,
        results_dir=results_dir
    )

    utils.make_dir(args.results_dir)
    model_dir = utils.make_dir(os.path.join(args.results_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(args.results_dir, 'buffer'))


    with open(os.path.join(args.results_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    print('result_dir:{}'.format(args.results_dir))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.device = device

    replay_buffer = utils.ReplayBuffer(
        obs_shape=obs_shape,
        action_shape=action_shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device
    )

    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )

    # tensorboard visualization
    L = Logger(args.results_dir, use_tb=args.save_tb)

    episode, episode_reward, done = 0, 0, True
    start_time = time.time()

    for step in range(int(args.num_train_steps)):
        # evaluate agent periodically

        if step % args.eval_freq == 0:
            L.log('eval/episode', episode, step)
            evaluate(env, agent, args.num_eval_episodes, L, step)
            if args.save_model and episode % 500 == 0:
                agent.save(model_dir, step)
            if args.save_buffer:
                replay_buffer.save(buffer_dir)

        if done:
            if step > 0:
                if step % args.log_interval == 0:
                    L.log('train/duration', time.time() - start_time, step)
                    # log data into train.log
                    L.dump(step)
                start_time = time.time()
            if step % args.log_interval == 0:
                L.log('train/episode_reward', episode_reward, step)

            obs = env.reset()
            episode_reward = 0
            episode_step = 0
            episode += 1
            if step % args.log_interval == 0:
                L.log('train/episode', episode, step)

        if step < args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs)

        # run training update
        if step >= args.init_steps:
            num_updates = 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step)

        next_obs, reward, done, _ = env.step(action)

        if episode_step + 1 == env._max_episode_steps:
            done_bool = 0   # distinguish truncated done and termination done
            done_max = 1
        else:
            done_bool = float(done)
            done_max = float(done)

        episode_reward += reward
        replay_buffer.add(obs, action, reward, next_obs, done_bool, done_max)

        obs = next_obs
        episode_step += 1


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    run_experiment(experiment)
