from __future__ import annotations

import copy
import logging
import pickle
import random
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gym
import h5py
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from tqdm import tqdm

TQDM_BAR_FORMAT = "{l_bar}{bar:50}{r_bar}"

logger = logging.getLogger(__name__)


def convert_ndarray_list_to_obj_ndarray(list_of_array: List[np.ndarray]):
    all_same_len = all(
        [len(arr) == len(list_of_array[0]) for arr in list_of_array])
    if all_same_len:
        original_len = len(list_of_array[0])
        list_of_array[0] = np.concatenate(
            (list_of_array[0], np.zeros_like(list_of_array[0])))
        list_of_array = np.array(list_of_array, dtype=object)
        list_of_array[0] = list_of_array[0][:original_len]
    else:
        list_of_array = np.array(list_of_array, dtype=object)

    return list_of_array


@dataclass
class TrajDataset:
    obs: np.ndarray
    actions: np.ndarray
    images: np.ndarray
    task_ids: Optional[np.ndarray] = None
    n_task_id: int = -1
    domain_id: int = -1
    n_domain_id: int = -1

    def __len__(self):
        return len(self.obs)

    def __post_init__(self):
        assert len(self.obs) == len(self.actions)
        if self.task_ids is None:
            self.task_ids = np.ones(len(self), dtype=int) * -1

    @staticmethod
    def _get_next_obs_of_single_traj(obs: np.ndarray):
        next_obs = np.concatenate((obs[1:], obs[-1:]))
        return next_obs

    @staticmethod
    def _get_next_images_of_single_traj(images: np.ndarray):
        next_images = np.concatenate((images[1:], images[-1:]))
        return next_images

    @property
    def next_obs(self) -> np.ndarray:
        assert len(self.obs[0].shape) == 2

        ret = [self._get_next_obs_of_single_traj(obs) for obs in self.obs]
        ret = convert_ndarray_list_to_obj_ndarray(list_of_array=ret)

        return ret

    @property
    def next_images(self) -> np.ndarray:
        if len(self.images[0].shape) != 4:
            return self.images

        ret = [
            self._get_next_images_of_single_traj(images)
            for images in self.images
        ]
        ret = convert_ndarray_list_to_obj_ndarray(list_of_array=ret)

        return ret

    def get_onehot_task_id(self) -> np.ndarray:
        return np.eye(self.n_task_id)[self.task_ids].astype(np.float32)

    def add_domain_id(self, domain_id: int, n_domain_id: int):
        assert self.domain_id == self.n_domain_id == -1
        self.domain_id = domain_id
        self.n_domain_id = n_domain_id

    def apply_obs_converter(self, func: ObservationConverter):
        new_obs = convert_ndarray_list_to_obj_ndarray(
            [func(tr) for tr in self.obs])
        self.obs = new_obs

    def apply_action_converter(self, func: ActionConverter):
        new_actions = convert_ndarray_list_to_obj_ndarray(
            [func(tr) for tr in self.actions])
        self.actions = new_actions

    def __getitem__(self, item) -> TrajDataset:
        assert self.task_ids is not None
        return TrajDataset(
            obs=self.obs[item],
            actions=self.actions[item],
            task_ids=self.task_ids[item],
            n_task_id=self.n_task_id,
            domain_id=self.domain_id,
            n_domain_id=self.n_domain_id,
            images=self.images[item],
        )


def concat_traj_datasets(dataset1: TrajDataset,
                         dataset2: TrajDataset) -> TrajDataset:
    assert dataset1.domain_id == dataset2.domain_id == -1
    assert dataset1.n_domain_id == dataset2.n_domain_id == -1
    assert dataset1.n_task_id == dataset2.n_task_id

    result_dataset = TrajDataset(
        obs=np.concatenate((dataset1.obs, dataset2.obs)),
        actions=np.concatenate((dataset1.actions, dataset2.actions)),
        images=np.concatenate((dataset1.images, dataset2.images)),
        task_ids=np.concatenate((dataset1.task_ids, dataset2.task_ids)),
        n_task_id=dataset1.n_task_id,
        domain_id=dataset1.domain_id,
        n_domain_id=dataset1.n_domain_id,
    )
    return result_dataset


@dataclass
class StepDataset:
    obs: np.ndarray
    actions: np.ndarray
    images: np.ndarray
    task_ids: np.ndarray
    n_task_id: int
    domain_ids: np.ndarray
    n_domain_id: int
    next_obs: np.ndarray
    next_images: np.ndarray

    def __len__(self):
        return len(self.obs)

    def get_onehot_task_id(self) -> np.ndarray:
        return np.eye(self.n_task_id)[self.task_ids].astype(np.float32)

    def get_onehot_domain_id(self) -> np.ndarray:
        return np.eye(self.n_domain_id)[self.domain_ids].astype(np.float32)

    def __getitem__(self, item) -> StepDataset:
        return StepDataset(
            obs=self.obs[item],
            actions=self.actions[item],
            task_ids=self.task_ids[item],
            n_task_id=self.n_task_id,
            domain_ids=self.domain_ids[item],
            n_domain_id=self.n_domain_id,
            next_obs=self.next_obs[item],
            images=self.images[item],
            next_images=self.next_images[item],
        )


