from __future__ import annotations

import logging
import multiprocessing as mp
import os
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from multiprocessing import Manager, Process
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

import comet_ml
import gym
import h5py
import imageio.v2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm

from common.ours.models import Policy
from common.utils.process_dataset import (AntMazeTaskIDManager,
                                          PairedTrajDataset,
                                          PointMazeTaskIDManager)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"


@dataclass
class PLPConfig:
    config: Optional[Path] = None
    domains: List[Dict] = field(default_factory=list)
    adapt_domains: List[Dict] = field(default_factory=list)

    n_domains: int = 2
    add_domain: bool = False
    complex_task: bool = False

    name: Optional[str] = None
    goal: int = 6
    n_tasks: int = -1
    n_traj: int = 1000
    train_ratio: float = 0.9
    batch_size: int = 256
    latent_dim: int = 192
    hid_dim: int = 192
    policy_num_layers: Tuple[int, int, int] = (3, 4, 3)
    decode_with_state: bool = False
    activation: str = 'gelu'
    repr_activation: str = 'gelu'
    z_norm: str = ''
    target_adapt: bool = False
    device: str = 'cuda:0'

    mmd_coef: float = 0.1
    mmd_kernel: str = 'gaussian'  # 'gaussian', 'laplacian', 'l2'
    mmd_sigma: float = 1.0
    mmd_norm: bool = True
    mmd_linear: bool = False

    tcc: bool = True
    tcc_coef: float = 0.2
    tcc_frame_interval: int = 16
    tcc_batch_size: int = 64
    tcc_prob: float = 1.0
    tcc_validation: bool = False

    # r2r params
    r2r_single_layer: bool = False

    comet: bool = False
    bc: bool = True
    naive_bc: bool = False

    hausdorff_coef: float = 0.0
    hausdorff_soft: bool = False
    hausdorff_norm: bool = True

    norm_hinge_coef: float = 0.0
    distreg_coef: float = 0.0

    # adversarial
    adversarial_coef: float = 0.0
    disc_num_layers: int = 4
    disc_lr: float = 5e-4

    # State-only demonstration
    state_pred: bool = False  # bc loss is invalidated when idm_coef > 0.0
    next_state_reg_coef: float = 1.0
    idm_coef: float = 1.0

    discrete: bool = False
    discrete_bins: int = 21

    # Training parameters
    epochs: int = 20
    eval_interval: int = 5
    n_eval_episodes: int = 20
    n_render_episodes: int = 10

    lr: float = 5e-4
    adapt_lr: Optional[float] = None
    weight_decay: float = 0.0

    adapt_epochs: int = 50
    adapt_eval_interval: int = 25
    adapt_n_eval_episodes: int = 50
    adapt_n_traj: int = -1
    adapt_with_all_tasks: bool = True
    adapt_align_data_size_rate: float = 1.0  # use len(adapt_data) * rate align data

    state_noise: float = 0.0
    action_noise: float = 0.0

    image_observation: bool = False
    image_state_dim: int = 1024
    input_image_state_into_decoder: bool = True
    use_image_decoder: bool = False
    image_recon_coef: float = 1.0
    use_coord_conv: bool = True
    evaluate: bool = True
    evaluate_parallel: bool = False  # avoid bug in parallelization of P2A
    amp: bool = False
    num_data_workers: int = 0
    robot: bool = False

    multienv: bool = False
    n_task_ids: Optional[int] = None
    target_goal_id: Optional[int] = None
    task_id_offset_list: Optional[List[int]] = None
    target_task_id_offset: Optional[int] = None

    # set in the script
    max_obs_dim: Optional[int] = None
    max_action_dim: Optional[int] = None
    max_seq_len: Optional[int] = None
    train_goal_ids: Optional[List[int]] = None
    savedir_root: Optional[str] = None


def visualize_move(env, obs, goals):
    for ob, goal in zip(obs, goals):
        env.reset_to_location(ob[:2], no_noise=True)
        env.set_target(goal)
        env.set_marker()
        env.render()


def convert_obs_to_img_obs(env,
                           obs: np.ndarray,
                           resolution: Tuple[int, int] = (64, 64)):
    # obs = [N, D] return [N, H, W, C]
    # TODO currently only support maze-medium

    _ = env.render('rgb_array')
    cam = env.env.viewer.cam
    cam.elevation = -90
    cam.distance = 8
    cam.lookat[0] += 0.5
    cam.lookat[1] += 0.5

    imgs = []
    for i, ob in enumerate(tqdm(obs)):
        env.reset_to_location(ob[:2], no_noise=True)
        rgb = env.render('rgb_array')
        img = Image.fromarray(rgb).resize(size=resolution)
        imgs.append(np.array(img))

    return np.array(imgs)


