from utils.replay_atari_data import AtariReplayDataset
from learn.config import AtariReplayDataExperimentConfig

class AtariReplayDatasetForCAStRL(AtariReplayDataset):

    def __getitem__(self, index):
        index = self.temporal_indices[index]
        states, targets = self.process_sequential_frames(index)
        states, targets = self.make_ready(states, targets)
        if self.use_strl and (self.low_contrast_mode is None):
            states = {'augmented': self.normalize(self.augmenter.transform(states)), 
                      'augmented_prime': self.normalize(self.augmenter.transform(states))}
        elif self.use_strl and (self.low_contrast_mode == 'keep_semantics'):
            states = {'augmented': self.normalize(states), 
                      'augmented_prime': self.normalize(self.augmenter.transform(states))}
        else:
            states = self.normalize(states)                      
        return states, targets

    def __len__(self):
        return len(self.temporal_indices)

    @classmethod
    def from_config(cls, cfg: AtariReplayDataExperimentConfig):
        dataset = cls(cfg.replay_data_dir, num_steps=cfg.num_steps, start_buffer=cfg.start_buffer, num_buffers=cfg.num_buffers, 
                      stack_size=cfg.stack_size, channel_last=False, image_size=cfg.image_size, use_dynamic_range=False, 
                      num_frames=cfg.seq_len, low_contrast_mode=cfg.low_contrast_mode, use_strl=cfg.use_strl, split_size=cfg.data_ratio, 
                      overlap_ratios=cfg.overlap_ratios, unlabeled_ratio=cfg.unlabeled_ratio, 
                      unknown_action=cfg.unknown_label, map_ale_actions=cfg.map_ale_actions, 
                      cache_dir=cfg.cache_dir, cachefile_prefix=cfg.cachefile_prefix)
        return dataset