class PairedTrajDataset(Dataset):

    def __init__(
        self,
        traj_datasets: List[TrajDataset],
        obs_dim: int,
        action_dim: int,
        seq_len: int = 400,
        sa_demo: bool = False,
    ):
        self.datasets = traj_datasets
        self.seq_len = seq_len
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.sa_demo = sa_demo
        self.n_domains = len(self.datasets)
        assert self.n_domains == self.datasets[0].n_domain_id

        self.n_traj = sum([len(data) for data in self.datasets])
        logger.info(f'{self.n_traj} trajectories available in total.')
        self.task_id_list = np.unique(self.datasets[0].task_ids)

    def __len__(self):
        return self.n_traj

    def sample_seq_of_task_id(self, task_id: int, source: bool = True):
        source_domain_id, target_domain_id = 0, 1
        dataset = self.datasets[source_domain_id] if source else self.datasets[
            target_domain_id]
        selected_trajs = dataset[dataset.task_ids == task_id]

        idx = np.random.randint(len(selected_trajs))
        traj = selected_trajs[idx]
        return traj

    def _get_traj_pair(self, domain1: int, domain2: int):
        dataset1 = self.datasets[domain1]
        dataset2 = self.datasets[domain2]

        while True:
            selected_task_id = np.random.choice(self.task_id_list)

            selected_trajs1 = dataset1[dataset1.task_ids == selected_task_id]
            selected_trajs2 = dataset2[dataset2.task_ids == selected_task_id]
            if len(selected_trajs1) > 0 and len(selected_trajs2) > 0:
                break

        idx1 = np.random.randint(len(selected_trajs1))
        traj1 = selected_trajs1[idx1]

        idx2 = np.random.randint(len(selected_trajs2))
        traj2 = selected_trajs2[idx2]

        return traj1, traj2

    @staticmethod
    def _get_padded_seq(arr,
                        target_length: int) -> Tuple[np.ndarray, np.ndarray]:
        pad_size = target_length - arr.shape[0]
        zero_array = np.zeros((pad_size, *arr.shape[1:]), dtype=arr[0].dtype)
        pad_mask = np.concatenate(
            (np.zeros(arr.shape[0]), np.ones(pad_size))).astype(bool)
        return np.concatenate((arr, zero_array)), pad_mask

    @staticmethod
    def _get_domain_id_array(length: int,
                             domain_id: int,
                             n_domain_id: int = 2) -> np.ndarray:
        domain_ids = np.zeros((length, n_domain_id), dtype=np.float32)
        domain_ids[..., domain_id] = 1.
        return domain_ids

    def __getitem__(self, item):
        sampled_domains = random.sample(range(self.n_domains), 2)
        traj1, traj2 = self._get_traj_pair(domain1=sampled_domains[0],
                                           domain2=sampled_domains[1])

        obs_seq1, pad_mask1 = self._get_padded_seq(traj1.obs,
                                                   target_length=self.seq_len)
        obs_seq2, pad_mask2 = self._get_padded_seq(traj2.obs,
                                                   target_length=self.seq_len)
        if self.sa_demo:
            action_seq1, _ = self._get_padded_seq(traj1.actions,
                                                  target_length=self.seq_len)
        action_seq2, _ = self._get_padded_seq(traj2.actions,
                                              target_length=self.seq_len)
        domain_ids1 = self._get_domain_id_array(length=len(obs_seq1),
                                                domain_id=sampled_domains[0],
                                                n_domain_id=self.n_domains)
        domain_ids2 = self._get_domain_id_array(length=len(obs_seq2),
                                                domain_id=sampled_domains[1],
                                                n_domain_id=self.n_domains)
        image_seq1, _ = self._get_padded_seq(traj1.images,
                                             target_length=self.seq_len)
        image_seq2, _ = self._get_padded_seq(traj2.images,
                                             target_length=self.seq_len)

        obs_seq1, _ = get_padded_batch_and_valid_mask(obs_seq1,
                                                      target_dim=self.obs_dim)
        obs_seq2, _ = get_padded_batch_and_valid_mask(obs_seq2,
                                                      target_dim=self.obs_dim)
        if self.sa_demo:
            action_seq1, _ = get_padded_batch_and_valid_mask(
                action_seq1, target_dim=self.action_dim)
            obs_seq1 = np.concatenate((obs_seq1, action_seq1), axis=-1)

        action_seq2, action_masks2 = get_padded_batch_and_valid_mask(
            action_seq2, target_dim=self.action_dim)

        data_dict = {
            'obs1': torch.from_numpy(obs_seq1),
            'obs2': torch.from_numpy(obs_seq2),
            'actions2': torch.from_numpy(action_seq2),
            #     'action_masks2': torch.from_numpy(action_masks2),
            'seq_pad_masks1': torch.from_numpy(pad_mask1),
            'seq_pad_masks2': torch.from_numpy(pad_mask2),
            'domain_ids1': torch.from_numpy(domain_ids1),
            'domain_ids2': torch.from_numpy(domain_ids2),
            'image1': torch.from_numpy(image_seq1),
            'image2': torch.from_numpy(image_seq2),
        }

        return data_dict