def save_model(model: torch.nn.Module, path: Path, epoch=None):
    os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), path / 'model.pt')
    logger.info(f'The model is saved to {path}')

    if epoch:
        with open(path / 'epoch.txt', 'w') as f:
            f.write(str(epoch))


@dataclass
class DataContainer:
    obs: np.ndarray
    actions: np.ndarray
    dones: np.ndarray
    goal_ids: np.ndarray
    start_ids: Optional[np.ndarray] = None
    task_ids: Optional[np.ndarray] = None
    max_task_id: Optional[int] = None

    @property
    def next_obs(self) -> np.ndarray:
        return get_next_obs(obs=self.obs, dones=self.dones)

    def get_onehot_task_id(self, max_task_id: int) -> np.ndarray:
        return np.eye(max_task_id)[self.task_ids].astype(np.float32)

    def calculate_task_ids_with_starts(self, env_id: str) -> int:
        env = gym.make(env_id)
        self.dones[-1] = True
        start_bit = np.concatenate(
            ([False], self.dones))[:-1]  # other than idx 0
        traj_ids = np.cumsum(start_bit)
        start_idx = np.concatenate(([0], np.where(start_bit)[0]))
        starts = self.obs[start_idx][..., :2]
        start_pos_ids = np.array([env.pos_to_start_id(pos) for pos in starts])
        self.start_ids = start_pos_ids[traj_ids]

        assert len(start_idx) == self.dones.sum()

        n_start_ids = env.n_start_groups
        n_goals = len(env.id_to_xy)
        self.max_task_id = n_goals * env.n_start_groups
        logger.info(
            f'The maze has {env.n_start_groups} start position groups. '
            f'Totally {n_goals * env.n_start_groups} tasks.')

        self.task_ids = n_start_ids * self.goal_ids + self.start_ids
        assert self.task_ids.max() < self.max_task_id

    def __getitem__(self, item) -> DataContainer:
        return DataContainer(
            obs=self.obs[item],
            actions=self.actions[item],
            dones=self.dones[item],
            goal_ids=self.goal_ids[item],
            start_ids=self.start_ids[item],
            task_ids=self.task_ids[item],
            max_task_id=self.max_task_id,
        )


def read_dataset(
    path: Path,
    env_id: str,
    trans_observations_fn: Optional[Callable] = None,
    trans_actions_fn: Optional[Callable] = None,
) -> DataContainer:
    with h5py.File(path, "r") as f:
        observations = np.array(f['observations'])
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        if 'ant' in env_id:
            dones = np.array(f['terminals'])
        else:
            dones = np.array(f['timeouts'])

    env = gym.make(env_id)
    goals = goals.round().astype(int)
    goal_ids = np.array([env.xy_to_id[tuple(goal)] for goal in goals])
    n_goals = len(env.xy_to_id)
    logger.info(f'Goal position to goal ID ({n_goals} goals) : {env.xy_to_id}')

    data = DataContainer(obs=observations,
                         actions=actions,
                         dones=dones,
                         goal_ids=goal_ids)
    data.calculate_task_ids_with_starts(env_id=env_id)

    if trans_observations_fn:
        data.obs = trans_observations_fn(data.obs).astype(np.float32)

    if trans_actions_fn:
        data.actions = trans_actions_fn(data.actions).astype(np.float32)

    return data


def split_dataset(dataset: torch.utils.data.Dataset, train_ratio: float,
                  batch_size: int):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return train_loader, val_loader


@dataclass
class CheckPointer:
    n_epochs: int
    savedir_root: Path
    best_val_loss: float = 1e9

    def save_if_necessary(self, policy: Policy, val_loss: float, epoch: int):
        if epoch == self.n_epochs - 1:
            save_model(model=policy,
                       path=self.savedir_root /
                       f'last_epoch{self.n_epochs:04d}')

        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            save_model(model=policy,
                       path=self.savedir_root / 'best',
                       epoch=epoch)


def _remove_single_step_trajectories(data: DataContainer, ) -> DataContainer:
    done_idx = np.concatenate(([-1], np.where(data.dones)[0]))
    traj_lens = done_idx[1:] - done_idx[:-1]
    single_step_traj_ids = np.where(traj_lens == 1)[0]

    valid_steps = np.ones(len(data.dones), dtype=bool)
    for single_step_traj_id in single_step_traj_ids:
        invalid_step_idx = done_idx[single_step_traj_id + 1]
        valid_steps[invalid_step_idx] = False

    data = data[valid_steps]
    return data


