from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union

import h5py
import numpy as np
import torch
import torch.utils.data as data

from .trans_fn import get_trans_observations_fns, trans_action_fn
from .utils import task_id_map_m2p


class DatasetWriter(object):

    def __init__(self):
        self.data = self._reset_data()
        self._num_samples = 0

    def _reset_data(self):
        data = {
            "observations": [],
            "actions": [],
            "next_observations": [],
            "terminals": [],
            "rewards": [],
            "infos/goal": [],
            "infos/goal_id": [],
        }
        return data

    def __len__(self):
        return self._num_samples

    def append_data(self, s, a, s_, r, done, goal, goal_id):
        self._num_samples += 1
        self.data["observations"].append(s)
        self.data["actions"].append(a)
        self.data["next_observations"].append(s_)
        self.data["rewards"].append(r)
        self.data["terminals"].append(done)
        self.data["infos/goal"].append(goal)
        self.data["infos/goal_id"].append(goal_id)

    def extend_data(self, s, a, s_, r, done, goal, goal_id):
        self._num_samples += len(s)
        self.data["observations"].extend(s)
        self.data["actions"].extend(a)
        self.data["next_observations"].extend(s_)
        self.data["rewards"].extend(r)
        self.data["terminals"].extend(done)
        self.data["infos/goal"].extend(goal)
        self.data["infos/goal_id"].extend(goal_id)

    def merge(self, writer):
        self._num_samples += len(writer)
        for k in self.data:
            self.data[k].extend(writer.data[k])

    def write_dataset(self, fname, max_size=None, compression="gzip"):
        np_data = {}
        for k in self.data:
            if k == "terminals":
                dtype = np.bool_
            elif k == "infos/goal_id":
                dtype = np.int64
            else:
                dtype = np.float32
            data = np.array(self.data[k], dtype=dtype)
            if max_size is not None:
                data = data[:max_size]
            np_data[k] = data

        dataset = h5py.File(fname, "w")
        for k in np_data:
            dataset.create_dataset(k, data=np_data[k], compression=compression)
        dataset.close()


def write_dataset(self, fname, max_size=None, compression="gzip"):
    np_data = {}
    for k in self.data:
        if k == "terminals":
            dtype = np.bool_
        elif k == "infos/goal_id":
            dtype = np.int64
        else:
            dtype = np.float32
        data = np.array(self.data[k], dtype=dtype)
        if max_size is not None:
            data = data[:max_size]
        np_data[k] = data

    dataset = h5py.File(fname, "w")
    for k in np_data:
        dataset.create_dataset(k, data=np_data[k], compression=compression)
    dataset.close()


def merge_dataset(dataset_paths, output_path):
    dataset_dict = defaultdict(list)
    keys = None
    for path in dataset_paths:
        file = h5py.File(path, mode="r")
        for key in file.keys():
            if key == "infos":
                for ikey in file["infos"]:
                    name = "infos/" + ikey
                    dataset_dict[name].append(file[name])
            else:
                dataset_dict[key].append(file[key])

    for key in dataset_dict.keys():
        dataset_dict[key] = np.concatenate(dataset_dict[key])

    new_dataset = h5py.File(output_path, "w")
    for k in dataset_dict:
        new_dataset.create_dataset(k, data=dataset_dict[k], compression="gzip")
    new_dataset.close()


def filter_dataset(
    dataset: Dict[str, np.ndarray],
    task_ids: List,
) -> Dict[str, np.ndarray]:
    goal_ids = dataset["infos/goal_id"]
    mask = np.full_like(goal_ids, False, dtype=np.bool_)
    for task_id in task_ids:
        mask |= (goal_ids == task_id)

    dataset_ = {}
    for k in dataset.keys():
        dataset_[k] = dataset[k][mask]

    return dataset_


def transform_dataset(
    env_id: str,
    dataset_file: h5py.File,
    transform_observations: bool = False,
    transform_actions: bool = False,
) -> Dict[str, np.ndarray]:

    dataset_ = {}
    for k in dataset_file.keys():
        if k == "infos":
            for ik in dataset_file["infos"].keys():
                dataset_[f"infos/{ik}"] = np.array(dataset_file["infos"][ik])
        else:
            dataset_[k] = np.array(dataset_file[k])

    trans_observation_fn, _ = get_trans_observations_fns(env_id)

    if transform_observations:
        dataset_["observations"] = trans_observation_fn(
            dataset_["observations"])
        if "next_observations" in dataset_:
            dataset_["next_observations"] = trans_observation_fn(
                dataset_["next_observations"])

    if transform_actions:
        dataset_["actions"] = trans_action_fn(dataset_["actions"])

    if "point" in env_id:
        dataset_["observations"][1:, [2, 3]] = (
            dataset_["actions"][:-1] * 2 *
            (np.random.rand(*(dataset_["actions"][:-1].shape)) * 0.5 + 2))
    return dataset_


