"""Builds the data modules used to train the policy."""

from __future__ import annotations

from abc import ABC, abstractmethod
from cProfile import label
import os
import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from d4rl import offline_env
import gym
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils import data

from rvs import step, util, bootstrap_quantile, bootstrap_qr, preemption, bootstrap_qr2
from wandb.sdk.wandb_run import Run
import wandb
import matplotlib.pyplot as plt

max_num_workers = 16


def create_data_module(
    env: gym.Env,
    env_name: str,
    dataset_preprocess: str,
    store_dataset_gpu: bool,
    percent_dataset: float,
    rollout_directory: str,
    batch_size: int = 256,
    val_frac: float = 0.1,
    unconditional_policy: bool = False,
    reward_conditioning: bool = False,
    discount_factor: float = 1.0,
    bootstrap_iters: int = 0,
    bootstrap_model: str = "qr",
    bootstrap_model_args: Dict[str, Any] = {},
    bootstrap_threshold_D: float = 0.01,
    bootstrap_noise: float = 0.0,
    bootstrap_feature_extractor: str = 'identity',
    average_reward_to_go: bool = True,
    seed: Optional[int] = None,
    wandb_run: Optional[Run] = None,
    pm: Optional[preemption.PreemptionManager] = None,
) -> AbstractDataModule:
    """Creates the data module used for training."""
    if unconditional_policy and reward_conditioning:
        raise ValueError(
            "Cannot condition on reward with an unconditional policy.")

    if env_name in step.d4rl_env_names:
        if unconditional_policy:
            data_module = D4RLBCDataModule(
                env,
                batch_size=batch_size,
                val_frac=val_frac,
                seed=seed,
            )
        elif reward_conditioning:
            data_module = D4RLRvSBDataModule(
                env,
                dataset_preprocess=dataset_preprocess,
                store_dataset_gpu=store_dataset_gpu,
                percent_dataset=percent_dataset,
                model=bootstrap_model,
                model_args=bootstrap_model_args,
                threshold_D=bootstrap_threshold_D,
                noise=bootstrap_noise,
                feature_extractor=bootstrap_feature_extractor,
                batch_size=batch_size,
                val_frac=val_frac,
                discount_factor=discount_factor,
                n_iter=bootstrap_iters,
                seed=seed,
                wandb_run=wandb_run,
                pm=pm,
            )
            # # if bootstrap_iters > 0:

            # else:
            #     data_module = D4RLRvSRDataModule(
            #         env,
            #         batch_size=batch_size,
            #         val_frac=val_frac,
            #         average_reward_to_go=average_reward_to_go,
            #         seed=seed,
            #         wandb_run=wandb_run,
            #     )
        else:
            data_module = D4RLRvSGDataModule(
                env,
                batch_size=batch_size,
                seed=seed,
                val_frac=val_frac,
            )
    else:
        if unconditional_policy:
            raise NotImplementedError
        else:
            data_module = GCSLDataModule(
                rollout_directory,
                batch_size=batch_size,
                val_frac=val_frac,
                seed=seed,
                num_workers=os.cpu_count(),
            )

    return data_module


