"""
This should results in an average return of ~3000 by the end of training.
Usually hits 3000 around epoch 80-100. Within a see, the performance will be
a bit noisy from one epoch to the next (occasionally dips dow to ~2000).
Note that one epoch = 5k steps, so 200 epochs = 1 million steps.
"""
from gym.envs.mujoco import HalfCheetahEnv

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.exploration_strategies.base import \
    PolicyWrappedWithExplorationStrategy
from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy
from rlkit.launchers.launcher_util import setup_logger
from rlkit.samplers.data_collector import MdpPathCollector
from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
from rlkit.torch.td3.td3 import TD3Trainer
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

from huge import envs
from huge.algo import buffer, huge, variants, networks
import gym
import wandb

class UnWrapper(gym.Env):
    def __init__(self, env, max_path_legnth):
        super(UnWrapper, self).__init__()
        self._env = env

        self.state_space = self.observation_space

        self.goal = self._env.extract_goal(self._env.sample_goal())

        self.max_path_length = max_path_legnth
        self.current_timestep = 0

    def __getattr__(self, attr):
        return getattr(self._env, attr)
    
    @property
    def action_space(self, ):
        return self._env.action_space

    @property
    def observation_space(self, ):
        return self._env.observation_space

    def compute_shaped_distance(self, state, goal):
        return self._env.compute_shaped_distance(state, goal)
        
    def render(self):
        self._env.render()

    def reset(self):
        """
        Resets the environment and returns a state vector
        Returns:
            The initial state
        """
        return self._env.observation(self._env.reset())

    def step(self, a):
        """
        Runs 1 step of simulation
        Returns:
            A tuple containing:
                next_state
                reward (always 0)
                done
                infos
        """
        self.current_timestep +=1
        new_state, reward, done, info = self._env.step(a)
        new_state = self._env.observation(new_state)
        reward = - self._env.compute_shaped_distance(new_state, self.goal)
        info['reward'] = reward
        done =  done or self.current_timestep == self.max_path_length
        if done:
            self.current_timestep = 0

        return new_state, reward, done, info


    def observation(self, state):
        """
        Returns the observation for a given state
        Args:
            state: A numpy array representing state
        Returns:
            obs: A numpy array representing observations
        """
        return  self._env.observation(state)
    
    def extract_goal(self, state):
        """
        Returns the goal representation for a given state
        Args:
            state: A numpy array representing state
        Returns:
            obs: A numpy array representing observations
        """
        return  self._env.extract_goal(state)

    def goal_distance(self, state, ):
        return self._env.goal_distance(state, self.goal)

    def sample_goal(self):
        return self.goal #self.goal_space.sample()



def experiment(variant, env_name):
    max_path_length = 50
    task_config = ""
    num_blocks = 0
    network_layers = "60,60"
    buffer_size = 1000
    fourier = True
    fourier_goal_selector = True
    normalize = False
    goal_selector_name = ""

    env = envs.create_env(env_name, task_config, num_blocks)
    env_params = envs.get_env_params(env_name)
    env_params['max_trajectory_length']=max_path_length
    env_params['network_layers']=network_layers # TODO: useless
    env_params['reward_model_name'] = ''
    env_params['buffer_size']=buffer_size
    env_params['fourier']=fourier
    env_params['fourier_goal_selector']=fourier_goal_selector
    env_params['normalize'] = normalize
    env_params['env_name'] = env_name
    env_params['goal_selector_name']=goal_selector_name


    wrapped_env, policy, reward_model, _, reward_model_buffer_1, gcsl_kwargs = variants.get_params_ddl(env, env_params)

    unwrapped_env = UnWrapper(wrapped_env, max_path_length)

    env2 = envs.create_env(env_name, task_config, num_blocks)

    wrapped_env2, eval_policy, reward_model, _, reward_model_buffer, gcsl_kwargs = variants.get_params_ddl(env2, env_params)

    unwrapped_env2 = UnWrapper(wrapped_env2, max_path_length)

    expl_env = NormalizedBoxEnv(unwrapped_env)
    eval_env = NormalizedBoxEnv(unwrapped_env2)
    #expl_env = NormalizedBoxEnv(HalfCheetahEnv())
    #eval_env = NormalizedBoxEnv(HalfCheetahEnv())
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()

import argparse
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default='pointmass_empty')
    parser.add_argument("--num_epochs", type=int, default=3000)
    parser.add_argument("--num_eval_steps_per_epoch", type=int, default=1000)
    parser.add_argument("--num_trains_per_train_loop", type=int, default=1000)
    parser.add_argument("--num_expl_steps_per_train_loop", type=int, default=1000)
    parser.add_argument("--min_num_steps_before_training", type=int, default=1000)
    parser.add_argument("--max_path_length", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=256)
    args = parser.parse_args()

    variant = dict(
        algorithm_kwargs=dict(
            num_epochs=args.num_epochs,
            num_eval_steps_per_epoch=args.num_eval_steps_per_epoch,
            num_trains_per_train_loop=args.num_trains_per_train_loop,
            num_expl_steps_per_train_loop=args.num_expl_steps_per_train_loop,
            min_num_steps_before_training=args.min_num_steps_before_training,
            max_path_length=args.max_path_length,
            batch_size=args.batch_size,
        ),
        trainer_kwargs=dict(
            discount=0.99,
        ),
        qf_kwargs=dict(
            hidden_sizes=[400, 600, 600, 300],
        ),
        policy_kwargs=dict(
            hidden_sizes=[400, 600, 600, 300],
        ),
        replay_buffer_size=int(1E6),
    )
    wandb.init(project=args.env_name+"gcsl_preferences", name=f"{args.env_name}_td3_0", config={
    })
    # ptu.set_gpu_mode(True)  # optionally set the GPU (default=False)
    setup_logger('rlkit-post-refactor-td3-half-cheetah', variant=variant)
    experiment(variant, args.env_name)