def _select_n_trajecotries(data: DataContainer, n: int):
    done_idx = np.concatenate(([-1], np.where(data.dones)[0]))
    start_bit = np.concatenate(([False], data.dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    selected_traj_ids = np.random.choice(range(traj_ids.max() + 1),
                                         n,
                                         replace=False)

    selected_steps = np.zeros(len(traj_ids))
    for selected_id in selected_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_steps):
            selected_steps[gl_idx + 1] -= 1

    selected_steps = np.cumsum(selected_steps).astype(bool)
    data = data[selected_steps]
    return data


def filter_dataset(
    data: DataContainer,
    filter_by_goal_id_fn: Callable,
    n_traj: Optional[int] = None,
):
    data.dones[-1] = True
    logger.info(f'At first, {data.dones.sum()} trajectories available.')

    select_flag = filter_by_goal_id_fn(data.goal_ids)
    data = data[select_flag]
    task_id_list = np.unique(data.goal_ids)
    logger.info(
        f'After goal ID filtering, {data.dones.sum()} trajectories remain.')
    logger.info(
        f'{len(task_id_list)} goals are going to be used. (ID={task_id_list})')

    # remove single_step trajectory
    data = _remove_single_step_trajectories(data=data)
    logger.info(
        f'After removing single-step trajectories, {data.dones.sum()} trajectories remain.'
    )

    # select final trajectories
    if n_traj:
        data = _select_n_trajecotries(data=data, n=n_traj)
        assert data.dones.sum() == n_traj
    logger.info(f'Finally, {data.dones.sum()} trajectories are selected.')

    return data


def create_savedir_root(
    phase_tag: str,
    name: str = '',
    source_env: str = '',
    target_env: str = '',
) -> Path:
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    random_no = np.random.randint(low=0, high=int(1e6) - 1)
    if source_env and target_env:
        savedir_root = Path(f'results/'
                            f'{source_env}_{target_env}/{name}'
                            ) / f'{timestamp}-{random_no:06d}'
    else:
        savedir_root = Path(f'results/{name}') / f'{timestamp}-{random_no:06d}'
    os.makedirs(savedir_root, exist_ok=True)

    return savedir_root


def get_next_obs(obs: np.ndarray, dones: np.ndarray) -> np.ndarray:
    next_obs = np.concatenate((obs[1:], obs[-1:]))
    next_obs = np.where(dones[..., None], obs, next_obs)
    return next_obs


def get_processed_data(
    source_dataset_path: Path,
    target_dataset_path: Path,
    filter_by_goal_id_fn: Callable,
    trans_into_source_obs: Callable,
    trans_into_source_action: Callable,
    n_traj: Optional[int] = None,
    source_ant_flag: bool = False,
    target_ant_flag: bool = False,
) -> Tuple[DataContainer, DataContainer, Dict]:

    source_data, goal_to_task_id = read_dataset(
        path=source_dataset_path,
        trans_observations_fn=trans_into_source_obs,
        trans_actions_fn=trans_into_source_action,
        ant_flag=source_ant_flag,
    )

    target_data, _ = read_dataset(
        path=target_dataset_path,
        trans_observations_fn=None,
        trans_actions_fn=None,
        ant_flag=target_ant_flag,
    )

    source_data = filter_dataset(
        data=source_data,
        filter_by_goal_id_fn=filter_by_goal_id_fn,
        n_traj=n_traj // 2 if n_traj is not None else None,
    )

    target_data = filter_dataset(
        data=target_data,
        filter_by_goal_id_fn=filter_by_goal_id_fn,
        n_traj=n_traj // 2 if n_traj is not None else None,
    )

    return source_data, target_data, goal_to_task_id


def _create_domain_id_array(domain_id: int, domain_id_dim: int,
                            length: int) -> np.ndarray:
    domain_id = np.eye(domain_id_dim)[domain_id][None]
    domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
    return domain_id_array


def create_dataset_from_data_container(data: DataContainer, domain_id: int,
                                       domain_id_dim: int) -> DictDataset:

    domain_ids = _create_domain_id_array(domain_id=domain_id,
                                         domain_id_dim=domain_id_dim,
                                         length=len(data.obs))
    numpy_data = {
        'observations': data.obs,
        'actions': data.actions,
        'domain_ids': domain_ids,
        'task_ids': data.get_onehot_task_id(max_task_id=data.max_task_id),
        'next_observations': data.next_obs,
    }
    dataset = DictDataset(data=numpy_data)
    return dataset


def prepare_dataset(
    args,
    query_dict: Dict,
    filter_by_goal_id_fn: Callable,
    batch_size: int,
    domain_id_dim: int = 2,
    task_id_zero: bool = False,
):
    dataloader_dict = defaultdict(dict)

    for domain_name, domain_query in query_dict.items():
        domain_data = read_dataset(
            path=domain_query['path'],
            env_id=domain_query['env_id'],
            trans_observations_fn=domain_query['trans_observations_fn'],
            trans_actions_fn=domain_query['trans_actions_fn'],
        )

        domain_data = filter_dataset(
            data=domain_data,
            filter_by_goal_id_fn=filter_by_goal_id_fn,
            n_traj=domain_query['n_traj'],
        )

        if task_id_zero:
            domain_data.task_ids = np.zeros_like(domain_data.task_ids)

        train_data, val_data = train_val_split_trajs(
            data=domain_data, train_ratio=args.train_ratio)
        train_dataset = create_dataset_from_data_container(
            data=train_data,
            domain_id=domain_query['domain_id'],
            domain_id_dim=domain_id_dim,
        )
        val_dataset = create_dataset_from_data_container(
            data=val_data,
            domain_id=domain_query['domain_id'],
            domain_id_dim=domain_id_dim,
        )
        dataloader_dict['train'][domain_name] = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True)
        dataloader_dict['validation'][domain_name] = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=True)

    logger.info('Dataset has been successfully created.')
    return dataloader_dict


