from rlf import GailDiscrim, NestedAlgo
from rlf import PPO
import rlf.il.utils as iutils
from collections import deque
import numpy as np
import rlf.rl.utils as rutils
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import goal_prox.method.utils as mutils

class GoodGAIL(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_discrim=None):
        super().__init__([GoodGailDiscrim(get_discrim), agent_updater], 1)

class GoodGailDiscrim(GailDiscrim):
    def init(self, policy, args):
        super().init(policy, args)

        self.exp_buff_size = args.exp_buff_size
        if self.exp_buff_size == -1:
            self.exp_buff_size = args.num_processes * args.num_steps
        self.failure_agent_trajs = deque(maxlen=self.exp_buff_size)
        self.success_agent_trajs = deque(maxlen=self.exp_buff_size)

    def _get_sampler(self, storage):
        take_count = self.exp_buff_size
        use_success_trajs = iutils.mix_data(self.success_agent_trajs, self.expert_dataset,
                take_count, self.args.exp_ratio)
        use_failure_trajs = self.failure_agent_trajs
        if len(self.failure_agent_trajs) > take_count:
            use_failure_trajs = np.random.choice(self.use_failure_trajs,
                    take_count, replace=False)

        self.use_success_trajs = iutils.convert_list_dict(use_success_trajs,
                self.args.device)
        self.use_failure_trajs = iutils.convert_list_dict(use_failure_trajs,
                self.args.device)

        failure_sampler = BatchSampler(SubsetRandomSampler(
            range(min(len(use_failure_trajs), take_count))),
                self.args.traj_batch_size, drop_last=True)
        success_sampler = BatchSampler(SubsetRandomSampler(range(take_count)),
                self.args.traj_batch_size, drop_last=True)

        return success_sampler, failure_sampler

    def _trans_batches(self, expert_batch, agent_batch):
        expert_batch = iutils.select_idx_from_dict(expert_batch,
                self.use_success_trajs)
        agent_batch = iutils.select_idx_from_dict(agent_batch,
                self.use_failure_trajs)
        return expert_batch, agent_batch

    def get_env_settings(self, args):
        settings = super().get_env_settings(args)
        settings.include_info_keys.extend([
                ('ep_found_goal', lambda _: (1,)),
                ('final_obs', lambda env: rutils.get_obs_shape(env.observation_space))
                ])
        return settings

    def on_traj_finished(self, trajs):
        super().on_traj_finished(trajs)
        obs, obs_add, actions, masks, add_data, rewards = iutils.traj_to_tensor(trajs,
                self.args.device)

        n_trajs = len(trajs)
        final_state = add_data['final_obs'][:, -1]
        final_state = final_state.view(n_trajs, -1, *final_state.shape[1:])

        is_success, end_t = mutils.get_success(add_data, masks)

        for i in range(n_trajs):
            if is_success[i] and not self.args.ignore_success:
                self.success_agent_trajs.extend([{'state': o, 'mask': m, 'actions': a}
                        for o, m, a in zip(obs[i], masks[i], actions[i])])
            else:
                self.failure_agent_trajs.extend([{'state': o, 'mask': m, 'action': a}
                        for o, m, a in zip(obs[i], masks[i], actions[i])])

    def get_add_args(self, parser):
        super().get_add_args(parser)
        # Kept as -1 as GAIL uses on-policy data to make the update.
        parser.add_argument('--exp-buff-size', type=int, default=-1)
        parser.add_argument('--exp-ratio', type=float, default=0.5)
        parser.add_argument('--ignore-success', action='store_true')
