import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

from sac_models import StochasticActor, Critic, SAC
from samplers import Sampler
from buffers import ReplayBuffer
from envs.envs import (ExpertInvertedPendulumEnv, ExpertInvertedDoublePendulumEnv)
from envs.more_envs import CustomReacher2Env, CustomReacher3Env
from envs.manipulation_envs import PusherEnv, PusherHumanSimEnv, ReachEnv, ReachHumanSimEnv

def train_expert(env_name):
    """Train expert policy in given environment."""
    if env_name == 'InvertedDoublePendulum-v2':
        expert_env = ExpertInvertedDoublePendulumEnv()
        episode_limit = 1000
    elif env_name == 'InvertedPendulum-v2':
        expert_env = ExpertInvertedPendulumEnv()
        episode_limit = 1000
    elif env_name == 'Reacher2-v2':
        expert_env = CustomReacher2Env()
        episode_limit = 50
    elif env_name == 'Reacher3-v2':
        expert_env = CustomReacher3Env()
        episode_limit = 50
    # ================================================== DMC
    elif env_name == 'DMCartPoleSwingUp':
        from envs.dmcontrol_envs import DMCartPoleSwingUpEnv
        expert_env = DMCartPoleSwingUpEnv()
        episode_limit = 1000
    elif env_name == 'DMPendulum':
        from envs.dmcontrol_envs import DMPendulumEnv
        expert_env = DMPendulumEnv()
        episode_limit = 1000
    elif env_name == 'DMAcrobot':
        from envs.dmcontrol_envs import DMAcrobotEnv
        expert_env = DMAcrobotEnv()
        episode_limit = 1000
    elif env_name == 'DMWalker':
        from envs.dmcontrol_envs import DMWalkerEnv
        expert_env = DMWalkerEnv()
        episode_limit = 200
    elif env_name == 'DMCheetah':
        from envs.dmcontrol_envs import DMCheetahEnv
        expert_env = DMCheetahEnv()
        episode_limit = 200
    elif env_name == 'DMHopper':
        from envs.dmcontrol_envs import DMHopperEnv
        expert_env = DMHopperEnv()
        episode_limit = 200
    elif env_name == 'Pusher-v2':
        from envs.manipulation_envs import PusherEnv
        expert_env = PusherEnv()
        episode_limit = 200
    elif env_name == 'PusherHumanSim-v2':
        from envs.manipulation_envs import PusherHumanSimEnv
        expert_env = PusherHumanSimEnv()
        episode_limit = 200
    elif env_name == 'Reach-v2':
        from envs.manipulation_envs import ReachEnv
        expert_env = ReachEnv()
        episode_limit = 200
    elif env_name == 'ReachHumanSim-v2':
        from envs.manipulation_envs import ReachHumanSimEnv
        expert_env = ReachHumanSimEnv()
        episode_limit = 200
    else:
        raise NotImplementedError

    buffer_size = 1000000
    init_random_samples = 1000
    exploration_noise = 0.2
    learning_rate = 1e-4
    batch_size = 128
    epochs = 150
    steps_per_epoch = 5000
    updates_per_step = 1
    update_actor_every = 1
    start_training = 512
    gamma = 0.99
    polyak = 0.995
    entropy_coefficient = 0.2
    clip_actor_gradients = False
    visual_env = True
    action_size = expert_env.action_space.shape[0]
    tune_entropy_coefficient = True
    target_entropy = -1*action_size

    def make_actor():
      actor = StochasticActor([tf.keras.layers.Dense(256, 'relu'),
                    tf.keras.layers.Dense(256, 'relu'),
                    tf.keras.layers.Dense(action_size*2)])
      return actor

    def make_critic():
      critic = Critic([tf.keras.layers.Dense(256, 'relu'),
                    tf.keras.layers.Dense(256, 'relu'),
                    tf.keras.layers.Dense(1)])
      return critic
    optimizer = tf.keras.optimizers.Adam(learning_rate)

    replay_buffer = ReplayBuffer(buffer_size)
    sampler = Sampler(expert_env, episode_limit=episode_limit,
                      init_random_samples=init_random_samples, visual_env=visual_env)
    agent = SAC(make_actor,
                make_critic,
                make_critic,
                actor_optimizer=optimizer,
                critic_optimizer=optimizer,
                gamma=gamma,
                polyak=polyak,
                entropy_coefficient=entropy_coefficient,
                tune_entropy_coefficient=tune_entropy_coefficient,
                target_entropy=target_entropy,
                clip_actor_gradients=clip_actor_gradients)
    obs = np.expand_dims(expert_env.reset(), axis=0)
    agent(obs)
    agent.summary()

    mean_test_returns = []
    mean_test_std = []
    steps = []
    step_counter = 0
    for e in range(epochs):
        while step_counter < (e + 1) * steps_per_epoch:
            traj_data = sampler.sample_trajectory(agent, exploration_noise)
            replay_buffer.add(traj_data)
            if step_counter > start_training:
                agent.train(replay_buffer, batch_size=batch_size,
                            n_updates=updates_per_step * traj_data['n'],
                            act_delay=update_actor_every)
            step_counter += traj_data['n']
        print('Epoch {}/{} - total steps {}'.format(e + 1, epochs, step_counter))
        out = sampler.evaluate(agent, 10)
        mean_test_returns.append(out['mean'])
        mean_test_std.append(out['std'])
        steps.append(step_counter)

    plt.errorbar(steps, mean_test_returns, mean_test_std)
    plt.xlabel('steps')
    plt.ylabel('returns')
    plt.show()
    return agent