class CorrespondenceDataset(Dataset):

    def __init__(
        self,
        traj_datasets: List[TrajDataset],
        obs_dim: int,
        n_traj: int = 10,
        states_per_traj: int = 10,
        neighbors_per_state: int = 1,
        correspondence_type: str = 'p2p',
    ):
        self.datasets = traj_datasets
        self.obs_dim = obs_dim
        self.n_domains = len(self.datasets)
        assert self.n_domains == 2
        assert self.n_domains == self.datasets[0].n_domain_id

        logger.info(f'{n_traj} trajectories available in total.')
        self.task_id_list = np.unique(self.datasets[0].task_ids)

        self.state_pair_list = []
        self.setup_dataset(n_traj, correspondence_type, states_per_traj,
                           neighbors_per_state)

    @staticmethod
    def _are_overlapped_trajectories(traj1: np.ndarray,
                                     traj2: np.ndarray,
                                     threshold: float = 0.2):
        # judge overlap for P2P and P2A.
        # It checks the initial state of one trajectory is in the other trajectory.

        from_traj1 = np.linalg.norm(traj1[0] - traj2, axis=-1).min()
        if from_traj1 < threshold:
            return True

        from_traj2 = np.linalg.norm(traj2[0] - traj1, axis=-1).min()
        if from_traj2 < threshold:
            return True

        return False

    def _select_indices_from_pairwise_distance(self,
                                               dist: np.ndarray,
                                               threshold: float = 0.2,
                                               states_per_traj: int = 10,
                                               neighbors_per_state: int = 1):
        assert len(dist.shape) == 2
        min_diff = dist.min(axis=-1)
        valid_indices = np.where(min_diff < threshold)[0]

        selected_indices = []
        target_indices = np.linspace(valid_indices[0],
                                     valid_indices[-1] + 1,
                                     states_per_traj,
                                     dtype=int)
        for target_idx in target_indices:
            selected_idx_in_candidate = np.argmin(
                abs(valid_indices - target_idx))
            selected_indices.append(valid_indices[selected_idx_in_candidate])

        # add neighbors
        selected_indices = np.array(selected_indices)
        augmented_list = [selected_indices]
        for i in range(neighbors_per_state):
            augmented_list.append(selected_indices + (i + 1))
            augmented_list.append(selected_indices - (i + 1))
        selected_indices = np.concatenate(augmented_list).clip(
            0,
            len(dist) - 1)
        selected_indices = np.unique(selected_indices)
        return selected_indices

    def setup_dataset(self,
                      n_traj: int = 10,
                      correspondence_type: str = 'p2p',
                      states_per_traj: int = 10,
                      neighbors_per_state: int = 1,
                      correspondence_threshold: float = 0.2):
        ok_traj = 0
        while ok_traj < n_traj:
            traj1, traj2 = self._get_traj_pair(domain1=0, domain2=1)

            if correspondence_type == 'p2p':
                point1_pos = traj1.obs[..., [1, 0]]
                point2_pos = traj2.obs[..., :2]
                if not self._are_overlapped_trajectories(
                        point1_pos, point2_pos,
                        threshold=correspondence_threshold):
                    continue

                diff = point1_pos[:, None, :] - point2_pos[None]
                diff = np.linalg.norm(diff, axis=-1)
                corresponding_indices = np.argmin(diff, axis=-1)

                selected_indices = self._select_indices_from_pairwise_distance(
                    dist=diff,
                    threshold=correspondence_threshold,
                    states_per_traj=states_per_traj,
                    neighbors_per_state=neighbors_per_state)

                for idx in selected_indices:
                    data_dict = {
                        'obs1': traj1.obs[idx],
                        'obs2': traj2.obs[corresponding_indices[idx]]
                    }
                    self.state_pair_list.append(data_dict)

            elif correspondence_type == 'p2a':
                from common.ours.utils import ant_to_maze2d
                point_pos = traj1.obs[..., :2]
                ant_pos = traj2.obs[..., :2]
                ant_converted_pos = ant_to_maze2d(ant_pos)

                if not self._are_overlapped_trajectories(
                        point_pos,
                        ant_converted_pos,
                        threshold=correspondence_threshold):
                    continue

                diff = point_pos[:, None, :] - ant_converted_pos[None]
                diff = np.linalg.norm(diff, axis=-1)
                corresponding_indices = np.argmin(diff, axis=-1)

                selected_indices = self._select_indices_from_pairwise_distance(
                    dist=diff,
                    threshold=correspondence_threshold,
                    states_per_traj=states_per_traj,
                    neighbors_per_state=neighbors_per_state)

                for idx in selected_indices:
                    data_dict = {
                        'obs1': traj1.obs[idx],
                        'obs2': traj2.obs[corresponding_indices[idx]]
                    }
                    self.state_pair_list.append(data_dict)

            else:
                raise NotImplementedError

            ok_traj += 1

    def __len__(self):
        return len(self.state_pair_list)

    def _get_traj_pair(self, domain1: int, domain2: int):
        dataset1 = self.datasets[domain1]
        dataset2 = self.datasets[domain2]

        while True:
            selected_task_id = np.random.choice(self.task_id_list)

            selected_trajs1 = dataset1[dataset1.task_ids == selected_task_id]
            selected_trajs2 = dataset2[dataset2.task_ids == selected_task_id]
            if len(selected_trajs1) > 0 and len(selected_trajs2) > 0:
                break

        idx1 = np.random.randint(len(selected_trajs1))
        traj1 = selected_trajs1[idx1]

        idx2 = np.random.randint(len(selected_trajs2))
        traj2 = selected_trajs2[idx2]

        return traj1, traj2

    @staticmethod
    def _get_domain_id_array(length: int,
                             domain_id: int,
                             n_domain_id: int = 2) -> np.ndarray:
        domain_ids = np.zeros((length, n_domain_id), dtype=np.float32)
        domain_ids[..., domain_id] = 1.
        return domain_ids

    def __getitem__(self, item):
        state_pair = self.state_pair_list[item]
        obs1, obs2 = state_pair['obs1'], state_pair['obs2']

        domain_ids1 = self._get_domain_id_array(length=1,
                                                domain_id=0,
                                                n_domain_id=self.n_domains)[0]
        domain_ids2 = self._get_domain_id_array(length=1,
                                                domain_id=1,
                                                n_domain_id=self.n_domains)[0]
        obs1, _ = get_padded_batch_and_valid_mask(obs1,
                                                  target_dim=self.obs_dim)
        obs2, _ = get_padded_batch_and_valid_mask(obs2,
                                                  target_dim=self.obs_dim)

        data_dict = {
            'obs1': obs1,
            'obs2': obs2,
            'domain_id1': domain_ids1,
            'domain_id2': domain_ids2
        }

        return data_dict


