import multiprocessing as mp

import gym
import torch
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.envs import ENVS
# from envs.subproc_vec_env import SubprocVecEnv
from .episode import BatchEpisodes
import random
import numpy as np

def make_env(env_name):
    def _make_env():
        return gym.make(env_name)

    return _make_env


class BatchSampler(object):
    def __init__(self, env_name, batch_size, device, seed):
        self.env_name = env_name
        self.batch_size = batch_size
        self.num_workers = 1
        self.device = device

        self.env = NormalizedBoxEnv(ENVS[env_name]())
        self.env.set_seed(seed)

    def set_test_task(self,test_env,seed):
        if test_env == 'cheetah-vel':
            self.env.set_velocity(-2.0) # set velocity (-2)
        elif test_env == 'cheetah-dir':
            self.env.set_direction(-1) # set direction (forward)
        elif test_env == 'ant-goal':
            self.env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
        elif test_env == 'ant-dir':
            self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
        elif 'params' in test_env:
            self.env.set_test_task()
        self.env.set_seed(seed)
        tasks = self.env.get_all_task_idx()
        eval_tasks = list(tasks)
        return eval_tasks

    def sample(self, policy, params=None, gamma=0.95, batch_size=None,get_steps=False):
        if batch_size is None:
            batch_size = self.batch_size
        episodes = BatchEpisodes(batch_size=batch_size, gamma=gamma, device=self.device)
        total_steps = 0
        for idx in range(batch_size):
            observation = self.env.reset()
            done = False
            steps = 0
            while not done and not(steps >= 200):
                steps += 1
                with torch.no_grad():
                    observation_tensor = torch.from_numpy(observation).to(device=self.device)
                    action_tensor = policy(observation_tensor, params=params).sample()
                    action = action_tensor.cpu().numpy()
                new_observation, reward, done, _ = self.env.step(action)
                episodes.append(observation, action, reward, idx)
                observation = new_observation
            total_steps += steps
        if get_steps:
            return episodes, total_steps
        else:
            return episodes

    def reset_task(self, task_idx):
        self.env.reset_task(task_idx)

    def sample_tasks(self, num_tasks):
        tasks = self.env.get_all_task_idx()
        train_tasks = list(tasks)
        sampled_tasks = random.sample(train_tasks, k=num_tasks)
        return sampled_tasks
    
    def set_train_task(self):
        self.env.set_train_task(30)