def s_g_pair_iter(
    s_obs_vecs: np.ndarray,
    s_ach_goal_vecs: np.ndarray,
    a_vecs: np.ndarary,
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """Use hindsight to iterate over all states and future achieved goals."""
    for episode in range(len(a_vecs)):
        s_obs_vec = s_obs_vecs[episode]
        s_ach_goal_vec = s_ach_goal_vecs[episode]
        a_vec = a_vecs[episode]
        for i in range(len(a_vec)):
            for j in range(i + 1, len(s_obs_vec)):
                s = s_obs_vec[i]
                a = a_vec[i]
                g = s_ach_goal_vec[j]
                yield s, a, g


def make_s_g_tensor(states: torch.Tensor, goals: torch.Tensor) -> torch.Tensor:
    """Combine observations and goals into the same tensor."""
    s_tensor = torch.tensor(states)
    g_tensor = torch.tensor(goals)
    s_g_tensor = torch.cat((s_tensor, g_tensor), dim=1)

    return s_g_tensor


def to_tensor_dataset(
    data_vec: List[Tuple[np.ndarray, np.ndarray, np.ndarray]],
) -> data.TensorDataset:
    """Convert a list of data into a tensor dataset."""
    states, goals, actions = zip(*data_vec)
    states = np.array(states)
    goals = np.array(goals)

    s_g_tensor = make_s_g_tensor(states, goals)
    a_tensor = torch.tensor(actions)

    assert not s_g_tensor.requires_grad
    assert not a_tensor.requires_grad

    return data.TensorDataset(s_g_tensor, a_tensor)


def load_tensor_dataset(rollout_directory: str) -> data.TensorDataset:
    """Load saved GCSL rollouts and return a tensor dataset."""
    s_obs_vecs = np.load(os.path.join(rollout_directory, step.s_obs_vecs_file))
    s_ach_goal_vecs = np.load(
        os.path.join(rollout_directory, step.s_ach_goal_vecs_file),
    )
    a_vecs = np.load(os.path.join(rollout_directory, step.a_vecs_file))

    data_vec = [
        (s, g, a) for s, a, g in s_g_pair_iter(s_obs_vecs, s_ach_goal_vecs, a_vecs)
    ]
    tensor_dataset = to_tensor_dataset(data_vec)

    return tensor_dataset


def seed_worker(worker_id: int) -> None:
    """Unique random seed for each parallel data worker to prevent duplicate batches."""
    # torch.initial_seed() is the base torch seed plus a unique offset for each worker
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


class D4RLIterableDataset(data.IterableDataset):
    """Used for goal-conditioned learning in D4RL."""

    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        dones: np.ndarray,
        epoch_size: int = 2450000,
        index_batch_size: int = 64,
        goal_columns: Optional[Union[Tuple[int],
                                     List[int], np.ndarray]] = None,
    ):
        """Initializes the dataset.

        Args:
            observations: The observations for the dataset.
            actions: The actions for the dataset.
            dones: The dones for the dataset.
            epoch_size: For PyTorch Lightning to count epochs.
            index_batch_size: This has no effect on the functionality of the dataset,
                but it is used internally as the batch size to fetch random indices.
            goal_columns: If not None, then only use these columns of the
                observation_space for the goal conditioning.
        """
        super().__init__()

        self.observations = observations
        self.actions = actions
        self.dones = dones
        self.epoch_size = epoch_size
        self.index_batch_size = index_batch_size
        self.goal_columns = goal_columns

    def _sample_indices(self) -> Tuple[np.ndarray, np.ndarray]:
        starts, ends, lengths = util.extract_done_markers(self.dones)

        # Credit to Dibya Ghosh's GCSL codebase for the logic in the following block:
        # https://github.com/dibyaghosh/gcsl/blob/
        # cfae5609cee79e5a2228fb7653451023c41a64cb/gcsl/algo/buffer.py#L78
        trajectory_indices = np.random.choice(
            len(starts), self.index_batch_size)
        proportional_indices_1 = np.random.rand(self.index_batch_size)
        proportional_indices_2 = np.random.rand(self.index_batch_size)
        time_indices_1 = np.floor(
            proportional_indices_1 * (lengths[trajectory_indices] - 1),
        ).astype(int)
        time_indices_2 = np.floor(
            proportional_indices_2 * lengths[trajectory_indices],
        ).astype(int)

        start_indices = starts[trajectory_indices] + np.minimum(
            time_indices_1,
            time_indices_2,
        )
        goal_indices = starts[trajectory_indices] + np.maximum(
            time_indices_1,
            time_indices_2,
        )

        return start_indices, goal_indices

    def _sample_batch(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        start_indices, goal_indices = self._sample_indices()

        observation_batch = self.observations[start_indices]
        goal_batch = self.observations[goal_indices]
        if self.goal_columns is not None:
            goal_batch = np.take(goal_batch, self.goal_columns, axis=1)
        action_batch = self.actions[start_indices]

        return observation_batch, goal_batch, action_batch

    def __iter__(self) -> Iterator[Tuple[torch.tensor, torch.tensor]]:
        """Yield each training example."""
        examples_yielded = 0
        while examples_yielded < self.epoch_size:
            (
                observation_batch,
                goal_batch,
                action_batch,
            ) = self._sample_batch()

            observation_tensors = torch.tensor(observation_batch)
            goal_tensors = torch.tensor(goal_batch)
            action_tensors = torch.tensor(action_batch)

            for observation, goal, action in zip(
                observation_tensors,
                goal_tensors,
                action_tensors,
            ):
                yield torch.cat((observation, goal), dim=0), action
                examples_yielded += 1
                if examples_yielded >= self.epoch_size:
                    break

    def __len__(self) -> int:
        """The number of examples in an epoch. Used by the trainer to count epochs."""
        return self.epoch_size


class AbstractDataModule(pl.LightningDataModule, ABC):
    """Abstract class that serves as parent for all DataModules."""

    def __init__(
        self,
        batch_size: int = 256,
        val_frac: float = 0.1,
        num_workers: Optional[int] = None,
        seed: Optional[int] = None,
    ):
        """Initialization for the abstract class.

        Args:
            batch_size: How many examples to return per batch.
            val_frac: What fraction of examples to use as a validation set.
            num_workers: How many cpu workers to fetch data. If not specified, takes
                the minimum of os.cpu_count() and max_num_workers (defined at the top of
                this file).
            seed: A seed for the random dataset samples.
        """
        super().__init__()
        self.batch_size = batch_size
        self.val_frac = val_frac
        slurm_cpus = int(os.environ.get("SLURM_CPUS_PER_TASK", os.cpu_count()))
        self.num_workers = num_workers or min(slurm_cpus, max_num_workers)

        # These should be created in self.setup()
        self.data_train = None
        self.data_val = None

        if seed is None:
            seed = np.random.randint(2**31 - 1)
        self.seed = seed
        self.generator = torch.Generator()
        self.generator.manual_seed(seed)

    @abstractmethod
    def setup(self, *args, **kwargs) -> None:
        """Create the training and validation data."""
        pass

    ####################
    # DATA RELATED HOOKS
    ####################

    def train_dataloader(self) -> data.DataLoader:
        """Make the training dataloader."""
        return data.DataLoader(
            self.data_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            generator=self.generator,
            worker_init_fn=seed_worker,
        )

    def val_dataloader(self) -> data.DataLoader:
        """Make the validation dataloader."""
        return data.DataLoader(
            self.data_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            generator=self.generator,
            worker_init_fn=seed_worker,
        )


class GCSLDataModule(AbstractDataModule):
    """The data module used for GCSL envs."""

    def __init__(
        self,
        rollout_directory: str,
        batch_size: int = 32,
        val_frac: float = 0.2,
        num_workers: Optional[int] = None,
        seed: Optional[int] = None,
    ):
        """Custom initialization for the GCSL data module."""
        super().__init__(
            batch_size=batch_size,
            val_frac=val_frac,
            num_workers=num_workers,
            seed=seed,
        )
        self.rollout_directory = rollout_directory

    def setup(self, stage: Optional[str] = None) -> None:
        """Create the training and validation data."""
        tensor_dataset = load_tensor_dataset(self.rollout_directory)
        n_val = int(self.val_frac * len(tensor_dataset))
        n_train = len(tensor_dataset) - n_val

        data_train, data_val = data.random_split(
            tensor_dataset, [n_train, n_val])
        if n_val == 0:
            data_val = None

        if stage == "fit" or stage is None:
            self.data_train, self.data_val = data_train, data_val


def d4rl_trajectory_split(
    dones: np.ndarray,
    val_frac: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Divides the D4RL trajectories into training and validation splits.

    Args:
        dones: Flags indicating the ends of trajectories.
        val_frac: What fraction of the trajectories to put in the validation set.

    Returns:
        Indices indicating which timesteps go in the training set and which go in the
            validation set.
    """
    assert 0 <= val_frac <= 1

    starts, ends, lengths = util.extract_done_markers(dones)
    n_val = int(val_frac * len(starts))
    n_train = len(starts) - n_val

    train_traj_indices = np.arange(n_train)
    val_traj_indices = np.arange(n_val) + n_train

    # avoid biased splits when trajectories are ordered, e.g., in combined datasets
    shuffled = np.arange(n_train + n_val)
    np.random.shuffle(shuffled)
    train_traj_indices = shuffled[train_traj_indices]
    val_traj_indices = shuffled[val_traj_indices]

    train_indices = util.collect_timestep_indices(
        dones, train_traj_indices).astype(int)
    val_indices = util.collect_timestep_indices(
        dones, val_traj_indices).astype(int)

    return train_indices, val_indices


def reward_to_go(dataset: Dict[str, np.ndarray], average: bool = True) -> np.ndarray:
    """Compute the reward to go for each timestep.

    The implementation is iterative because when I wrote a vectorized version, np.cumsum
    cauased numerical instability.
    """
    dones = np.logical_or(dataset["terminals"], dataset["timeouts"])
    _, _, lengths = util.extract_done_markers(dones)
    max_episode_steps = np.max(lengths)

    reverse_reward_to_go = np.inf * np.ones_like(dataset["rewards"])
    running_reward = 0
    for i, (reward, done) in enumerate(zip(dataset["rewards"][::-1], dones[::-1])):
        if done:
            running_reward = 0
        running_reward += reward
        reverse_reward_to_go[i] = running_reward
    cum_reward_to_go = reverse_reward_to_go[::-1].copy()

    avg_reward_to_go = np.inf * np.ones_like(cum_reward_to_go)
    elapsed_time = 0
    for i, (cum_reward, done) in enumerate(zip(cum_reward_to_go, dones)):
        avg_reward_to_go[i] = cum_reward / (max_episode_steps - elapsed_time)
        elapsed_time += 1
        if done:
            elapsed_time = 0

    return avg_reward_to_go if average else cum_reward_to_go

# Iterable Dataset that takes TensorDataset as input but fixed epoch length


class TensorDatasetIterableDataset(data.IterableDataset):

    rand: np.random.Generator

    def __init__(self, dataset, epoch_length):
        self.dataset = dataset
        self.epoch_length = epoch_length
        self.tensors = self.dataset.tensors

    def segment_generator(self, epoch_length):
        for _ in range(epoch_length):
            idx = self.rand.integers(len(self.dataset))
            yield self.dataset[idx]

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        self.rand = np.random.default_rng(None)
        if worker_info is None:  # single-process data loading, return the full iterator
            gen = self.segment_generator(int(self.epoch_length))
        else:  # in a worker process
            # split workload
            per_worker_time_steps = int(
                self.epoch_length / float(worker_info.num_workers))
            gen = self.segment_generator(per_worker_time_steps)
        return gen

    def __len__(self):
        return self.epoch_length


class D4RLTensorDatasetDataModule(AbstractDataModule):
    """Abstract class for D4RL datasets that can be stored as a TensorDataset."""

    def __init__(
        self,
        env: offline_env.OfflineEnv,
        batch_size: int,
        val_frac: float = 0.1,
        num_workers: Optional[int] = None,
        seed: Optional[int] = None,
        wandb_run: Optional[Run] = None,
    ):
        """Custom initialization that saves the environment."""
        super().__init__(
            batch_size=batch_size,
            val_frac=val_frac,
            num_workers=num_workers,
            seed=seed,
        )
        self.env = env
        self.wandb_run = wandb_run

    def setup(self, stage: Optional[str] = None) -> None:
        """Create the training and validation data."""
        dataset = self.env.get_dataset()
        observation_tensor = self._get_observation_tensor(dataset)
        action_tensor = torch.tensor(dataset["actions"])
        dones = np.logical_or(dataset["terminals"], dataset["timeouts"])

        train_indices, val_indices = d4rl_trajectory_split(
            dones, self.val_frac)

        epoch_len = action_tensor.shape[0]

        # the observation tensor may be n times larger than the action tensor due to data
        # augmentation, so we need to repeat the action tensor

        n_actions = action_tensor.shape[1]
        action_tensor = action_tensor.reshape(-1, 1, n_actions).expand(-1,
                                                                       observation_tensor.shape[0] // action_tensor.shape[0], n_actions)
        action_tensor = action_tensor.reshape(-1, n_actions)
        assert observation_tensor.shape[0] == action_tensor.shape[0]

        train_dataset = data.TensorDataset(
            observation_tensor[train_indices], action_tensor[train_indices])
        val_dataset = (
            data.TensorDataset(
                observation_tensor[val_indices], action_tensor[val_indices])
            if self.val_frac > 0
            else None
        )

        train_iterable_dataset = TensorDatasetIterableDataset(
            train_dataset, epoch_length=epoch_len)
        val_iterable_dataset = (
            TensorDatasetIterableDataset(val_dataset, epoch_length=epoch_len)
            if self.val_frac > 0
            else None
        )
        if stage == "fit" or stage is None:
            self.data_train, self.data_val = train_iterable_dataset, val_iterable_dataset

    @abstractmethod
    def _get_observation_tensor(self, dataset: Dict[str, np.ndarray]) -> torch.Tensor:
        pass

    # Output tensors
    def mean_std(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute the mean and standard deviation of the training data. 
        Use only the first element of each item in the tensor dataset."""
        obs = self.data_train.tensors[0]
        return obs.mean(dim=0), obs.std(dim=0) + 1e-6


def plot_return_histogram(ret_np, wandb_run, num_bins=20, tag=''):
    """Plot a matplotlib histogram with the right number of bins and log to wandb."""
    ax = plt.figure().gca()
    ax.hist(ret_np, bins=num_bins)
    ax.set_title(f"Return Histogram {tag}")
    ax.set_xlabel("Return")
    ax.set_ylabel("Frequency")
    wandb_run.log({"return_distribution": wandb.Image(ax)})


class D4RLBCDataModule(D4RLTensorDatasetDataModule):
    """Data module for unconditional behavior cloning in D4RL."""

    def _get_observation_tensor(self, dataset: Dict[str, np.ndarray]) -> torch.Tensor:
        return torch.tensor(dataset["observations"])


class D4RLRvSRDataModule(D4RLTensorDatasetDataModule):
    """Data module for RvS-R (reward-conditioned) learning in D4RL."""

    def __init__(
        self,
        env: offline_env.OfflineEnv,
        batch_size: int,
        val_frac: float = 0.1,
        num_workers: Optional[int] = None,
        average_reward_to_go: bool = True,
        seed: Optional[int] = None,
        wandb_run: Optional[Run] = None,
    ):
        """Custom initialization that sets the average_reward_to_go."""
        super().__init__(
            env,
            batch_size,
            val_frac=val_frac,
            num_workers=num_workers,
            seed=seed,
            wandb_run=wandb_run,
        )
        self.average_reward_to_go = average_reward_to_go

    def _get_observation_tensor(self, dataset: Dict[str, np.ndarray]) -> torch.Tensor:
        rets = reward_to_go(
            dataset, average=self.average_reward_to_go).reshape(-1, 1)
        plot_return_histogram(rets.reshape(-1), self.wandb_run, num_bins=20)
        return make_s_g_tensor(
            dataset["observations"],
            rets,
        )


class D4RLRvSBDataModule(AbstractDataModule):
    """Data module for RvS-B (bootstrapped-reward-conditioned) learning in D4RL."""

    def __init__(
        self,
        env: offline_env.OfflineEnv,
        dataset_preprocess: str,
        store_dataset_gpu: bool,
        percent_dataset: float,
        model: str,
        model_args: Dict[str, Any],
        threshold_D: float,
        noise: float,
        feature_extractor: str,
        batch_size: int,
        n_iter: int,
        val_frac: float = 0.1,
        num_workers: Optional[int] = None,
        discount_factor: float = 1.0,
        seed: Optional[int] = None,
        wandb_run: Optional[Run] = None,
        pm: Optional[preemption.PreemptionManager] = None,
    ):
        """Custom initialization that sets the average_reward_to_go."""
        super().__init__(
            batch_size=batch_size,
            val_frac=val_frac,
            num_workers=num_workers,
            seed=seed,
        )
        self.discount_factor = discount_factor
        self.n_iter = n_iter
        self.model_args = model_args
        self.threshold_D = threshold_D
        self.noise = noise
        self.feature_extractor = feature_extractor
        self.pm = pm
        self.wandb_run = wandb_run
        self.env = env
        self.model = model
        self.dataset_preprocess = dataset_preprocess
        self.store_dataset_gpu = store_dataset_gpu
        self.percent_dataset = percent_dataset

    def setup(self, stage: Optional[str] = None, terminal_every=None) -> None:
        """Create the training and validation data."""

        dataset = self.env.get_dataset()

        if self.percent_dataset < 1.0:
            print("Using a subset of the dataset!")
            for key in dataset.keys():
                if key in ["observations", "actions", "rewards", "terminals", "timeouts", "continuous_actions"]:
                    dataset[key] = dataset[key][:int(
                        self.percent_dataset * len(dataset[key]))]

        dataset["rewards"] = step.preprocess_reward(
            dataset["rewards"], self.model_args['reward_preprocessing'])

        dataset = util.preprocess_dataset(dataset, self.dataset_preprocess)

        # Check if env is an instance of util.DiscretizeWrapper
        if isinstance(self.env, util.DiscretizeWrapper):
            # If so, discretize the actions
            print("Discretizing actions in the dataset...")
            dataset["continuous_actions"] = dataset["actions"]
            dataset["actions"] = self.env.reverse_action(dataset["actions"])

        if self.model == 'qr':
            bootstrapped_dataset = self.pm.load_if_exists('bootstrapped_dataset',
                                                          bootstrap_qr.bootstrapped_dataset_quantile(dataset,
                                                                                                     self.discount_factor,
                                                                                                     self.model_args,
                                                                                                     self.n_iter,
                                                                                                     self.seed,
                                                                                                     wandb_run=self.wandb_run,
                                                                                                     env=self.env,
                                                                                                     store_dataset_gpu=self.store_dataset_gpu))
        else:
            raise NotImplementedError

        self.pm.save('bootstrapped_dataset', bootstrapped_dataset, now=True)

        labels = np.concatenate([np.array(labels).flatten()
                                for labels in bootstrapped_dataset.traj_labels], axis=0)
        plot_return_histogram(labels, self.wandb_run, num_bins=20)

        if self.model == 'qr':
        for i in range(len(bootstrapped_dataset.traj_labels[0])):
            labels = np.concatenate([np.array(labels[i]).flatten()
                                    for labels in bootstrapped_dataset.traj_labels], axis=0)
            plot_return_histogram(labels, self.wandb_run,
                                    num_bins=20, tag=f'bootstrapped-{i}-times')

        if self.n_iter > 0 and step.is_antmaze_env(self.env):
            os = []
            l = [[] for i in range(self.n_iter + 1)]

            for t, tls in zip(bootstrapped_dataset.trajs, bootstrapped_dataset.traj_labels):
                for o in t.obs:
                    os.append(o[:2])
                for i in range(self.n_iter + 1):
                    for tl in tls[i]:
                        l[i].append(tl)

            os = np.array(os)
            l = np.array(l)

            # generate 2 2d grids for the x & y bounds
            minx, maxx = os[:, 0].min() - 1, os[:, 0].max() + 1
            miny, maxy = os[:, 1].min() - 1, os[:, 1].max() + 1

            x, y = np.meshgrid(np.linspace(minx, maxx, 200),
                                np.linspace(miny, maxy, 200))

            fig, ax = plt.subplots(
                1, self.n_iter + 1, figsize=((self.n_iter + 1) * 5, 4))

            for j in range(self.n_iter + 1):
                min_r = -1 / (1 - self.discount_factor + 1e-8)
                z = np.ones_like(x) * min_r

                for o, label in list(zip(os, l[j])):
                    xloc = int((o[0] - minx) / (maxx - minx) * 200)
                    yloc = int((o[1] - miny) / (maxy - miny) * 200)
                    z[xloc, yloc] = np.maximum(z[xloc, yloc], label)

                z = z[:-1, :-1]
                z_min, z_max = min_r, 0

                c = ax[j].pcolormesh(
                    x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
                ax[j].set_title(f"Values, bootstrap = {j}")
                # set the limits of the plot to the limits of the data
                ax[j].axis([x.min(), x.max(), y.min(), y.max()])
                fig.colorbar(c, ax=ax[j])

            self.wandb_run.log({"values plot": wandb.Image(fig)})

        bootstrapped_dataset.epoch_len = dataset["actions"].shape[0]
        bootstrapped_dataset.combine_goal = True
        bootstrapped_dataset.seed = self.seed

        self.iter_dataset = bootstrapped_dataset

        if self.store_dataset_gpu:
            bootstrapped_dataset = bootstrapped_dataset.convert_to_tensor_dataset()

        if self.val_frac > 0:
            raise NotImplementedError(
                "Validation not implemented for bootstrapped data.")

        if stage == "fit" or stage is None:
            self.data_train, self.data_val = bootstrapped_dataset, None

    def train_dataloader(self) -> data.DataLoader:
        """Make the training dataloader."""
        if self.store_dataset_gpu:
            return util.FastTensorDataLoader(self.data_train, batch_size=self.batch_size, shuffle=True, device='cuda')
        else:
            return data.DataLoader(
                self.data_train,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                generator=self.generator,
                worker_init_fn=seed_worker,
            )

    def val_dataloader(self) -> data.DataLoader:
        """Make the validation dataloader."""
        if self.store_dataset_gpu:
            pass
        else:
            return data.DataLoader(
                self.data_val,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                generator=self.generator,
                worker_init_fn=seed_worker,
            )

    def mean_std(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the mean and standard deviation of the training data."""
        return self.iter_dataset.mean_std()


class D4RLRvSGDataModule(AbstractDataModule):
    """Data module for RvS-G (goal-conditioned) learning in D4RL."""

    def __init__(
        self,
        env: offline_env.OfflineEnv,
        batch_size: int = 32,
        val_frac: float = 0.1,
        num_workers: Optional[int] = None,
        seed: Optional[int] = None,
    ):
        """Custom initialization.

        Saves the environment and conditions on the (x, y) coordinate of the goal in
        AntMaze.
        """
        super().__init__(
            batch_size=batch_size,
            val_frac=val_frac,
            num_workers=num_workers,
            seed=seed,
        )
        self.env = env
        self.goal_columns = (0, 1) if step.is_antmaze_env(env) else None

    def setup(self, stage: Optional[str] = None) -> None:
        """Create the training and validation data."""
        dataset = self.env.get_dataset()
        observations = dataset["observations"]
        actions = dataset["actions"]
        if step.is_antmaze_env(self.env):
            dones = dataset["timeouts"]
        else:
            dones = np.logical_or(dataset["terminals"], dataset["timeouts"])

        train_indices, val_indices = d4rl_trajectory_split(
            dones, self.val_frac)

        train_dataset = D4RLIterableDataset(
            observations[train_indices],
            actions[train_indices],
            dones[train_indices],
            goal_columns=self.goal_columns,
        )
        val_dataset = (
            D4RLIterableDataset(
                observations[val_indices],
                actions[val_indices],
                dones[val_indices],
                goal_columns=self.goal_columns,
            )
            if self.val_frac > 0
            else None
        )

        if stage == "fit" or stage is None:
            self.data_train, self.data_val = train_dataset, val_dataset