def get_padded_batch_and_valid_mask(arr: np.ndarray, target_dim: int):
    one_dim = len(arr.shape) == 1
    if one_dim:
        arr = arr[None]
    pad_size = target_dim - arr.shape[-1]
    pad_array = np.zeros((arr.shape[0], pad_size), dtype=arr.dtype)
    result_array = np.concatenate([arr, pad_array], axis=-1)
    result_mask = np.concatenate([np.ones_like(arr), pad_array], axis=-1)

    if one_dim:
        result_array, result_mask = result_array[0], result_mask[0]

    return result_array, result_mask


class TorchStepDataset(Dataset):

    def __init__(self,
                 dataset: StepDataset,
                 obs_dim: int,
                 action_dim: int,
                 padding: bool = True):
        self.data = dataset
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.padding = padding

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        data = self.data[item]
        if self.padding:
            obs, _ = get_padded_batch_and_valid_mask(data.obs,
                                                     target_dim=self.obs_dim)
            next_obs, _ = get_padded_batch_and_valid_mask(
                data.next_obs, target_dim=self.obs_dim)
            actions, action_masks = get_padded_batch_and_valid_mask(
                data.actions, target_dim=self.action_dim)
        else:
            obs = data.obs
            next_obs = data.next_obs
            actions = data.actions
            action_masks = np.ones_like(actions)
        images = data.images

        return {
            'observations': obs.astype(np.float32),
            'actions': actions.astype(np.float32),
            'images': images.astype(np.float16),
            'action_masks': action_masks.astype(np.float32),
            'task_ids': data.get_onehot_task_id(),
            'domain_ids': data.get_onehot_domain_id(),
            'next_observations': next_obs,
            'next_images': data.next_images.astype(np.float16),
        }


def convert_traj_dataset_to_step_dataset(traj_dataset: TrajDataset):
    obs = np.concatenate(traj_dataset.obs)
    actions = np.concatenate(traj_dataset.actions)
    images = np.concatenate(traj_dataset.images)
    length_list = [len(tr) for tr in traj_dataset.obs]
    task_ids = np.repeat(traj_dataset.task_ids, length_list)
    domain_ids = np.ones(len(obs), dtype=int) * traj_dataset.domain_id
    next_obs = np.concatenate(traj_dataset.next_obs)
    next_images = np.concatenate(traj_dataset.next_images)

    step_dataset = StepDataset(
        obs=obs,
        actions=actions,
        images=images,
        task_ids=task_ids,
        n_task_id=traj_dataset.n_task_id,
        domain_ids=domain_ids,
        n_domain_id=traj_dataset.n_domain_id,
        next_obs=next_obs,
        next_images=next_images,
    )

    return step_dataset


def _convert_to_traj_array(arr: np.ndarray, dones: np.ndarray):
    ret = []
    start_idx = np.concatenate(([-1], np.where(dones)[0])) + 1

    for i in range(len(start_idx) - 1):
        st, gl = start_idx[i], start_idx[i + 1]
        ret.append(arr[st:gl])

    all_same_len = all([len(arr) == len(ret[0]) for arr in ret])

    if len(arr.shape) == 4 and all_same_len:  # image
        ret = np.array(ret)
    else:
        ret = convert_ndarray_list_to_obj_ndarray(ret)

    return ret


class PointMazeTaskIDManager:

    def __init__(self, env_id: str):
        self.env = gym.make(env_id)
        logger.info(f'Goal-ID dict: : {self.env.xy_to_id}')
        logger.info(
            f'The maze has {self.n_starts} start groups and {self.n_goals} goals. '
            f'{self.n_task_id} tasks in total.')

    @property
    def n_starts(self):
        return self.env.n_start_groups

    @property
    def n_goals(self):
        return len(self.env.id_to_xy)

    @property
    def n_task_id(self):
        return self.n_goals * self.n_starts

    def calc_task_id(self, goal: int, start: int):
        return self.n_starts * goal + start

    def goal_to_goal_id(self, goal_pos: np.ndarray):
        goal_pos = goal_pos.round().astype(int)
        return self.env.xy_to_id[tuple(goal_pos)]

    def start_to_start_id(self, start_pos: np.ndarray):
        return self.env.pos_to_start_id(start_pos)

    def traj_to_task_id(self, obs: np.ndarray):
        start, goal = obs[0][:2], obs[-1][:2]
        start_id = self.start_to_start_id(start_pos=start)
        goal_id = self.goal_to_goal_id(goal_pos=goal)
        task_id = self.calc_task_id(goal=goal_id, start=start_id)
        return task_id

    def task_id_to_goal_id(self, task_id: Union[int, np.ndarray]):
        return task_id // self.n_starts

    def task_id_to_start_id(self, task_id: Union[int, np.ndarray]):
        return task_id % self.n_starts

    def add_task_id_to_traj_dataset(self,
                                    dataset: TrajDataset,
                                    task_id: int = -1):
        task_ids = []
        for obs in dataset.obs:
            task_id_for_this = self.traj_to_task_id(
                obs=obs) if task_id < 0 else task_id
            task_ids.append(task_id_for_this)

        dataset.task_ids = np.array(task_ids)
        # dataset.max_task_id = self.n_task_id

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_starts
        return list(range(start, start + self.n_starts))


