import matplotlib as mpl
import random
import numpy as np
import torch
from utils import helpers as utl
import matplotlib.pyplot as plt
import seaborn as sns

from gym import Env
from gym import spaces

import metaworld

from utils.helpers import get_device

class ML1(Env):

    def __init__(self, max_episode_steps=500, SEED=10, env_name='reach-v2', test=False):
        ml1 = metaworld.ML1(env_name, seed=SEED)
        self.SEED = SEED
        self.env_name = env_name
        self._env = ml1.train_classes[env_name]()
        self.train_tasks = ml1.train_tasks
        self.test_tasks = ml1.test_tasks
        self.test = test

        self.reset_task()
        self.task_dim = 1

        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space

        self._max_episode_steps = max_episode_steps

    def set_task(self, task):
        self._task = task # This is an index for the task

    def get_task(self):
        return self._task

    def reset_task(self, task=None): # if task is not None, it will be an index for the task
        if task is None:
            subtask_ind = random.choice(range(50))
            self.set_task(subtask_ind)
        else:
            subtask_ind = task
            self.set_task(subtask_ind)
        
        if self.test:
            self._env.set_task(self.test_tasks[self._task])
        else:
            self._env.set_task(self.train_tasks[self._task])
            
        self.reset()
        return self._state

    def _reset_model(self):
        # resetting to unwrapped metaworld initial position not the task type
        self._state = self._env.reset()
        return self._get_obs()

    def reset(self, task=None):
        if task is not None:
            self.reset_task(task)
        return self._reset_model()

    def _get_obs(self):
        return np.copy(self._state)

    def step(self, action):
        #self._env.render(offscreen=True) #for rendering
        action = np.clip(action, self.action_space.low, self.action_space.high)
        self._state, reward, done, info = self._env.step(action)
        ob = self._get_obs()

        # info = {'task': self.get_task(), 'success': info['success'], 'image': self._env.render(offscreen=True)}  # for rendering
        info = {'task': self.get_task(), 'success': info['success']}

        return ob, reward, done, info
