import gym
import torch
import numpy as np


class EnvSampler():
    def __init__(self, env, max_path_length=1000):
        self.env = env

        self.path_length = 0
        self.current_state = None
        self.max_path_length = max_path_length
        self.path_rewards = []
        self.sum_reward = 0

    def sample(self, agent, eval_t=False):
        if self.current_state is None:
            self.current_state = np.asarray([self.env.reset()])

        cur_state = self.current_state
        with torch.no_grad():
            cur_state_tensor = torch.tensor(cur_state, dtype=torch.float32)
            emb_state = np.asarray(agent.embedder(cur_state_tensor))

        action = agent.select_action(emb_state, eval_t)[0]
        next_state, reward, terminal = self.env.step(action, eval=eval_t)
        with torch.no_grad():
            next_state = np.asarray([next_state])
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
            emb_next_state = np.asarray(agent.embedder(next_state_tensor))

        self.path_length += 1
        self.sum_reward += reward

        # TODO: Save the path to the env_pool
        if terminal or self.path_length >= self.max_path_length:
            self.current_state = None
            self.path_length = 0
            self.path_rewards.append(self.sum_reward)
            self.sum_reward = 0
        else:
            self.current_state = next_state

        return emb_state[0], action, reward, emb_next_state[0], terminal, cur_state[0]