class AntMazeTaskIDManager:

    def __init__(self, env_id: str):
        # TODO
        self.env = gym.make(env_id)
        logger.info(f'Goal-ID dict: : {self.env.xy_to_id}')
        logger.info(
            f'The maze has {self.n_starts} start groups and {self.n_goals} goals. '
            f'{self.n_task_id} tasks in total.')

    @property
    def n_starts(self):
        return self.env.n_start_groups

    @property
    def n_goals(self):
        return len(self.env.id_to_xy)

    @property
    def n_task_id(self):
        return self.n_goals * self.n_starts

    def calc_task_id(self, goal: int, start: int):
        return self.n_starts * goal + start

    @staticmethod
    def round_by_num(x, num: int = 4):
        # make x to multiple of num (0, 4, 8, ...)
        return (x + num / 2) // num * num

    def goal_to_goal_id(self, goal_pos: np.ndarray):
        goal_pos = self.round_by_num(goal_pos).astype(int)
        return self.env.xy_to_id[tuple(goal_pos)]

    def start_to_start_id(self, start_pos: np.ndarray):
        return self.env.pos_to_start_id(start_pos)

    def traj_to_task_id(self, obs: np.ndarray):
        start, goal = obs[0][:2], obs[-1][:2]
        start_id = self.start_to_start_id(start_pos=start)
        goal_id = self.goal_to_goal_id(goal_pos=goal)
        task_id = self.calc_task_id(goal=goal_id, start=start_id)
        return task_id

    def task_id_to_goal_id(self, task_id: Union[int, np.ndarray]):
        return task_id // self.n_starts

    def task_id_to_start_id(self, task_id: Union[int, np.ndarray]):
        return task_id % self.n_starts

    def add_task_id_to_traj_dataset(self,
                                    dataset: TrajDataset,
                                    task_id: int = -1):
        task_ids = []
        for obs in dataset.obs:
            task_id_for_this = self.traj_to_task_id(
                obs=obs) if task_id < 0 else task_id
            task_ids.append(task_id_for_this)

        dataset.task_ids = np.array(task_ids)
        dataset.max_task_id = self.n_task_id

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_starts
        return list(range(start, start + self.n_starts))


class LiftTaskIDManager:

    def __init__(self, env_id: str):
        self.goal_ids = None
        self.dones = None

    @property
    def n_starts(self):
        return 1

    @property
    def n_goals(self):
        return 27

    @property
    def n_task_id(self):
        return self.n_goals * self.n_starts

    def set_goal_id(self, goal_ids: np.ndarray):
        self.goal_ids = goal_ids

    def set_dones(self, dones: np.ndarray):
        self.dones = dones

    def calc_task_id(self, goal: int, start: int):
        return self.n_starts * goal + start

    def calc_traj_task_id(self):
        dones_idx = np.where(self.dones)
        task_ids = self.goal_ids[dones_idx]
        return task_ids

    def goal_to_goal_id(self, goal_pos: np.ndarray):
        return self.goal_ids

    def start_to_start_id(self, start_pos: np.ndarray):
        return 1

    def task_id_to_goal_id(self, task_id: Union[int, np.ndarray]):
        return task_id // self.n_starts

    def task_id_to_start_id(self, task_id: Union[int, np.ndarray]):
        return task_id % self.n_starts

    def add_task_id_to_traj_dataset(self, dataset: TrajDataset):
        dataset.task_ids = np.array(self.calc_traj_task_id(), dtype=np.int32)
        dataset.max_task_id = self.n_task_id

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_starts
        return list(range(start, start + self.n_starts))


class StackTaskIDManager:

    def __init__(self, env_id: str):
        self.goal_ids = None
        self.dones = None

    @property
    def n_starts(self):
        return 1

    @property
    def n_goals(self):
        return 8

    @property
    def n_task_id(self):
        return self.n_goals * self.n_starts

    def set_goal_id(self, goal_ids: np.ndarray):
        self.goal_ids = goal_ids

    def set_dones(self, dones: np.ndarray):
        self.dones = dones

    def calc_task_id(self, goal: int, start: int):
        return self.n_starts * goal + start

    def calc_traj_task_id(self):
        dones_idx = np.where(self.dones)
        task_ids = self.goal_ids[dones_idx]
        return task_ids

    def goal_to_goal_id(self, goal_pos: np.ndarray):
        return self.goal_ids

    def start_to_start_id(self, start_pos: np.ndarray):
        return 1

    def task_id_to_goal_id(self, task_id: Union[int, np.ndarray]):
        return task_id

    def task_id_to_start_id(self, task_id: Union[int, np.ndarray]):
        return 1

    def add_task_id_to_traj_dataset(self, dataset: TrajDataset):
        dataset.task_ids = np.array(self.calc_traj_task_id(), dtype=np.int32)
        dataset.max_task_id = self.n_task_id

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_starts
        return list(range(start, start + self.n_starts))


class ReachGoalTaskIDManager:

    def __init__(self, env_id: str):
        self.goal_ids = None
        self.dones = None

    @property
    def n_starts(self):
        return 1

    @property
    def n_goals(self):
        return 4

    @property
    def n_task_id(self):
        return self.n_goals * self.n_starts

    def set_goal_id(self, goal_ids: np.ndarray):
        self.goal_ids = goal_ids

    def set_dones(self, dones: np.ndarray):
        self.dones = dones

    def calc_task_id(self, goal: int, start: int):
        return self.n_starts * goal + start

    def calc_traj_task_id(self):
        dones_idx = np.where(self.dones)
        task_ids = self.goal_ids[dones_idx]
        return task_ids

    def goal_to_goal_id(self, goal_pos: np.ndarray):
        return self.goal_ids

    def start_to_start_id(self, start_pos: np.ndarray):
        return 1

    def task_id_to_goal_id(self, task_id: Union[int, np.ndarray]):
        return task_id

    def task_id_to_start_id(self, task_id: Union[int, np.ndarray]):
        return 1

    def add_task_id_to_traj_dataset(self, dataset: TrajDataset):
        dataset.task_ids = np.array(self.calc_traj_task_id(), dtype=np.int32)
        dataset.max_task_id = self.n_task_id

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_starts
        return list(range(start, start + self.n_starts))


