import numpy as np
import torch
from infreastructure.pytorch_util import to_numpy,from_numpy


class StepSampler(object):

    def __init__(self, env, max_traj_length=1000):
        self.max_traj_length = max_traj_length
        self._env = env
        self._traj_steps = 0
        self._current_observation = self.env.reset()

    def sample(self, policy, n_steps, deterministic=False, replay_buffer=None):
        observations = []
        actions = []
        rewards = []
        next_observations = []
        dones = []

        for _ in range(n_steps):
            self._traj_steps += 1
            #self._env.render()
            observation = self._current_observation
            action = policy.call(
                np.expand_dims(observation, 0), deterministic=deterministic
            )[0, :]
            next_observation, reward, done, _ = self.env.step(action)
            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            next_observations.append(next_observation)

            if replay_buffer is not None:
                replay_buffer.add_sample(
                    observation, action, reward, next_observation, done
                )

            self._current_observation = next_observation

            if done or self._traj_steps >= self.max_traj_length:
                self._traj_steps = 0
                self._current_observation = self.env.reset()

        return dict(
            observations=np.array(observations, dtype=np.float32),
            actions=np.array(actions, dtype=np.float32),
            rewards=np.array(rewards, dtype=np.float32),
            next_observations=np.array(next_observations, dtype=np.float32),
            dones=np.array(dones, dtype=np.float32),
        )

    @property
    def env(self):
        return self._env


class TrajSampler(object):

    def __init__(self, env, max_traj_length=1000):
        self.max_traj_length = max_traj_length#min(max_traj_length,env.max_traj_length)
        self._env = env

    def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):
        trajs = []
        n_traj = min(n_trajs,self.max_traj_length)
        for _ in range(n_traj):
            observations = []
            actions = []
            rewards = []
            next_observations = []
            dones = []

            #self._env.render()
            observation = self.env.reset()

            for _ in range(self.max_traj_length):
                action = policy.call(
                    np.expand_dims(observation, 0), deterministic=deterministic
                )[0, :]
                next_observation, reward, done, _ = self.env.step(action)
                observations.append(observation)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                next_observations.append(next_observation)

                if replay_buffer is not None:
                    replay_buffer.add_sample(
                        observation, action, reward, next_observation, done
                    )

                observation = next_observation

                if done:
                    break

            trajs.append(dict(
                observations=np.array(observations, dtype=np.float32),
                actions=np.array(actions, dtype=np.float32),
                rewards=np.array(rewards, dtype=np.float32),
                next_observations=np.array(next_observations, dtype=np.float32),
                dones=np.array(dones, dtype=np.float32),
            ))

        return trajs

    @property
    def env(self):
        return self._env

def sample_traj(env,model,policy,horizon,init_obs,device):
    #horizon = min(horizon,env.spec._horizon)
    obs = []
    acs = []
    ob = init_obs.to('cuda')
    for t in range(horizon):
        ac,_ = policy(ob)
        acs.append(to_numpy(ac.clone()))
        obs.append(to_numpy(ob.clone()))
        ob = model.predict(ob,ac)

    obs = from_numpy(np.array(obs)).to(device)
    obs = torch.reshape(obs,(obs.shape[1],obs.shape[0],obs.shape[2]))
    acs = from_numpy(np.array(acs)).to(device)
    acs = torch.reshape(acs,(acs.shape[1],acs.shape[0],acs.shape[2]))
    return {
        'observations':obs,
        'actions':acs
    }

def get_data(dataset,step,batch_size):
    num = min(batch_size * (step + 1),len(dataset['observations']))
    idxs = np.random.permutation(num)
    idxs = idxs[:batch_size]
    return dataset['observations'][idxs]

def get_init_obs(env):
    init_obs = []
    for i in range(256):
        init_obs.append(env.reset().astype(np.float32))
    return np.array(init_obs)