from rlf import NestedAlgo
from rlf import PPO
from rlf.algos.il.gaifo import GaifoDiscrim
import rlf.il.utils as iutils
from collections import deque
import numpy as np
import rlf.rl.utils as rutils
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import goal_prox.method.utils as mutils
import torch

class GoodGAIFO(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_discrim=None):
        super().__init__([GoodGaifoDiscrim(get_discrim), agent_updater], 1)

class GoodGaifoDiscrim(GaifoDiscrim):
    def init(self, policy, args):
        super().init(policy, args)

        self.exp_buff_size = args.exp_buff_size
        if self.exp_buff_size == -1:
            self.exp_buff_size = args.num_processes * args.num_steps
        self.failure_agent_trajs = deque(maxlen=self.exp_buff_size)
        self.success_agent_trajs = deque(maxlen=self.exp_buff_size)

    def _get_sampler(self, storage):
        take_count = self.exp_buff_size
        use_success_trajs = iutils.mix_data(self.success_agent_trajs, self.expert_dataset,
                take_count, self.args.exp_ratio)
        use_failure_trajs = self.failure_agent_trajs
        if len(self.failure_agent_trajs) > take_count:
            use_failure_trajs = np.random.choice(self.use_failure_trajs,
                    take_count, replace=False)

        self.use_success_trajs = iutils.convert_list_dict(use_success_trajs,
                self.args.device)
        self.use_failure_trajs = iutils.convert_list_dict(use_failure_trajs,
                self.args.device)

        failure_sampler = BatchSampler(SubsetRandomSampler(
            range(min(len(use_failure_trajs), take_count))),
                self.args.traj_batch_size, drop_last=True)
        success_sampler = BatchSampler(SubsetRandomSampler(range(take_count)),
                self.args.traj_batch_size, drop_last=True)

        return success_sampler, failure_sampler

    def _trans_batches(self, expert_batch, agent_batch):
        expert_batch = iutils.select_idx_from_dict(expert_batch,
                self.use_success_trajs)
        agent_batch = iutils.select_idx_from_dict(agent_batch,
                self.use_failure_trajs)
        return expert_batch, agent_batch

    def get_env_settings(self, args):
        settings = super().get_env_settings(args)
        settings.include_info_keys.extend([
                ('ep_found_goal', lambda _: (1,)),
                ('final_obs', lambda env: rutils.get_obs_shape(env.observation_space))
                ])
        return settings

    def on_traj_finished(self, trajs):
        super().on_traj_finished(trajs)
        obs, obs_add, actions, masks, add_data, rewards = iutils.traj_to_tensor(trajs,
                self.args.device)

        n_trajs = len(trajs)
        final_state = add_data['final_obs'][:, -1]
        final_state = final_state.view(n_trajs, -1, *final_state.shape[1:])
        next_state = obs[:, 1:]
        next_state = torch.cat([next_state, final_state], dim=1)

        is_success, end_t = mutils.get_success(add_data, masks)

        for i in range(n_trajs):
            if is_success[i]:
                self.success_agent_trajs.extend([{'state': o, 'mask': m, 'actions': a, 'next_state': ns}
                        for o, m, a, ns in zip(obs[i], masks[i], actions[i], next_state[i])])
            else:
                self.failure_agent_trajs.extend([{'state': o, 'mask': m, 'action': a, 'next_state': ns}
                        for o, m, a, ns in zip(obs[i], masks[i], actions[i], next_state[i])])

    def get_add_args(self, parser):
        super().get_add_args(parser)
        # Kept as -1 as GAIL uses on-policy data to make the update.
        parser.add_argument('--exp-buff-size', type=int, default=-1)
        parser.add_argument('--exp-ratio', type=float, default=0.5)
