import numpy as np
import torch
class DiffSampler(object):
    def __init__(self, env, max_traj_length=1000):
        self.max_traj_length = max_traj_length
        self._env = env
        self._traj_steps = 0

    def sample(self, policy, init_obss, replay_buffer=None, iql=False, threshold=0):
        observations = init_obss

        for _ in range(self.max_traj_length):
            # a = torch.isnan(observations).any()
            actions = policy(observations, input_is_torch=True)
            next_observations, rewards, terminals, _ = self._env.multi_step(observations, actions, std_threshold=threshold)
            if replay_buffer is not None:
                replay_buffer.add_traj(observations, actions, rewards, next_observations, terminals)

            nonterm_mask = (~terminals).flatten()
            if nonterm_mask.sum() == 0:
                break

            observations = next_observations[nonterm_mask]


class StepSampler(object):

    def __init__(self, env, max_traj_length=1000):
        self.max_traj_length = max_traj_length
        self._env = env
        self._traj_steps = 0
        self._current_observation = self.env.reset()

    def sample(self, policy, n_steps, deterministic=False, replay_buffer=None, iql=False):
        observations = []
        actions = []
        rewards = []
        next_observations = []
        dones = []

        for _ in range(n_steps):
            self._traj_steps += 1
            observation = self._current_observation
            if iql:
                tensor_observation = torch.tensor(
                    observation, dtype=torch.float32, device='cuda:0'
                )
                action = policy.act(tensor_observation, deterministic=deterministic).cpu().numpy()
            else:
                action = policy(
                    np.expand_dims(observation, 0), deterministic=deterministic
                )[0, :]
            next_observation, reward, done, _ = self.env.step(action)
            done2 = done
            if _["diff_timeout"]:
                done2 = False
            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            dones.append(done2)
            next_observations.append(next_observation)

            if replay_buffer is not None:
                replay_buffer.add_sample(
                    observation, action, reward, next_observation, done2
                )

            self._current_observation = next_observation

            if done or self._traj_steps >= self.max_traj_length:
                self._traj_steps = 0
                self._current_observation = self.env.reset()

        return dict(
            observations=np.array(observations, dtype=np.float32),
            actions=np.array(actions, dtype=np.float32),
            rewards=np.array(rewards, dtype=np.float32),
            next_observations=np.array(next_observations, dtype=np.float32),
            dones=np.array(dones, dtype=np.float32),
        )

    @property
    def env(self):
        return self._env


class TrajSampler(object):

    def __init__(self, env, max_traj_length=1000, normalizer=None):
        self.max_traj_length = max_traj_length
        self._env = env
        self.normalizer = normalizer
    def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):
        trajs = []
        for _ in range(n_trajs):
            observations = []
            actions = []
            rewards = []
            next_observations = []
            dones = []

            observation = self.env.reset()

            for _ in range(self.max_traj_length):
                normalized_obs = observation
                if self.normalizer:
                    normalized_obs = self.normalizer['observations'](observation)
                #normalized_obs = self.normalizer(observation)
                action = policy(
                    np.expand_dims(normalized_obs, 0), deterministic=deterministic
                )[0, :]
                # if self.normalizer:
                #     action = self.normalizer['actions'].unnormalize(action)
                next_observation, reward, done, _ = self.env.step(action)
                observations.append(observation)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                next_observations.append(next_observation)

                if replay_buffer is not None:
                    replay_buffer.add_sample(
                        observation, action, reward, next_observation, done
                    )

                observation = next_observation

                if done:
                    break

            trajs.append(dict(
                observations=np.array(observations, dtype=np.float32),
                actions=np.array(actions, dtype=np.float32),
                rewards=np.array(rewards, dtype=np.float32),
                next_observations=np.array(next_observations, dtype=np.float32),
                dones=np.array(dones, dtype=np.float32),
            ))

        return trajs

    @property
    def env(self):
        return self._env