def get_success(obs, target, threshold=0.1):
    return np.linalg.norm(obs[0:2] - target) <= threshold


def eval_policy(
    env: gym.Env,
    policy: Policy,
    device,
    source_trans_fn: Callable,
    source_action_type: str = 'normal',
    times=10,
    target_center=(1, 1),
    source_flag=False,
    render_episodes: int = 0,
    video_path: Path = '.',
    experiment: Optional[comet_ml.Experiment] = None,
):
    assert source_action_type in ['normal', 'inv']

    # if source_flag == True,  x and y will be inverted.
    results = []
    steps = []
    task_id_dim = policy.cond_dim
    images = []

    for t in range(times):
        obs = env.reset()
        target_location = np.array(target_center) + env.np_random.uniform(
            low=-.1, high=.1, size=env.model.nq)
        env.set_target(target_location)

        while get_success(obs=obs, target=env.get_target(), threshold=0.5):
            obs = env.reset()

        source_domain_id = torch.from_numpy(np.array([[1, 0]])).to(device)
        target_domain_id = torch.from_numpy(np.array([[0, 1]])).to(device)

        done = False
        step = 0
        success = False
        while not done:
            step += 1
            if source_flag:
                state_input = source_trans_fn(obs)[None].astype(np.float32)
                domain_input = source_domain_id
            else:  # normal evaluation
                state_input = obs[None].astype(np.float32)
                domain_input = target_domain_id
            state_input = torch.from_numpy(state_input).to(device)
            task_id = torch.zeros((1, task_id_dim),
                                  dtype=torch.float32,
                                  device=device)
            with torch.no_grad():
                action, _, _ = policy(
                    s=state_input,
                    c=task_id,
                    d=domain_input,
                )
                action = action.detach().cpu().numpy()[0]

            if source_flag and source_action_type == 'inv':
                action = -action

            obs, reward, done, _ = env.step(action)

            if t < render_episodes:
                images.append(env.render('rgb_array'))

            if get_success(obs=obs, target=env.get_target()):
                success = True
                if t < render_episodes:
                    for _ in range(10):
                        images.append(env.render('rgb_array'))
                break

        results.append(success)
        steps.append(step)
        logger.debug(f'Trial {t}: success={success}; steps={step}')

    success_rate = np.array(results).mean()
    steps_mean = np.array(steps).mean()

    if images:
        save_video(path=video_path,
                   images=images,
                   fps=20,
                   skip_rate=5,
                   experiment=experiment)

    return success_rate, steps_mean


