import bisect
import pickle
import traceback
from pathlib import Path
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torch.utils.data as data
from numpy.lib.stride_tricks import sliding_window_view
from tqdm import tqdm

from .dataset_utils import filter_dataset


class TrajectoryDataset(data.Dataset):

    def __init__(
        self,
        dataset: Dict[str, np.ndarray],
        num_task_ids: int,
        max_size: Optional[int] = None,
        state_dim: int = None,
        domain_id: int = None,
        domain_dim: int = None,
    ):
        super().__init__()

        # Process observations ----------
        obs_array = dataset["observations"]
        next_obs_array = dataset["next_observations"]
        n, s = obs_array.shape

        if state_dim:
            assert obs_array.shape[1] <= state_dim
            if obs_array.shape[1] < state_dim:
                pad = np.zeros((n, state_dim - s))
                obs_array = np.hstack((obs_array, pad))
                next_obs_array = np.hstack((next_obs_array, pad))

        # Process actions ----------
        act_array = dataset["actions"]

        # Process domain ids ----------
        if domain_dim is not None and domain_id is not None:
            domain_ids = np.eye(domain_dim)[domain_id][None, :]
            domain_ids_onehot = np.tile(domain_ids, (n, 1))
        else:
            domain_ids_onehot = np.zeros((n, 2))

        # Process task ids ----------
        task_ids = dataset["infos/goal_id"]
        task_ids_onehot = np.eye(num_task_ids)[task_ids - 1]

        random_idx = np.arange(len(obs_array))
        np.random.shuffle(random_idx)
        if max_size:
            random_idx = random_idx[:max_size]

        self.obs_array = torch.Tensor(obs_array[random_idx].copy())
        self.act_array = torch.Tensor(act_array[random_idx].copy())
        self.next_obs_array = torch.Tensor(next_obs_array[random_idx].copy())
        self.task_ids_onehot = torch.Tensor(task_ids_onehot[random_idx].copy())
        self.domain_ids_onehot = torch.Tensor(
            domain_ids_onehot[random_idx].copy())

        print("Dataset length:", len(self.obs_array))

    def __len__(self):
        return len(self.obs_array)

    def __getitem__(self, index):
        obs = self.obs_array[index]
        act = self.act_array[index]
        next_obs = self.next_obs_array[index]
        task_id = self.task_ids_onehot[index]
        domain_id = self.domain_ids_onehot[index]
        act_mask = torch.ones_like(act)
        return obs, task_id, domain_id, act, next_obs, act_mask


