import logging
import os
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import comet_ml
import gym
import h5py
import imageio.v2
import numpy as np
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def read_dataset(
    path: str,
    source_trans_fn: Callable = lambda x: x,
    source_action_trans_fn: Callable = lambda x: x,
):
    with h5py.File(path, 'r') as f:
        obs = np.array(f['observations'])
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        dones = np.array(f['timeouts'])

    goals = goals.round().astype(int)
    goal_list = np.unique(goals, axis=0)
    goal_to_task_id = {tuple(goal): i + 1 for i, goal in enumerate(goal_list)}
    logger.info(f'Goal to task ID: {goal_to_task_id}')
    task_ids = np.array([goal_to_task_id[tuple(goal)] for goal in goals])
    task_ids_onehot = np.eye(task_ids.max() + 1)[task_ids].astype(np.float32)

    # x <-> y and vx <-> vy are swapped in the source domain
    source_obs = source_trans_fn(obs).astype(np.float32)
    source_actions = source_action_trans_fn(actions).astype(np.float32)
    target_obs = obs.astype(np.float32)
    target_actions = actions.astype(np.float32)

    return task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id, goals


def split_dataset(dataset: torch.utils.data.TensorDataset, train_ratio: float,
                  batch_size: int):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return train_loader, val_loader


def calc_accuracy(pred_logit, gt_logit):
    with torch.no_grad():
        gt_domain_id = gt_logit.max(1).indices
        pred_domain_id = pred_logit.max(1).indices
        correct = (gt_domain_id == pred_domain_id).sum().item()
        accuracy = correct / len(gt_domain_id)

    return accuracy


def calc_alignment_score(encoder: torch.nn.Module, states, domain_ids):
    inverted_states = states[..., [1, 0, 3, 2]]
    inverted_domains = domain_ids[..., [1, 0]]
    inp = torch.cat((states, domain_ids), dim=-1)
    inverted_inp = torch.cat((inverted_states, inverted_domains), dim=-1)

    with torch.no_grad():
        z = encoder(inp)
        z_inv = encoder(inverted_inp)

    align_score = torch.nn.MSELoss()(z, z_inv).item()
    return align_score


def _remove_single_step_trajectories(dones: np.ndarray, *arrays: np.ndarray):
    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    traj_lens = done_idx[1:] - done_idx[:-1]
    single_step_traj_ids = np.where(traj_lens == 1)[0]

    valid_steps = np.ones(len(dones), dtype=bool)
    for single_step_traj_id in single_step_traj_ids:
        invalid_step_idx = done_idx[single_step_traj_id + 1]
        valid_steps[invalid_step_idx] = False

    new_arrays = []
    for array in arrays:
        new_arrays.append(array[valid_steps])
    dones = dones[valid_steps]

    return (dones, *new_arrays)


