import numpy as np
import copy
from utils import log_trajectory_statistics

seed = 0
class Sampler(object):
    """Sampler to collect and evaluate agent behavior."""

    def __init__(self, env, episode_limit=1000, init_random_samples=1000,
                 visual_env=False):
        """
        Parameters
        ----------
        env : Environment to run the agent.
        episode_limit : Maximum number of timesteps per trajectory, default is 1000.
        init_random_samples : Number of initial timesteps to execute random behavior, default is 1000.
        visual_env : Environment returns visual observations, default is False.
        """
        self._env = env
        self._env.seed(seed)
        self._eval_env = copy.deepcopy(self._env)
        self._visual_env = visual_env
        self._el = episode_limit
        self._nr = init_random_samples
        self._tc = 0
        self._ct = 0
        self._timestep = 1

        self._ob = None
        self._reset = True

    def _handle_ob(self, ob):
        if self._visual_env:
            return ob['obs']
        return ob

    def sample_trajectory(self, policy, noise_stddev, get_ims=True):
        """Collect a full trajectory with policy."""
        obs, nobs, acts, rews, dones, step, ids = [], [], [], [], [], [], []
        if self._visual_env and get_ims:
            visual_obs = []
        ret = 0
        ct = 0
        steps = 1
        done = False
        ob = self._env.reset()
        if self._visual_env and get_ims:
            _ = self._env.get_ims()
        while not done and ct < self._el:
            if self._tc < self._nr:
                act = self._env.action_space.sample()
            else:
                act = np.array(policy.get_action(np.expand_dims(ob.astype('float32'), axis=0), noise_stddev))[0]
            obs.append(ob)
            acts.append(act)
            step.append(steps)
            ob, rew, done, info = self._env.step(act)
            if False and done:
                print("Done at timestep {}, outer_reward {:3.1f}".format(int(self._tc), info['outer_reward']))
            if self._visual_env and get_ims:
                visual_obs.append(self._env.get_ims())
            nobs.append(ob)
            rews.append(rew)
            dones.append(done)
            ret += rew
            ct += 1
            self._tc += 1
            steps += 1
        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': ct, 'ret': ret, 'step': np.stack(step), 'ids': np.array(ids)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs) # 4016
        return out

    def sample_learner_trajectory(self, policy, noise_stddev, n=5, get_ims=True):
        """Collect a full trajectory with policy."""
        obs, nobs, acts, rews, dones, step, rets, ids = [], [], [], [], [], [], [], []
        if policy is None:
            print('WARNING: running random policy')
        if self._visual_env and get_ims:
            visual_obs = []
        for i in range(n):
            ret = 0
            ct = 0
            steps = 1
            done = False
            ob = self._env.reset()
            if self._visual_env and get_ims:
                _ = self._env.get_ims()
            while not done and ct < self._el:
                if policy is not None:
                    act = np.array(policy.get_action(np.expand_dims(ob.astype('float32'), axis=0), noise_stddev))[0]
                else:
                    act = self._eval_env.action_space.sample()
                obs.append(ob)
                acts.append(act)
                step.append(steps)
                ob, rew, done, info = self._env.step(act)
                if False and done:
                    print("Done at timestep {}, outer_reward {:3.1f}".format(int(self._tc), info['outer_reward']))
                if self._visual_env and get_ims:
                    visual_obs.append(self._env.get_ims())
                nobs.append(ob)
                rews.append(rew)
                dones.append(done)
                ids.append(i)
                ret += rew
                ct += 1
                self._tc += 1
                steps += 1
                rets.append(ret)
        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': ct, 'ret': rets, 'step': np.stack(step),'ids': np.array(ids)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs)
        return out

    def sample_test_trajectories(self, policy, noise_stddev, n=5, visualize=False, get_ims=False):
        """Collect multiple trajectories with policy keeping track of trajectory-specific statistics."""
        obs, nobs, acts, rews, dones, step, rets, ids = [], [], [], [], [], [], [], []
        if policy is None:
            print('WARNING: running random policy')

        if self._visual_env and get_ims:
            visual_obs = []
        for i in range(n):
            ret = 0
            ct = 0
            done = False
            steps = 1
            ob = self._eval_env.reset()
            if self._visual_env and get_ims:
                _ = self._eval_env.get_ims()
            while not done and ct < self._el:
                if policy is not None:
                    act = np.array(policy.get_action(np.expand_dims(ob.astype('float32'), axis=0), noise_stddev))[0]
                else:
                    act = self._eval_env.action_space.sample()
                obs.append(ob)
                acts.append(act)
                step.append(steps)
                ob, rew, done, info = self._eval_env.step(act)
                if False and done:
                    print("Done at timestep {}, outer_reward {:3.1f}".format(int(ct), info['outer_reward']))

                if visualize:
                    self._eval_env.render()

                if self._visual_env and get_ims:
                    visual_obs.append(self._eval_env.get_ims())
                nobs.append(ob)
                rews.append(rew)
                dones.append(done)
                ids.append(i)
                ret += rew
                ct += 1
                steps += 1
            rets.append(ret)

        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': ct, 'ret': rets,
               'ids': np.array(ids), 'step': np.stack(step)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs)
        return out

    def evaluate(self, policy, n=10, log=True, get_ims=False):
        """Collect multiple trajectories with policy and log trajectory-specific statistics."""
        traj = self.sample_test_trajectories(policy, 0.0, n, get_ims=get_ims)
        return log_trajectory_statistics(traj['ret'], log)
