# coding=utf-8
# Adapted from Ravens - Transporter Networks, Zeng et al., 2021
# https://github.com/google-research/ravens
"""Image dataset."""

import os
import pickle

import numpy as np
from ravens_torch import tasks
from ravens_torch.tasks import cameras


# See transporter.py, regression.py, dummy.py, task.py, etc.
PIXEL_SIZE = 0.003125
CAMERA_CONFIG = cameras.RealSenseD415.CONFIG
BOUNDS = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]])

# Names as strings, REVERSE-sorted so longer (more specific) names are first.
TASK_NAMES = (tasks.names).keys()
TASK_NAMES = sorted(TASK_NAMES)[::-1]

import IPython as ipy


class Dataset:
    """A simple image dataset class."""

    def __init__(self, path):
        """A simple RGB-D image dataset."""
        self.path = path
        self.sample_set = []
        self.max_seed = -1
        self.n_episodes = 0

        # Track existing dataset if it exists.
        color_path = os.path.join(self.path, 'action')
        if os.path.exists(color_path):
            for fname in sorted(os.listdir(color_path)):
                if '.pkl' in fname:
                    seed = int(fname[(fname.find('-') + 1):-4])
                    self.n_episodes += 1
                    self.max_seed = max(self.max_seed, seed)

        self._cache = {}

    def add(self, seed, episode):
        """Add an episode to the dataset.

        Args:
          seed: random seed used to initialize the episode.
          episode: list of (obs, act, reward, info) tuples.
        """
        color, depth, action, reward, info = [], [], [], [], []
        for obs, act, r, i in episode:
            color.append(obs['color'])
            depth.append(obs['depth'])
            action.append(act)
            reward.append(r)
            info.append(i)

        color = np.uint8(color)
        depth = np.float32(depth)

        def dump(data, field):
            field_path = os.path.join(self.path, field)
            if not os.path.exists(field_path):
                os.makedirs(field_path)
            fname = f'{self.n_episodes:06d}-{seed}.pkl'  # -{len(episode):06d}
            with open(os.path.join(field_path, fname), 'wb') as f:
                pickle.dump(data, f)

        dump(color, 'color')
        dump(depth, 'depth')
        dump(action, 'action')
        dump(reward, 'reward')
        dump(info, 'info')

        self.n_episodes += 1
        self.max_seed = max(self.max_seed, seed)

    def set(self, episodes):
        """Limit random samples to specific fixed set."""
        self.sample_set = episodes

    def load(self, episode_id, images=True, cache=False):
        """Load data from a saved episode.

        Args:
          episode_id: the ID of the episode to be loaded.
          images: load image data if True.
          cache: load data from memory if True.

        Returns:
          episode: list of (obs, act, reward, info) tuples.
          seed: random seed used to initialize the episode.
        """

        def load_field(episode_id, field, fname):

            # Check if sample is in cache.
            if cache:
                if episode_id in self._cache:
                    if field in self._cache[episode_id]:
                        return self._cache[episode_id][field]
                else:
                    self._cache[episode_id] = {}

            # Load sample from files.
            path = os.path.join(self.path, field)
            data = pickle.load(open(os.path.join(path, fname), 'rb'))
            if cache:
                self._cache[episode_id][field] = data
            return data

        # Get filename and random seed used to initialize episode.
        seed = None
        path = os.path.join(self.path, 'action')
        for fname in sorted(os.listdir(path)):
            if f'{episode_id:06d}' in fname:
                seed = int(fname[(fname.find('-') + 1):-4])

                # Load data.
                color = load_field(episode_id, 'color', fname)
                depth = load_field(episode_id, 'depth', fname)
                action = load_field(episode_id, 'action', fname)
                reward = load_field(episode_id, 'reward', fname)
                info = load_field(episode_id, 'info', fname)

                # Reconstruct episode.
                episode = []
                for i in range(len(action)):
                    obs = {'color': color[i],
                           'depth': depth[i]} if images else {}
                    episode.append((obs, action[i], reward[i], info[i]))
                return episode, seed

    def sample(self, images=True, cache=False):
        """Uniformly sample from the dataset.

        Args:
          images: load image data if True.
          cache: load data from memory if True.

        Returns:
          sample: randomly sampled (obs, act, reward, info) tuple.
          goal: the last (obs, act, reward, info) tuple in the episode.
        """

        # Choose random episode.
        if len(self.sample_set) > 0:  # pylint: disable=g-explicit-length-test
            episode_id = np.random.choice(self.sample_set)
        else:
            episode_id = np.random.choice(range(self.n_episodes))
        episode, _ = self.load(episode_id, images, cache)

        # Return random observation action pair (and goal) from episode.
        i = np.random.choice(range(len(episode) - 1))
        sample, goal = episode[i], episode[-1]
        return sample, goal
    
    def fetch_detect_set(self, images=True, cache=False):
        """Load detection set for martingale computations. 
        Args:
            images: load image data if True.
            cache: load data from memory if True.
            
        Returns:
            detect_set: first time-step from all episodes in the dataset.
        """
        
        # Get number of episodes
        if len(self.sample_set) > 0:
            n_episodes = len(self.sample_set)
        else:
            n_episodes = self.n_episodes
            
            
        # Initialize detect_set
        detect_set = []
        
        # Unshuffle self.sample_set
        if len(self.sample_set) > 0:
            episode_ids = sorted(self.sample_set)
        
        # Load first time-step information from all episodes
        for i in range(n_episodes):
            episode_id = i if len(self.sample_set) == 0 else episode_ids[i]
            episode, _ = self.load(episode_id, images, cache)
            detect_set.append(episode[0])
            
        return detect_set

def load_data(FLAGS, only_test=False):
    test_path = os.path.join(FLAGS.data_dir, f'{FLAGS.task}-test')
    train_path = os.path.join(FLAGS.data_dir, f'{FLAGS.task}-train')

    if FLAGS.verbose:
        if not only_test:
            print(f"Loading trainset from {train_path}")
        print(f"Loading testset  from {test_path}")

    if only_test:
        return Dataset(test_path)

    return Dataset(train_path), Dataset(test_path)