class AlignmentDataset(data.Dataset):

    def __init__(
        self,
        source_dataset: Dict[str, np.ndarray],
        target_dataset: Dict[str, np.ndarray],
        task_ids: List[int],
        num_task_ids: int,
        max_size: Optional[int] = None,
        use_domain_id: bool = True,
        traj_len: int = 1,
    ):
        super().__init__()
        source_dataset = filter_dataset(source_dataset, task_ids)
        target_dataset = filter_dataset(target_dataset, task_ids)

        # Process observations ----------
        source_observations = source_dataset["observations"]
        target_observations = target_dataset["observations"]
        source_next_observations = np.array(
            source_dataset["next_observations"])
        target_next_observations = np.array(
            target_dataset["next_observations"])

        source_random_idx = np.arange(len(source_observations))
        target_random_idx = np.arange(len(target_observations))
        np.random.shuffle(source_random_idx)
        np.random.shuffle(target_random_idx)
        if max_size:
            source_random_idx = source_random_idx[:max_size // 2]
            target_random_idx = target_random_idx[:max_size // 2]

        if source_observations.shape[1] != target_observations.shape[1]:
            observation_size = np.max(
                (source_observations.shape[1], target_observations.shape[1]))
            if source_observations.shape[1] < target_observations.shape[1]:
                pad = np.zeros(
                    (source_observations.shape[0],
                     observation_size - source_observations.shape[1]))
                source_observations = np.hstack((source_observations, pad))
                source_next_observations = np.hstack(
                    (source_next_observations, pad))
            else:
                pad = np.zeros(
                    (target_observations.shape[0],
                     observation_size - target_observations.shape[1]))
                target_observations = np.hstack((target_observations, pad))
                target_next_observations = np.hstack(
                    (target_next_observations, pad))

        self.observations = torch.Tensor(
            np.vstack((source_observations[source_random_idx],
                       target_observations[target_random_idx])).copy())
        self.next_observations = torch.Tensor(
            np.vstack((source_next_observations[source_random_idx],
                       target_next_observations[target_random_idx])).copy())

        # Process trajectory ----------
        if traj_len == 1:
            self.trajectories = self.observations
        else:
            source_observations_ = np.vstack(
                (source_observations[0][None].repeat(traj_len - 1, axis=0),
                 source_observations))
            source_trajs = sliding_window_view(
                source_observations_,
                (traj_len, source_observations_.shape[1]))

            target_observations_ = np.vstack(
                (target_observations[0][None].repeat(traj_len - 1, axis=0),
                 target_observations))
            target_trajs = sliding_window_view(
                target_observations_,
                (traj_len, target_observations_.shape[1]))

            self.trajectories = torch.Tensor(
                np.vstack((source_trajs, target_trajs))).squeeze()

        # Process actions ----------
        source_actions = source_dataset["actions"]
        target_actions = target_dataset["actions"]
        source_actions_mask = np.ones_like(source_actions)
        target_actions_mask = np.ones_like(target_actions)

        if source_actions.shape[1] != target_actions.shape[1]:
            action_size = np.max(
                (source_actions.shape[1], target_actions.shape[1]))
            if source_actions.shape[1] < target_actions.shape[1]:
                pad = np.zeros((source_actions.shape[0],
                                action_size - source_actions.shape[1]))
                source_actions = np.hstack((source_actions, pad))
                source_actions_mask = np.hstack((source_actions_mask, pad))
            else:
                pad = np.zeros((target_actions.shape[0],
                                action_size - target_actions.shape[1]))
                target_actions = np.hstack((target_actions, pad))
                target_actions_mask = np.hstack((target_actions_mask, pad))

        self.actions = torch.Tensor(
            np.vstack((source_actions[source_random_idx],
                       target_actions[target_random_idx])).copy())
        self.actions_mask = torch.Tensor(
            np.vstack((source_actions_mask[source_random_idx],
                       target_actions_mask[target_random_idx])).copy())

        # Process task ids ----------
        source_task_ids = np.eye(num_task_ids)[source_dataset["infos/goal_id"]
                                               - 1]
        target_task_ids = np.eye(num_task_ids)[target_dataset["infos/goal_id"]
                                               - 1]

        self.task_ids = torch.Tensor(
            np.vstack((source_task_ids[source_random_idx],
                       target_task_ids[target_random_idx])).copy())

        # Process domain ids ----------
        if use_domain_id:
            source_domain_ids = np.eye(2)[0][None, :].repeat(
                source_observations.shape[0], axis=0)
            target_domain_ids = np.eye(2)[1][None, :].repeat(
                target_observations.shape[0], axis=0)
        else:
            source_domain_ids = np.zeros(
                (1, 2)).repeat(source_observations.shape[0], axis=0)
            target_domain_ids = np.zeros(
                (1, 2)).repeat(target_observations.shape[0], axis=0)
        self.domain_ids = torch.Tensor(
            np.vstack((source_domain_ids[source_random_idx],
                       target_domain_ids[target_random_idx])).copy())

        print("Source dataset length:", len(source_random_idx))
        print("Target dataset length:", len(target_random_idx))
        print("Total dataset length:", len(self.observations))

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, index):
        obs = self.observations[index]
        act = self.actions[index]
        next_obs = self.next_observations[index]
        task_id = self.task_ids[index]
        domain_id = self.domain_ids[index]
        act_mask = self.actions_mask[index]
        return obs, act, next_obs, task_id, domain_id, act_mask


class TripletDataset(data.Dataset):

    def __init__(
        self,
        dataset: Dict[str, np.ndarray],
        dataset_size: int,
        positive_margin: int,
        negative_margin: int,
        domain_id: int,
        domain_dim: int,
        state_dim: int,
        dataset_path: Optional[Union[str, Path]] = None,
    ):
        super().__init__()

        observations = dataset["observations"]

        assert observations.shape[1] <= state_dim
        n, s = observations.shape
        if observations.shape[1] < state_dim:
            pad = np.zeros((n, state_dim - s))
            observations = np.hstack((observations, pad))

        domain_ids = np.eye(domain_dim)[domain_id][None, :]
        domain_id_onehot = np.tile(domain_ids, (n, 1))

        self.positive_margin = positive_margin
        self.negative_margin = negative_margin
        self.observations = torch.Tensor(
            np.concatenate([observations, domain_id_onehot], axis=-1))

        self.dataset_size = dataset_size

        terminals = dataset["terminals"]
        starts = np.full_like(terminals, False)
        starts[0] = True
        starts[1:] = terminals[:-1]
        self.start_indices = np.where(starts == True)[0]

        loaded = False
        if dataset_path:
            try:
                self.load(dataset_path)
                assert self.dataset_size <= len(self.anchor_observations)
                loaded = True
                print("Loaded", dataset_path)
            except Exception:
                traceback.print_exc()
                print("Unable to load", dataset_path)

        if not loaded:
            self.build_dataset(dataset_size)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, index):
        anchor_obs = self.anchor_observations[index]
        positive_obs = self.positive_observations[index]
        negative_obs = self.negative_observations[index]
        return anchor_obs, positive_obs, negative_obs

    def build_dataset(self, dataset_size):
        print("Building triplet dataset...")
        anchor_observations = []
        positive_observations = []
        negative_observations = []
        for _ in tqdm(range(dataset_size),
                      bar_format="{l_bar}{bar:50}{r_bar}"):
            index = np.random.randint(len(self.observations))
            anchor_obs = self.observations[index]

            try:
                positive_idx = self.sample_positive_index(index)
            except ValueError:
                continue
            except Exception:
                traceback.print_exc()
            positive_obs = self.observations[positive_idx]

            try:
                negative_idx = self.sample_negative_index(index)
            except ValueError:
                continue
            except Exception:
                traceback.print_exc()
            negative_obs = self.observations[negative_idx]

            anchor_observations.append(anchor_obs)
            positive_observations.append(positive_obs)
            negative_observations.append(negative_obs)

        self.anchor_observations = torch.stack(anchor_observations)
        self.positive_observations = torch.stack(positive_observations)
        self.negative_observations = torch.stack(negative_observations)
        print("Finished!")

    def _search(self, index):
        idx = bisect.bisect_left(self.start_indices, index)
        if self.start_indices[idx + 1] == index:
            return idx + 1
        else:
            return idx

    def sample_positive_index(self, index):
        idx = self._search(index)
        offset_index = self.start_indices[idx]
        episode_length = self.start_indices[idx + 1] - offset_index

        lower_bound = max(offset_index, index - self.positive_margin)
        lower_range = np.arange(lower_bound, index)

        upper_bound = min(index + self.positive_margin,
                          offset_index + episode_length)
        upper_range = np.arange(index + 1, upper_bound)

        range_ = np.concatenate([lower_range, upper_range])
        if len(range_) == 0:
            raise ValueError
        return np.random.choice(range_)

    def sample_negative_index(self, index):
        idx = self._search(index)
        offset_index = self.start_indices[idx]
        episode_length = self.start_indices[idx + 1] - offset_index

        upper_bound = max(offset_index, index - self.negative_margin)
        if upper_bound == offset_index:
            lower_idx = None
        else:
            lower_idx = np.random.randint(offset_index, upper_bound)
            lower_len = upper_bound - offset_index

        lower_bound = min(index + self.negative_margin + 1,
                          offset_index + episode_length - 1)
        if lower_bound == offset_index + episode_length - 1:
            upper_idx = None
        else:
            upper_idx = np.random.randint(lower_bound,
                                          offset_index + episode_length)
            upper_len = offset_index + episode_length - lower_bound

        if lower_idx is None and upper_idx is None:
            raise ValueError
        elif lower_idx is None:
            return upper_idx
        elif upper_idx is None:
            return lower_idx
        else:
            p = np.array([lower_len, upper_len])
            p = p / p.sum()
            return np.random.choice([upper_idx, lower_idx], p=p)

    def save(self, path):
        data = {
            "anchor_observations": self.anchor_observations,
            "positive_observations": self.positive_observations,
            "negative_observations": self.negative_observations,
        }
        pickle.dump(data, open(path, "wb"))

    def load(self, path):
        data = pickle.load(open(path, "rb"))
        self.anchor_observations = data["anchor_observations"]
        self.positive_observations = data["positive_observations"]
        self.negative_observations = data["negative_observations"]
