import numpy as np

from rlkit.samplers.util import rollout
from rlkit.torch.sac.policies import MakeDeterministic


class InPlacePathSampler(object):
    """
    A sampler that does not serialization for sampling. Instead, it just uses
    the current policy and environment as-is.

    WARNING: This will affect the environment! So
    ```
    sampler = InPlacePathSampler(env, ...)
    sampler.obtain_samples  # this has side-effects: env will change!
    
    
    
    ```
    """
    def __init__(self, env, policy, max_path_length, adapter, gpu_id, adapt_steps):
        self.env = env
        self.policy = policy
        self.adapter=adapter
        self.gpu_id=gpu_id
        self.adapt_steps=adapt_steps

        self.max_path_length = max_path_length

    def start_worker(self):
        pass

    def shutdown_worker(self):
        pass

    def obtain_samples(self, index, max_trajs=np.inf, deterministic=False, max_samples=np.inf, testing=False, action_dim=None, hyperparam_dim=None):
        """
        Obtains samples in the environment until either we reach either max_samples transitions or
        num_traj trajectories.
        """
        assert max_samples < np.inf
        policy = MakeDeterministic(self.policy) if deterministic else self.policy
        paths = []
        n_steps_total = 0
        n_trajs = 0
        while n_steps_total < max_samples and n_trajs < max_trajs:
            path = rollout(
                self.env, policy, index, max_path_length=self.max_path_length, adapter=self.adapter, testing=testing, action_dim=action_dim, hyperparam_dim=hyperparam_dim, gpu_id=self.gpu_id, adapt_steps=self.adapt_steps)
            paths.append(path)
            n_steps_total += len(path['observations'])
            n_trajs += 1
          
        return paths, n_steps_total