def get_dataset(
    dataset_path: Union[str, Path],
    task_ids: List[int],
    transform_observations: bool = False,
    transform_actions: bool = False,
) -> Dict[str, np.ndarray]:

    dataset_file = h5py.File(dataset_path, "r")
    env_id = Path(dataset_path).stem

    dataset = transform_dataset(
        env_id=env_id,
        dataset_file=dataset_file,
        transform_observations=transform_observations,
        transform_actions=transform_actions,
    )

    if "maze2d" in dataset_path:
        if "medium" in dataset_path:
            dataset["infos/goal_id"] = np.array([
                task_id_map_m2p["medium"][int(t)]
                for t in dataset["infos/goal_id"]
            ])
        if "umaze" in dataset_path:
            dataset["infos/goal_id"] = np.array([
                task_id_map_m2p["umaze"][int(t)]
                for t in dataset["infos/goal_id"]
            ])

    dataset = filter_dataset(
        dataset=dataset,
        task_ids=task_ids,
    )

    return dataset


def process_alignment_dataset(
    source_dataset: Dict[str, np.ndarray],
    target_dataset: Dict[str, np.ndarray],
    num_task_ids: int,
    max_size: Optional[int] = None,
    use_domain_id: bool = True,
) -> Dict[str, torch.Tensor]:

    alignment_dataset = {}

    # Process observations ----------
    source_observations = source_dataset["observations"]
    target_observations = target_dataset["observations"]
    source_next_observations = np.array(source_dataset["next_observations"])
    target_next_observations = np.array(target_dataset["next_observations"])

    source_random_idx = np.arange(len(source_observations))
    target_random_idx = np.arange(len(target_observations))
    np.random.shuffle(source_random_idx)
    np.random.shuffle(target_random_idx)
    if max_size:
        source_random_idx = source_random_idx[:max_size // 2]
        target_random_idx = target_random_idx[:max_size // 2]

    if source_observations.shape[1] != target_observations.shape[1]:
        observation_size = np.max(
            (source_observations.shape[1], target_observations.shape[1]))
        if source_observations.shape[1] < target_observations.shape[1]:
            pad = np.zeros((source_observations.shape[0],
                            observation_size - source_observations.shape[1]))
            source_observations = np.hstack((source_observations, pad))
            source_next_observations = np.hstack(
                (source_next_observations, pad))
        else:
            pad = np.zeros((target_observations.shape[0],
                            observation_size - target_observations.shape[1]))
            target_observations = np.hstack((target_observations, pad))
            target_next_observations = np.hstack(
                (target_next_observations, pad))

    alignment_dataset["observations"] = torch.Tensor(
        np.vstack((source_observations[source_random_idx],
                   target_observations[target_random_idx])).copy())
    alignment_dataset["next_observations"] = torch.Tensor(
        np.vstack((source_next_observations[source_random_idx],
                   target_next_observations[target_random_idx])).copy())

    # Process actions ----------
    source_actions = source_dataset["actions"]
    target_actions = target_dataset["actions"]
    source_action_masks = np.ones_like(source_actions)
    target_action_masks = np.ones_like(target_actions)

    if source_actions.shape[1] != target_actions.shape[1]:
        action_size = np.max(
            (source_actions.shape[1], target_actions.shape[1]))
        if source_actions.shape[1] < target_actions.shape[1]:
            pad = np.zeros((source_actions.shape[0],
                            action_size - source_actions.shape[1]))
            source_actions = np.hstack((source_actions, pad))
            source_action_masks = np.hstack((source_action_masks, pad))
        else:
            pad = np.zeros((target_actions.shape[0],
                            action_size - target_actions.shape[1]))
            target_actions = np.hstack((target_actions, pad))
            target_action_masks = np.hstack((target_action_masks, pad))

    alignment_dataset["actions"] = torch.Tensor(
        np.vstack((source_actions[source_random_idx],
                   target_actions[target_random_idx])).copy())
    alignment_dataset["action_masks"] = torch.Tensor(
        np.vstack((source_action_masks[source_random_idx],
                   target_action_masks[target_random_idx])).copy())

    # Process task ids ----------
    source_task_ids = np.eye(num_task_ids)[source_dataset["infos/goal_id"] - 1]
    target_task_ids = np.eye(num_task_ids)[target_dataset["infos/goal_id"] - 1]

    alignment_dataset["task_ids"] = torch.Tensor(
        np.vstack((source_task_ids[source_random_idx],
                   target_task_ids[target_random_idx])).copy())

    # Process domain ids ----------
    if use_domain_id:
        source_domain_ids = np.eye(2)[0][None, :].repeat(
            source_observations.shape[0], axis=0)
        target_domain_ids = np.eye(2)[1][None, :].repeat(
            target_observations.shape[0], axis=0)
    else:
        source_domain_ids = np.zeros(
            (1, 2)).repeat(source_observations.shape[0], axis=0)
        target_domain_ids = np.zeros(
            (1, 2)).repeat(target_observations.shape[0], axis=0)
    alignment_dataset["domain_ids"] = torch.Tensor(
        np.vstack((source_domain_ids[source_random_idx],
                   target_domain_ids[target_random_idx])).copy())

    print("Source dataset length:", len(source_random_idx))
    print("Target dataset length:", len(target_random_idx))
    print("Total dataset length:", len(alignment_dataset["observations"]))

    return alignment_dataset


def split_dataset(dataset: data.TensorDataset, train_ratio: float):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    train_dataset, val_dataset = data.random_split(dataset,
                                                   [train_size, val_size])

    return train_dataset, val_dataset
