import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.launchers.launcher_util import setup_logger
from rlkit.samplers.data_collector import MdpPathCollector
from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
from rlkit.torch.sac.sac import SACTrainer
from rlkit.torch.networks import ConcatMlp
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
from huge import envs
from huge.algo import buffer, huge, variants, networks
import gym
import wandb


class UnWrapper(gym.Env):
    def __init__(self, env, max_path_legnth):
        super(UnWrapper, self).__init__()
        self._env = env

        self.state_space = self.observation_space

        self.goal = self._env.extract_goal(self._env.sample_goal())

        self.max_path_length = max_path_legnth
        self.current_timestep = 0

    def __getattr__(self, attr):
        return getattr(self._env, attr)
    
    @property
    def action_space(self, ):
        return self._env.action_space

    @property
    def observation_space(self, ):
        return self._env.observation_space

    def compute_shaped_distance(self, state, goal):
        return self._env.compute_shaped_distance(state, goal)
        
    def render(self):
        self._env.render()

    def reset(self):
        """
        Resets the environment and returns a state vector
        Returns:
            The initial state
        """
        return self._env.observation(self._env.reset())

    def step(self, a):
        """
        Runs 1 step of simulation
        Returns:
            A tuple containing:
                next_state
                reward (always 0)
                done
                infos
        """
        self.current_timestep +=1
        new_state, reward, done, info = self._env.step(a)
        new_state = self._env.observation(new_state)
        reward = - self._env.compute_shaped_distance(new_state, self.goal)
        info['reward'] = reward
        done =  done or self.current_timestep == self.max_path_length
        if done:
            self.current_timestep = 0

        return new_state, reward, done, info


    def observation(self, state):
        """
        Returns the observation for a given state
        Args:
            state: A numpy array representing state
        Returns:
            obs: A numpy array representing observations
        """
        return  self._env.observation(state)
    
    def extract_goal(self, state):
        """
        Returns the goal representation for a given state
        Args:
            state: A numpy array representing state
        Returns:
            obs: A numpy array representing observations
        """
        return  self._env.extract_goal(state)

    def goal_distance(self, state, ):
        return self._env.goal_distance(state, self.goal)

    def sample_goal(self):
        return self.goal #self.goal_space.sample()


def experiment(variant, env_name):
    max_path_length = 50
    task_config = ""
    num_blocks = 0
    network_layers = "60,60"
    buffer_size = 1000
    fourier = True
    fourier_goal_selector = True
    normalize = False
    goal_selector_name = ""

    env = envs.create_env(env_name, task_config, num_blocks)
    env_params = envs.get_env_params(env_name)
    env_params['max_trajectory_length']=max_path_length
    env_params['network_layers']=network_layers # TODO: useless
    env_params['reward_model_name'] = ''
    env_params['buffer_size']=buffer_size
    env_params['fourier']=fourier
    env_params['fourier_goal_selector']=fourier_goal_selector
    env_params['normalize'] = normalize
    env_params['env_name'] = env_name
    env_params['goal_selector_name']=goal_selector_name


    wrapped_env, policy, reward_model, _, reward_model_buffer_1, gcsl_kwargs = variants.get_params_ddl(env, env_params)

    unwrapped_env = UnWrapper(wrapped_env, max_path_length)

    env2 = envs.create_env(env_name, task_config, num_blocks)

    wrapped_env2, eval_policy, reward_model, _, reward_model_buffer, gcsl_kwargs = variants.get_params_ddl(env2, env_params)

    unwrapped_env2 = UnWrapper(wrapped_env2, max_path_length)

    expl_env = NormalizedBoxEnv(unwrapped_env)
    eval_env = NormalizedBoxEnv(unwrapped_env2)
    #expl_env = NormalizedBoxEnv(HalfCheetahEnv())
    #eval_env = NormalizedBoxEnv(HalfCheetahEnv())
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 600, 600, 300],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 600, 600, 300],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 600, 600, 300],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 600, 600, 300],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[400, 600, 600, 300],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs']
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()


import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default='pointmass_empty')
    parser.add_argument("--num_epochs", type=int, default=3000)
    parser.add_argument("--num_eval_steps_per_epoch", type=int, default=1000)
    parser.add_argument("--num_trains_per_train_loop", type=int, default=1000)
    parser.add_argument("--num_expl_steps_per_train_loop", type=int, default=1000)
    parser.add_argument("--min_num_steps_before_training", type=int, default=1000)
    parser.add_argument("--max_path_length", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=256)

    args = parser.parse_args()
    # noinspection PyTypeChecker
    variant = dict(
        algorithm="SAC",
        version="normal",
        layer_size=256,
        replay_buffer_size=int(1E6),
        algorithm_kwargs=dict(
            num_epochs=args.num_epochs,
            num_eval_steps_per_epoch=args.num_eval_steps_per_epoch,
            num_trains_per_train_loop=args.num_trains_per_train_loop,
            num_expl_steps_per_train_loop=args.num_expl_steps_per_train_loop,
            min_num_steps_before_training=args.min_num_steps_before_training,
            max_path_length=args.max_path_length,
            batch_size=args.batch_size,
        ),
        trainer_kwargs=dict(
            discount=0.99,
            soft_target_tau=5e-3,
            target_update_period=1,
            policy_lr=3E-4,
            qf_lr=3E-4,
            reward_scale=1,
            use_automatic_entropy_tuning=True,
        ),
    )

    wandb.init(project=args.env_name+"gcsl_preferences", name=f"{args.env_name}_sac_0", config={
    })

    setup_logger('name-of-experiment', variant=variant)
    # ptu.set_gpu_mode(True)  # optionally set the GPU (default=False)
    experiment(variant, args.env_name)