from rlf.algos.il.base_il import ExperienceGenerator
from goal_prox.policies.grid_world_expert import GridWorldExpert
import gym
from goal_prox.envs.gw_helper import *
from goal_prox.policies.grid_world_expert import GridWorldExpert
import torch

class GwExpGenerator(ExperienceGenerator):
    def init(self, policy, args, exp_gen_num_trans):
        self.args = args
        self.env = GoalCheckerWrapper(FullyObsWrapper(gym.make(args.env_name)), args)
        self.expert = GridWorldExpert()
        self.expert.init(self.env.observation_space, self.env.action_space,
                args)
        self.obs = self.env.reset()
        self.masks = torch.ones(1)
        self.num_trans = exp_gen_num_trans
        self.cur_trans = 0

    def _trans_obs(self, obs):
        return torch.FloatTensor(np.expand_dims(obs.transpose(2, 0, 1), axis=0))

    def _generate_exp(self):
        use_obs = self._trans_obs(self.obs)
        ac = self.expert.get_action(use_obs, None, self.masks, None)
        obs, reward, done, _ = self.env.step(ac.take_action)
        self.masks = torch.FloatTensor([not done])
        if done:
            obs = self.env.reset()

        self.cur_trans += 1
        if self.cur_trans == self.num_trans:
            return None
        self.obs = obs
        return self._trans_obs(self.obs), ac.action

    def get_batch(self):
        exp = [self._generate_exp() for _ in range(self.args.traj_batch_size)]
        if None in exp:
            return None
        states, actions = zip(*exp)

        return {
                "state": torch.stack(states).squeeze(1),
                "actions": torch.stack(actions),
                }

    def reset(self):
        self.cur_trans = 0