def _select_n_trajecotries(source_obs: np.ndarray, target_obs: np.ndarray,
                           task_ids_onehot: np.ndarray,
                           source_actions: np.ndarray,
                           target_actions: np.ndarray, dones: np.ndarray,
                           n: int):

    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    start_bit = np.concatenate(([False], dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    selected_traj_ids = np.random.choice(range(traj_ids.max() + 1),
                                         n,
                                         replace=False)

    selected_steps = np.zeros(len(traj_ids))
    for selected_id in selected_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_steps):
            selected_steps[gl_idx + 1] -= 1

    selected_steps = np.cumsum(selected_steps).astype(bool)

    source_obs = source_obs[selected_steps]
    target_obs = target_obs[selected_steps]
    task_ids_onehot = task_ids_onehot[selected_steps]
    source_actions = source_actions[selected_steps]
    target_actions = target_actions[selected_steps]
    dones = dones[selected_steps]

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones


def filter_dataset(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    dones: np.ndarray,
    goals: np.ndarray,
    task_ids: np.ndarray,
    filter_by_id_fn: Callable,
    n_traj: Optional[int] = None,
):
    dones[-1] = True
    logger.info(f'At first, {dones.sum()} trajectories available.')

    select_flag = filter_by_id_fn(task_ids)
    next_flag = np.roll(select_flag, 1)
    next_flag[0] = False

    source_next_obs = source_obs[next_flag]
    target_next_obs = target_obs[next_flag]
    source_obs = source_obs[select_flag]
    target_obs = target_obs[select_flag]
    task_ids_onehot = task_ids_onehot[select_flag]
    source_actions = source_actions[select_flag]
    target_actions = target_actions[select_flag]
    dones = dones[select_flag]
    goals = goals[select_flag]

    if len(source_obs) != len(source_next_obs):
        source_obs = source_obs[:-1]
        source_actions = source_actions[:-1]
        target_obs = target_obs[:-1]
        target_actions = target_actions[:-1]
        task_ids_onehot = task_ids_onehot[:-1]
        dones = dones[:-1]
        goals = goals[:-1]

    task_id_list = np.unique(task_ids[select_flag])
    logger.info(f'After task ID filtering, {dones.sum()} trajectories remain.')
    logger.info(
        f'{len(task_id_list)} tasks are going to be used. (ID={task_id_list})')

    # remove single_step trajectory
    dones, source_obs, source_next_obs, source_actions, target_obs, target_next_obs, target_actions, task_ids_onehot, goals = \
    _remove_single_step_trajectories(
        dones,
        source_obs,
        source_next_obs,
        source_actions,
        target_obs,
        target_next_obs,
        target_actions,
        task_ids_onehot,
        goals,
    )
    logger.info(
        f'After removing single-step trajectories, {dones.sum()} trajectories remain.'
    )

    # select final trajectories
    # if n_traj:
    #     source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones = _select_n_trajecotries(
    #         source_obs=source_obs,
    #         target_obs=target_obs,
    #         task_ids_onehot=task_ids_onehot,
    #         source_actions=source_actions,
    #         target_actions=target_actions,
    #         dones=dones,
    #         n=n_traj,
    #     )

    logger.info(f'Finally, {dones.sum()} trajectories are selected.')
    if n_traj:
        assert dones.sum() == n_traj

    return source_obs, source_next_obs, source_actions, target_obs,\
            target_next_obs, target_actions, task_ids_onehot, dones, goals


def create_savedir_root(phase_tag: str, env_tag) -> Path:
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    savedir_root = Path(f'custom/results/{env_tag}/{phase_tag}') / timestamp
    os.makedirs(savedir_root, exist_ok=True)

    return savedir_root


def get_processed_data(dataset: str,
                       task_id_zero: bool,
                       filter_by_id_fn: Callable,
                       trans_into_source_obs: Callable,
                       trans_into_source_action: Callable,
                       n_traj: Optional[int] = None):
    task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id, goals = read_dataset(
        path=dataset,
        source_trans_fn=trans_into_source_obs,
        source_action_trans_fn=trans_into_source_action,
    )

    source_obs, source_next_obs, source_actions, target_obs, target_next_obs, target_actions, task_ids_onehot, dones, goals = filter_dataset(
        source_obs,
        target_obs,
        task_ids_onehot,
        source_actions,
        target_actions,
        dones,
        goals,
        task_ids,
        filter_by_id_fn=filter_by_id_fn,
        n_traj=n_traj,
    )
    # task_id = 0 for adaptation. Here it's False.
    if task_id_zero:
        task_ids_onehot = np.zeros_like(task_ids_onehot, dtype=np.float32)
    return source_obs, source_next_obs, source_actions, target_obs,\
            target_next_obs, target_actions, task_ids_onehot, goal_to_task_id, dones, goals


def prepare_dataset(args, filter_by_id_fn: Callable,
                    trans_into_source_obs: Callable,
                    trans_into_source_action: Callable,
                    dataset_concat_fn: Callable, task_id_zero: bool):
    logger.info('Start creating dataset...')
    source_obs, source_next_obs, source_actions, target_obs,\
        target_next_obs, target_actions, task_ids_onehot, goal_to_task_id, dones, goals = \
    get_processed_data(
        dataset=args.dataset,
        task_id_zero=task_id_zero,
        filter_by_id_fn=filter_by_id_fn,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=trans_into_source_action,
        n_traj=args.n_traj,
    )

    source_domain_id, target_domain_id = np.array(
        [[1, 0]], dtype=np.float32), np.array([[0, 1]], dtype=np.float32)
    obs_for_dataset, next_obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset = dataset_concat_fn(
        source_obs=source_obs,
        target_obs=target_obs,
        source_next_obs=source_next_obs,
        target_next_obs=target_next_obs,
        source_actions=source_actions,
        target_actions=target_actions,
        source_domain_id=source_domain_id,
        target_domain_id=target_domain_id,
        task_ids_onehot=task_ids_onehot)

    # create torch dataset
    obs_for_dataset = torch.from_numpy(obs_for_dataset)
    next_obs_for_dataset = torch.from_numpy(next_obs_for_dataset)
    cond_for_dataset = torch.from_numpy(cond_for_dataset)
    domains_for_dataset = torch.from_numpy(domains_for_dataset)
    actions_for_dataset = torch.from_numpy(actions_for_dataset)
    actions_mask = torch.ones_like(actions_for_dataset).bool()
    dataset = TensorDataset(
        obs_for_dataset,
        actions_for_dataset,
        next_obs_for_dataset,
        cond_for_dataset,
        domains_for_dataset,
        actions_mask,
    )
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    logger.info('Dataset has been successfully created.')
    return train_loader, goal_to_task_id


def get_success(obs, target, threshold=0.1):
    return np.linalg.norm(obs[0:2] - target) <= threshold


def make_dataloader(batch_size=5000, phase="train"):

    args = OmegaConf.create({
        "dataset": "datasets/maze2d/maze2d-medium-dense-v1.hdf5",
        "n_traj": None,
        "train_ratio": 0.9,
        "batch_size": batch_size,
    })

    TRAIN_TASK_IDS = list(range(1, 27))
    TRAIN_TASK_IDS.remove(7)
    TEST_TASK_IDS = [7]
    if phase == "train":
        TASK_IDS = TRAIN_TASK_IDS
    elif phase == "test":
        TASK_IDS = TEST_TASK_IDS
    else:
        raise ValueError

    def trans_into_source_obs(original_obs):
        # source_obs = original_obs[..., [1, 0, 3, 2]]
        source_obs = original_obs[..., ::-1]
        return source_obs

    def get_action_translator(action_type: str):
        assert action_type in ['normal',
                               'inv'], f'invalid action_type "{action_type}"'

        def trans_into_source_action(original_action):
            if action_type == 'normal':
                source_action = original_action
            else:
                source_action = -original_action[..., ::-1]

            return source_action

        return trans_into_source_action

    def dataset_concat_fn(
        source_obs: np.ndarray,
        target_obs: np.ndarray,
        source_next_obs: np.ndarray,
        target_next_obs: np.ndarray,
        source_actions: np.ndarray,
        target_actions: np.ndarray,
        source_domain_id: np.ndarray,
        target_domain_id: np.ndarray,
        task_ids_onehot: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

        def _create_domain_id_array(domain_id: np.ndarray, length: int):
            domain_id_array = np.tile(domain_id,
                                      (length, 1)).astype(np.float32)
            return domain_id_array

        source_domain_array = _create_domain_id_array(source_domain_id,
                                                      len(source_obs))
        target_domain_array = _create_domain_id_array(target_domain_id,
                                                      len(target_obs))

        obs_for_dataset = np.concatenate((source_obs, target_obs))
        next_obs_for_dataset = np.concatenate(
            (source_next_obs, target_next_obs))
        cond_for_dataset = np.concatenate((task_ids_onehot, task_ids_onehot))
        domains_for_dataset = np.concatenate(
            (source_domain_array, target_domain_array))
        actions_for_dataset = np.concatenate((source_actions, target_actions))
        return obs_for_dataset, next_obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset

    def _filter_by_id(task_ids: np.ndarray):
        select_flag = np.zeros_like(task_ids, dtype=bool)
        for id_ in TASK_IDS:
            select_flag |= (task_ids == id_)

        return select_flag

    train_loader, _ = prepare_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator("inv"),
        dataset_concat_fn=dataset_concat_fn,
        task_id_zero=False,
    )

    return train_loader


# def dataset_concat_fn(
#     source_obs: np.ndarray,
#     target_obs: np.ndarray,
#     source_next_obs: np.ndarray,
#     target_next_obs: np.ndarray,
#     source_actions: np.ndarray,
#     target_actions: np.ndarray,
#     source_domain_id: np.ndarray,
#     target_domain_id: np.ndarray,
#     task_ids_onehot: np.ndarray,
# ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

#     def _create_domain_id_array(domain_id: np.ndarray, length: int):
#         domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
#         return domain_id_array

#     source_domain_array = _create_domain_id_array(source_domain_id,
#                                                   len(source_obs))
#     target_domain_array = _create_domain_id_array(target_domain_id,
#                                                   len(target_obs))

#     obs_for_dataset = np.concatenate((source_obs, target_obs))
#     next_obs_for_dataset = np.concatenate((source_next_obs, target_next_obs))
#     cond_for_dataset = np.concatenate((task_ids_onehot, task_ids_onehot))
#     domains_for_dataset = np.concatenate(
#         (source_domain_array, target_domain_array))
#     actions_for_dataset = np.concatenate((source_actions, target_actions))
#     return obs_for_dataset, next_obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset


def create_dataset(dataset_path="datasets/maze2d/maze2d-medium-v1.hdf5",
                   mode="proxy",
                   dataset_size=1000000):

    PROXY_TASK_IDS = list(range(1, 27))
    ALL_TASK_IDS = list(range(1, 27))
    PROXY_TASK_IDS.remove(7)
    INFERENCE_TASK_IDS = [7]
    if mode == "proxy":
        TASK_IDS = PROXY_TASK_IDS
    elif mode == "inference":
        TASK_IDS = INFERENCE_TASK_IDS
    else:
        TASK_IDS = ALL_TASK_IDS

    def _filter_by_id(task_ids: np.ndarray):
        select_flag = np.zeros_like(task_ids, dtype=bool)
        for id_ in TASK_IDS:
            select_flag |= (task_ids == id_)

        return select_flag

    source_obs, source_next_obs, source_actions, target_obs,\
        target_next_obs, target_actions, task_ids_onehot, goal_to_task_id, dones, goals = \
    get_processed_data(
        dataset=dataset_path,
        task_id_zero=False,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=lambda x: x,
        trans_into_source_action=lambda x: x,
        n_traj=None,
    )
    print(source_obs.shape)

    goal_ids = np.argmax(task_ids_onehot, axis=1)

    np_data = {
        "observations": source_obs[:dataset_size],
        "actions": source_actions[:dataset_size],
        "next_observations": source_next_obs[:dataset_size],
        "terminals": dones[:dataset_size],
        "infos/goal_id": goal_ids[:dataset_size],
        "infos/goal": goals[:dataset_size],
    }

    dataset_path = Path(dataset_path)
    maze_type = dataset_path.stem.split("-")[1]
    new_dataset_path = f"datasets/maze2d/maze2d-{maze_type}-v1.hdf5"
    print(new_dataset_path)
    dataset = h5py.File(new_dataset_path, "w")
    for k in np_data:
        dataset.create_dataset(k, data=np_data[k], compression="gzip")
    dataset.close()


if __name__ == "__main__":
    create_dataset(
        dataset_path="./datasets/maze2d/maze2d-large-sparse-v1.hdf5",
        mode="",
        dataset_size=int(1e9),
    )
