import numpy as np
from tqdm import tqdm
from sgcrl.environments.envs.procgen_env import ProcgenWrappedEnv, get_procgen_dataset
from sgcrl.data.episodes_readers.base_episodes_readers import EpisodesReader

class ProcgenEpisodesReader(EpisodesReader):
    def __init__(
        self,
        env_name: str = None,
        max_episode_length:int = None
    ):
        if env_name == 'procgen-500':
            self.dataset = get_procgen_dataset('datasets/procgen/level500.npz', state_based=('state' in env_name))
            min_level, max_level = 0, 499
        elif env_name == 'procgen-1000':
            self.dataset = get_procgen_dataset('datasets/procgen/level1000.npz', state_based=('state' in env_name))
            min_level, max_level = 0, 999
        else:
            raise NotImplementedError

        # Test on large levels having >=20 border states
        large_levels = [12, 34, 35, 55, 96, 109, 129, 140, 143, 163, 176, 204, 234, 338, 344, 369, 370, 374, 410, 430, 468, 470, 476, 491] + [5034, 5046, 5052, 5080, 5082, 5142, 5244, 5245, 5268, 5272, 5283, 5335, 5342, 5366, 5375, 5413, 5430, 5474, 5491]
        goal_infos = []
        goal_infos.append({'eval_level': [level for level in large_levels if min_level <= level <= max_level], 'eval_level_name': 'train'})
        goal_infos.append({'eval_level': [level for level in large_levels if level > max_level], 'eval_level_name': 'test'})

        dones_float = 1.0 - self.dataset['masks']
        dones_float[-1] = 1.0
        self.dataset = self.dataset.copy({
            'dones_float': dones_float
        })

        # get episodes
        ep_first_frame = 0
        self.lengths = []
        self.returns = []
        self.episodes_idxs = []
        dones_float = self.dataset['dones_float']
        for idx, done in tqdm(enumerate(dones_float), total=dones_float.shape[0], desc='reading transitions'):
            if done:
                
                # Store (start,end) idx, such that episode is within [start,end)
                self.episodes_idxs.append((ep_first_frame,idx+1))

                # fetch episode data
                rewards = self.dataset['rewards'][ep_first_frame:idx+1]
                self.returns.append(np.sum(rewards))

                # store real lengths
                length = len(rewards)
                self.lengths.append(length)

                # store episode
                ep_first_frame = idx + 1

        if max_episode_length is None: 
            self.max_episode_length = max(self.lengths)
        else:
            self.max_episode_length = max_episode_length
    
    def __getitem__(self, i):
        start, end = self.episodes_idxs[i]

        # form episode
        episode = {
            'observations': self.dataset['images'][start:end],
            'actions': self.dataset['actions'][start:end],
            'rewards': self.dataset['rewards'][start:end],
            'dones': self.dataset['dones_float'][start:end],
            'masks': self.dataset['masks'][start:end],
            'next_observations': self.dataset['next_images'][start:end]
        }

        # pad episode
        rest = self.max_episode_length - len(episode['actions'])
        padded_episode = {}
        for k,v in episode.items():
            padding = np.zeros(shape=(rest,*v.shape[1:]),dtype=v.dtype)
            padded_episode[k] = np.concatenate([v,padding],axis=0)
        
        return padded_episode

    def __len__(self):
        return len(self.episodes_idxs)