def save_video(path: Path,
               images: List[np.ndarray],
               fps: int = 20,
               skip_rate: int = 5,
               experiment: Optional[comet_ml.Experiment] = None):
    os.makedirs(path.parent, exist_ok=True)

    # imageio.mimsave(video_path, images[1::5], fps=20)
    # logger.info(f'video is saved to {video_path}')
    # gifsicle(sources=str(video_path),
    #          destination=str(video_path),
    #          optimize=False,
    #          colors=256,
    #          options=['--optimize=3'])
    # logger.info(f'video has been successfully compressed.')

    mp4_path = path.with_suffix('.mp4')
    imageio.mimsave(mp4_path, images[1::skip_rate], fps=fps)
    logger.info(f'video is saved to {mp4_path}')
    if experiment:
        # experiment.log_image(video_path,
        #                      name=video_path.parent.name,
        #                      step=int(video_path.stem))
        experiment.log_asset(mp4_path,
                             file_name=path.parent.name,
                             step=int(path.stem))


def train_val_split_trajs(data: DataContainer, train_ratio: float):
    done_idx = np.concatenate(([-1], np.where(data.dones)[0]))
    start_bit = np.concatenate(([False], data.dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    total_size = traj_ids.max() + 1
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    logger.info(f'Train: {train_size} trajs; Val: {val_size} trajs.')

    train_traj_ids = np.random.choice(range(total_size),
                                      train_size,
                                      replace=False)

    selected_train_steps = np.zeros(len(traj_ids))
    for selected_id in train_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_train_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_train_steps):
            selected_train_steps[gl_idx + 1] -= 1

    selected_train_steps = np.cumsum(selected_train_steps).astype(bool)

    train_data = data[selected_train_steps]
    val_data = data[~selected_train_steps]
    return train_data, val_data


def task_id_to_target_pos(
    goal_to_task_id: Dict,
    task_id: int,
) -> Tuple[int, int]:
    target_pos = None
    for pos, id_ in goal_to_task_id.items():
        if id_ == task_id:
            target_pos = pos
    assert target_pos is not None

    return target_pos


class DictDataset(Dataset):

    def __init__(self, data=Dict[str, np.ndarray]):
        """Create a Dataset from a dict of ndarrays with the same length.

        Args:
            data: {'key1': data1, 'key2', data2, ...}.

        """
        super().__init__()
        self.data = self._convert_to_tensor_dict(data=data)
        lens = [len(val) for val in data.values()]
        assert all([lens[0] == l for l in lens
                    ]), 'values in a single dataset should have same length.'
        self.len = lens[0]

    @staticmethod
    def _convert_to_tensor_dict(
            data: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
        return {key: torch.from_numpy(val) for key, val in data.items()}

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return {key: val[index] for key, val in self.data.items()}


def ant_to_maze2d(xy_array: np.ndarray) -> np.ndarray:
    xy_array = xy_array[:, [1, 0]] / 4 + 1
    return xy_array


def maze2d_to_ant(xy_array: np.ndarray) -> np.ndarray:
    xy_array = ((xy_array - 1) * 4)[:, [1, 0]]
    return xy_array


def get_env_conf(env_id: str) -> Tuple[int, int, int, int]:
    """Get obs_dim, action_dim, and n_goals
    """

    if 'ant' in env_id:
        obs_dim = 29
        action_dim = 8
    elif 'maze2d' in env_id:
        obs_dim = 4
        action_dim = 2
    else:
        raise ValueError(f'Unknown env {env_id}')

    if 'umaze' in env_id:
        n_goals = 7
    elif 'medium' in env_id:
        n_goals = 26
    else:
        raise ValueError(f'Unknown env {env_id}')

    if 'ant' in env_id and 'umaze' in env_id:
        seq_len = 350
    elif 'ant' in env_id and 'medium' in env_id:
        seq_len = 600
    elif 'maze2d' in env_id and 'umaze' in env_id:
        seq_len = 250
    elif 'maze2d' in env_id and 'medium' in env_id:
        seq_len = 400
    else:
        raise ValueError(f'Unknown env {env_id}')

    return obs_dim, action_dim, n_goals, seq_len


def _assert_contain_same_str(str1: str, str2: str, keyword: str):
    if keyword in str1:
        assert keyword in str2


def check_dataset_and_env_consistency(dataset_name: str, env_id: str):
    _assert_contain_same_str(env_id, dataset_name, 'ant')
    _assert_contain_same_str(env_id, dataset_name, 'maze2d')
    _assert_contain_same_str(env_id, dataset_name, 'umaze')
    _assert_contain_same_str(env_id, dataset_name, 'medium')


def check_source_and_target_consistency(
    source_name: str,
    target_name: str,
):
    _assert_contain_same_str(source_name, target_name, 'umaze')
    _assert_contain_same_str(source_name, target_name, 'medium')


def maze2d_trans_into_source_obs(original_obs):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    return source_obs


def maze2d_get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action