def get_task_id_manager(env_id: str):
    if 'maze2d' in env_id:
        return PointMazeTaskIDManager(env_id=env_id)
    elif 'ant' in env_id:
        return AntMazeTaskIDManager(env_id=env_id)
    elif 'Lift' in env_id:
        return LiftTaskIDManager(env_id=env_id)
    elif 'Stack' in env_id:
        return StackTaskIDManager(env_id=env_id)
    elif 'reach-goal' in env_id or 'reach-color' in env_id:
        return ReachGoalTaskIDManager(env_id=env_id)
    else:
        raise ValueError(f'Task ID manager for {env_id} is not implemented')


def get_goal_candidates(n_goals: int,
                        target_goal: int,
                        align: bool = True,
                        complex_task: bool = False,
                        is_r2r: bool = False,
                        r2r_single_layer=False,
                        n_tasks: int = -1) -> List[int]:
    """If align is True, target_goal is eliminated.
    Otherwise, only the target_goals is selected.
    """
    goal_candidates = np.array(range(n_goals)).tolist()
    if align:
        if not complex_task:
            goal_candidates.remove(target_goal)
    else:
        goal_candidates = [target_goal]

    if align and is_r2r:
        target_layer, target_pos = target_goal // 9, target_goal % 9
        # remove same pos
        goal_candidates = list(
            filter(lambda g: g % 9 != target_pos, goal_candidates))

        if r2r_single_layer:
            # keep only the same layer
            goal_candidates = list(
                filter(lambda g: g // 9 == target_layer, goal_candidates))
            assert len(goal_candidates) == 8
        else:
            # remove same layer
            goal_candidates = list(
                filter(lambda g: g // 9 != target_layer, goal_candidates))
            assert len(goal_candidates) == 16  # 27 - 9 - 3 + 1

    if n_tasks > 0:
        goal_candidates = random.sample(goal_candidates, k=n_tasks)

    return goal_candidates


def filter_by_goal_id(dataset: TrajDataset, goal_ids: List[int],
                      task_id_manager):
    if task_id_manager is None:
        flag = np.isin(dataset.task_ids, goal_ids)
    else:
        data_goal_ids = task_id_manager.task_id_to_goal_id(dataset.task_ids)
        flag = np.isin(data_goal_ids, goal_ids)
    dataset = dataset[flag]
    return dataset


def select_n_trajectories(dataset: TrajDataset, n_traj: int = -1):
    if n_traj <= 0:
        return dataset
    idx_list = np.arange(len(dataset))
    np.random.shuffle(idx_list)
    idx_list = idx_list[:n_traj]
    return dataset[idx_list]


def train_val_split(dataset: TrajDataset,
                    train_ratio: float) -> Tuple[TrajDataset, TrajDataset]:
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    logger.info(f'Train: {train_size} trajs; Val: {val_size} trajs.')

    idx_list = np.arange(len(dataset)).astype(np.int32)
    np.random.shuffle(idx_list)
    train_idx_list = idx_list[:train_size]
    val_idx_list = idx_list[train_size:]

    train_data, val_data = dataset[train_idx_list], dataset[val_idx_list]

    return train_data, val_data


def read_dataset(
    path: Path,
    env_id: str,
    n_additional_tasks: int = 0,
    image_observation: Optional[bool] = False,
    domain_id: Optional[int] = -1,
    args: Optional[DictConfig] = None,
) -> Tuple[TrajDataset, Union[PointMazeTaskIDManager, AntMazeTaskIDManager]]:
    with h5py.File(path, "r") as f:
        observations = np.array(f['observations'])
        if "v2" in env_id:
            self_state = observations[:, :4]
            prev_self_state = observations[:, 18:22]
            observations = np.concatenate((self_state, prev_self_state),
                                          axis=-1)
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        if 'ant' in env_id or 'ood' in path:
            dones = np.array(f['terminals'])
        else:
            dones = np.array(f['timeouts'])

        if 'Lift' in env_id or 'Stack' in env_id or "reach-goal" in env_id or "reach-color" in env_id:
            goal_ids = np.array(f['infos/goal_id'])

        if 'Lift' in env_id or 'Stack' in env_id:
            robot_state = np.array(f['infos/robot0_proprio-state'])
            observations = robot_state

        if image_observation and "v2" in env_id:
            assert domain_id >= 0
            print("Start loading image...")
            import time
            start = time.time()
            camera_name = "corner3" if domain_id == 0 else "corner"
            images = np.array(f[f'infos/{camera_name}_image'])
            end = time.time()
            print(f"Loading takes {end - start:.2f} seconds.")
        elif image_observation:
            images = np.empty((len(observations), 128, 128, 6),
                              dtype=np.float16)
            images[..., :3] = np.array(f['infos/agentview_image'],
                                       dtype=np.float16)
            images[..., 3:] = np.array(f['infos/sideview_image'],
                                       dtype=np.float16)
        else:
            images = np.empty((len(observations), 0))

    obs_trajs = _convert_to_traj_array(arr=observations, dones=dones)
    action_trajs = _convert_to_traj_array(arr=actions, dones=dones)
    image_trajs = _convert_to_traj_array(arr=images, dones=dones)

    task_id_manager = get_task_id_manager(env_id=env_id)
    if 'Lift' in env_id or 'Stack' in env_id or "reach-goal" in env_id or "reach-color" in env_id:
        task_id_manager.set_goal_id(goal_ids)
        task_id_manager.set_dones(dones)

    data = TrajDataset(
        obs=obs_trajs,
        actions=action_trajs,
        images=image_trajs,
        n_task_id=task_id_manager.n_task_id + n_additional_tasks,
    )

    # items = data.__getitem__(0)
    # print(items)
    # quit()

    # env = gym.make(env_id)
    # goals = goals.round().astype(int)
    # goal_ids = np.array([env.xy_to_id[tuple(goal)] for goal in goals])
    # n_goals = len(env.xy_to_id)
    # logger.info(f'Goal position to goal ID ({n_goals} goals) : {env.xy_to_id}')

    return data, task_id_manager


def read_multi_dataset(
    args: DictConfig,
    domain_info: DictConfig,
    image_observation: Optional[bool] = False,
    goal_id_offset: int = 0,
) -> Tuple[TrajDataset, None]:
    data = defaultdict(list)
    v2_envs = [
        "reach-goal-v2", "reach-color-v2", "reach-color_simple_3-v2",
        "reach-color_simple_2-v2", "window-close_4-v2"
    ]
    for env_tag in domain_info.env_tags:
        env_info = domain_info[env_tag]
        env_id = env_info.env
        path = env_info.dataset
        with h5py.File(path, "r") as f:
            observations = np.array(f['observations'])
            if "v2" in env_id:
                self_state = observations[:, :4]
                prev_self_state = observations[:, 18:22]
                observations = np.concatenate((self_state, prev_self_state),
                                              axis=-1)
            actions = np.array(f['actions'])
            goals = np.array(f['infos/goal'])
            if 'ant' in env_id:
                dones = np.array(f['terminals']).astype(np.bool_)
            else:
                dones = np.array(f['timeouts']).astype(np.bool_)

            if 'Lift' in env_id or 'Stack' in env_id or env_id in v2_envs:
                goal_ids = np.array(f['infos/goal_id'])
            else:
                goal_ids = None

            if 'Lift' in env_id or 'Stack' in env_id:
                robot_state = np.array(f['infos/robot0_proprio-state'])
                observations = robot_state

            if image_observation and "v2" in env_id:
                assert domain_info.domain_id >= 0
                print("Start loading image...")
                import time
                start = time.time()
                camera_name = "corner3" if domain_info.domain_id == 0 else "corner"
                images = np.array(f[f'infos/{camera_name}_image'])
                end = time.time()
                print(f"Loading takes {end - start:.2f} seconds.")
            else:
                images = np.empty((len(observations), 0))

        if goal_ids is not None:
            goal_ids += goal_id_offset
            if env_tag == domain_info.target_env:
                goal_ids -= args.target_goal_id
            mask = ((goal_id_offset <= goal_ids) &
                    (goal_ids < args.n_task_ids))
            goal_ids = goal_ids[mask]
            observations = observations[mask]
            actions = actions[mask]
            images = images[mask]
            dones = dones[mask]

            dones_ = dones.copy()
            dones_[-1] = True
            goal_ids = goal_ids[dones_]
            data["goal_ids"].extend(goal_ids)
            goal_id_offset += env_info.n_goals
        else:  # single task_id
            goal_id = goal_id_offset
            dones_ = dones.copy()
            dones_[-1] = True
            data["goal_ids"].extend([goal_id] * dones_.sum())
            goal_id_offset += 1

        obs_trajs = _convert_to_traj_array(arr=observations, dones=dones)
        action_trajs = _convert_to_traj_array(arr=actions, dones=dones)
        image_trajs = _convert_to_traj_array(arr=images, dones=dones)

        data["obs_trajs"].extend(obs_trajs)
        data["action_trajs"].extend(action_trajs)
        data["image_trajs"].extend(image_trajs)

    for k, v in data.items():
        if "trajs" in k:
            data[k] = convert_ndarray_list_to_obj_ndarray(v)
        else:
            data[k] = np.array(v)

    dataset = TrajDataset(
        obs=data["obs_trajs"],
        actions=data["action_trajs"],
        images=data["image_trajs"],
        task_ids=data["goal_ids"],
    )

    return dataset, None


def remove_single_step_trajectories(dataset: TrajDataset):
    length_list = np.array([len(tr) for tr in dataset.obs])
    flag = length_list > 1
    dataset = dataset[flag]
    return dataset


def _assert_contain_same_str(str1: str, str2: str, keyword: str):
    if keyword in str1:
        assert keyword in str2


def check_dataset_and_env_consistency(dataset_name: str, env_id: str):
    _assert_contain_same_str(env_id, dataset_name, 'ant')
    _assert_contain_same_str(env_id, dataset_name, 'maze2d')
    _assert_contain_same_str(env_id, dataset_name, 'umaze')
    _assert_contain_same_str(env_id, dataset_name, 'medium')


def check_source_and_target_consistency(
    source_name: str,
    target_name: str,
):
    _assert_contain_same_str(source_name, target_name, 'umaze')
    _assert_contain_same_str(source_name, target_name, 'medium')


def read_env_config_yamls(args: DictConfig):
    """It checks args.source and args.target, read common/config_utils/{args.source}.yaml and {args.target}.yaml,
and merge fields to args with a prefix "source_" or  "target_".
It also calculates max_obs_dim and max_action_dim from configs of the two domains.

Args:
    args: Omegaconf config dict.

Returns:
    config dict that with additional items such as "source_env", "source_dataset", and "source_obs_dim".

"""

    if args.multienv:
        return read_multi_env_config_yamls(args)

    max_obs_dim, max_action_dim, max_seq_len = -1, -1, -1
    for idx, domain_info in enumerate(args.domains):
        env_name = domain_info.env_tag.replace('-', '_')
        env_conf = OmegaConf.load(f'common/config_utils/{env_name}.yaml')
        merged_conf = OmegaConf.merge(domain_info, env_conf)
        if args.image_observation:
            merged_conf.obs_dim += args.image_state_dim
        args.domains[idx] = merged_conf
        max_obs_dim = max(max_obs_dim, merged_conf.obs_dim)
        max_action_dim = max(max_action_dim, merged_conf.action_dim)
        max_seq_len = max(max_seq_len, merged_conf.seq_len)

        check_dataset_and_env_consistency(dataset_name=merged_conf.dataset,
                                          env_id=merged_conf.env)

    args.max_obs_dim = max_obs_dim
    args.max_action_dim = max_action_dim
    args.max_seq_len = max_seq_len

    # TODO check data consistency between domains

    return args


def read_multi_env_config_yamls(args: DictConfig):

    # check consistency of env_tags
    env_tags = args.domains[0].env_tags
    for domain in args.domains[1:]:
        assert domain.env_tags == env_tags

    n_goals = 0
    env_keys = [
        "env", "dataset", "obs_dim", "action_dim", "seq_len", "n_goals"
    ]
    max_obs_dim, max_action_dim, max_seq_len = -1, -1, -1
    for env_tag in env_tags:
        copy_args = copy.deepcopy(args)
        copy_args.multienv = False
        for domain in copy_args.domains:
            domain.env_tag = env_tag
        read_env_config_yamls(copy_args)
        n_goals += copy_args.domains[0].n_goals
        max_obs_dim = max(max_obs_dim, copy_args.max_obs_dim)
        max_action_dim = max(max_action_dim, copy_args.max_action_dim)
        max_seq_len = max(max_seq_len, copy_args.max_seq_len)

        for i, domain in enumerate(args.domains):
            env_args = {k: copy_args.domains[i][k] for k in env_keys}
            domain[env_tag] = env_args

    for domain in args.domains:
        domain.n_goals = n_goals
        domain.obs_dim = max_obs_dim
        domain.action_dim = max_action_dim
        domain.dataset = ""
        domain.env = ""

    args.max_obs_dim = max_obs_dim
    args.max_action_dim = max_action_dim
    args.max_seq_len = max_seq_len
    set_multienv_task_ids(args)

    return args


class ObservationConverter:

    def __call__(self, obs):
        # sim obs -> obs in domain
        raise NotImplementedError

    def inv(self, obs):
        # obs in domain -> sim obs
        raise NotImplementedError


class ActionConverter:

    def __call__(self, action):
        # sim action -> action in domain
        raise NotImplementedError

    def inv(self, action):
        # action in domain -> sim action
        raise NotImplementedError


class IdentityObsConverter(ObservationConverter):

    def __call__(self, obs):
        return obs

    def inv(self, obs):
        return obs


class PointSwapObsConverter(ObservationConverter):

    def __call__(self, obs):
        return obs[..., [1, 0, 3, 2]]

    def inv(self, obs):
        return obs[..., [1, 0, 3, 2]]


class AffineObsConverter(ObservationConverter):

    def __init__(self, angle, scale):
        self.angle = angle
        self.scale = scale
        assert abs(self.scale) > 1e-5
        self.rotmat = self.get_rotmat(angle)
        logger.info(f'AffineObsConverter: angle={angle}, scale={scale}')

    @staticmethod
    def get_rotmat(deg: float):
        rad = np.deg2rad(deg)
        rot = np.array([[np.cos(rad), -np.sin(rad)],
                        [np.sin(rad), np.cos(rad)]]).astype(np.float32)
        return rot

    def __call__(self, obs):
        pos, vel = np.split(obs, 2, axis=-1)
        pos = np.dot(pos, self.rotmat.T) * self.scale
        vel = np.dot(vel, self.rotmat.T) * self.scale
        ret = np.concatenate((pos, vel), axis=-1)
        return ret

    def inv(self, obs):
        pos, vel = np.split(obs, 2, axis=-1)
        pos = np.dot(pos, self.rotmat) / self.scale
        vel = np.dot(vel, self.rotmat) / self.scale
        ret = np.concatenate((pos, vel), axis=-1)
        return ret


class ReverseObsConverter(ObservationConverter):

    def __call__(self, obs):
        return obs[..., ::-1]

    def inv(self, obs):
        return obs[..., ::-1]


class IdentityActionConverter(ActionConverter):

    def __call__(self, action):
        return action

    def inv(self, action):
        return action


class NegativeActionConverter(ActionConverter):

    def __call__(self, action):
        return -action

    def inv(self, action):
        return -action


def get_obs_converter(name: Optional[str] = None,
                      **kwargs) -> ObservationConverter:
    if name is None or name == 'identity':
        return IdentityObsConverter()

    elif name == 'point_swap':
        return PointSwapObsConverter()

    elif name == 'affine':
        return AffineObsConverter(**kwargs)

    elif name == 'reverse':
        return ReverseObsConverter(**kwargs)

    else:
        raise ValueError(f'obs converter "{name}" is not implemented.')


def get_action_converter(name: Optional[str] = None,
                         **kwargs) -> ActionConverter:
    if name is None or name == 'identity':
        return IdentityActionConverter()

    elif name == 'negative':
        return NegativeActionConverter()

    else:
        raise ValueError(f'action converter "{name}" is not implemented.')


def set_multienv_task_ids(args: DictConfig):
    assert args.multienv
    n_task_ids = 0
    args.task_id_offset_list = []
    for env_tag in args.domains[-1].env_tags[:-1]:
        args.task_id_offset_list.append(n_task_ids)
        n_task_ids += args.domains[-1][env_tag].n_goals
    n_task_ids += 1
    args.task_id_offset_list.append(n_task_ids)
    args.n_task_ids = n_task_ids
    args.target_task_id_offset = n_task_ids - 1
