from __future__ import annotations

import logging
import random
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

import gym
import h5py
import numpy as np
import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm

from common.utils.evaluate import TQDM_BAR_FORMAT, save_video
from common.utils.process_dataset import (
    ActionConverter, AntMazeTaskIDManager,
    ObservationConverter, PointMazeTaskIDManager, TorchStepDataset,
    TrajDataset, _convert_to_traj_array, convert_ndarray_list_to_obj_ndarray,
    convert_traj_dataset_to_step_dataset, filter_by_goal_id,
    get_action_converter, get_goal_candidates, get_obs_converter,
    get_task_id_manager, read_env_config_yamls,
    remove_single_step_trajectories, select_n_trajectories, train_val_split)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def get_next_obs(obs: np.ndarray, dones: np.ndarray) -> np.ndarray:
    next_obs = np.concatenate((obs[1:], obs[-1:]))
    next_obs = np.where(dones[..., None], obs, next_obs)
    return next_obs


@dataclass
class CDILTrajDataset:
    obs: np.ndarray
    actions: np.ndarray
    images: np.ndarray
    pos: np.ndarray
    task_ids: Optional[np.ndarray] = None
    n_task_id: int = -1
    domain_id: int = -1
    n_domain_id: int = -1

    def __len__(self):
        return len(self.obs)

    def __post_init__(self):
        assert len(self.obs) == len(self.actions)
        if self.task_ids is None:
            self.task_ids = np.ones(len(self), dtype=int) * -1

    @property
    def next_obs(self) -> np.ndarray:
        assert len(self.obs[0].shape) == 2

        ret = [self._get_next_obs_of_single_traj(obs) for obs in self.obs]
        ret = convert_ndarray_list_to_obj_ndarray(list_of_array=ret)

        return ret

    @property
    def next_images(self) -> np.ndarray:
        ret = [
            self._get_next_images_of_single_traj(images)
            for images in self.images
        ]
        ret = convert_ndarray_list_to_obj_ndarray(list_of_array=ret)

        return ret

    @staticmethod
    def _get_next_obs_of_single_traj(obs: np.ndarray):
        next_obs = np.concatenate((obs[1:], obs[-1:]))
        return next_obs

    @staticmethod
    def _get_next_images_of_single_traj(images: np.ndarray):
        next_images = np.concatenate((images[1:], images[-1:]))
        return next_images

    def get_onehot_task_id(self) -> np.ndarray:
        return np.eye(self.n_task_id)[self.task_ids].astype(np.float32)

    def add_domain_id(self, domain_id: int, n_domain_id: int):
        assert self.domain_id == self.n_domain_id == -1
        self.domain_id = domain_id
        self.n_domain_id = n_domain_id

    def apply_obs_converter(self, func: ObservationConverter):
        new_obs = np.array([func(tr) for tr in self.obs], dtype=object)
        self.obs = new_obs

    def apply_action_converter(self, func: ActionConverter):
        new_actions = np.array([func(tr) for tr in self.actions], dtype=object)
        self.actions = new_actions

    def __getitem__(self, item) -> CDILTrajDataset:
        assert self.task_ids is not None
        return CDILTrajDataset(
            obs=self.obs[item],
            actions=self.actions[item],
            pos=self.pos[item],
            task_ids=self.task_ids[item],
            n_task_id=self.n_task_id,
            domain_id=self.domain_id,
            n_domain_id=self.n_domain_id,
            images=self.images[item],
        )


@dataclass
class CDILStepDataset:
    obs: np.ndarray
    actions: np.ndarray
    next_obs: np.ndarray
    pos: np.ndarray
    images: np.ndarray
    next_images: np.ndarray
    task_ids: np.ndarray
    n_task_id: int
    domain_ids: np.ndarray
    n_domain_id: int

    def __len__(self):
        return len(self.obs)

    def get_onehot_task_id(self) -> np.ndarray:
        return np.eye(self.n_task_id)[self.task_ids].astype(np.float32)

    def get_onehot_domain_id(self) -> np.ndarray:
        return np.eye(self.n_domain_id)[self.domain_ids].astype(np.float32)

    def __getitem__(self, item) -> CDILStepDataset:
        return CDILStepDataset(
            obs=self.obs[item],
            actions=self.actions[item],
            next_obs=self.next_obs[item],
            pos=self.pos[item],
            task_ids=self.task_ids[item],
            n_task_id=self.n_task_id,
            domain_ids=self.domain_ids[item],
            n_domain_id=self.n_domain_id,
            images=self.images[item],
            next_images=self.next_images[item],
        )


class CDILTorchStepDataset(Dataset):

    def __init__(self, dataset: CDILStepDataset, obs_dim: int,
                 action_dim: int):
        self.data = dataset
        self.obs_dim = obs_dim
        self.action_dim = action_dim

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        data = self.data[item]
        images = data.images
        return {
            'observations': data.obs.astype(np.float32),
            'actions': data.actions.astype(np.float32),
            'next_observations': data.next_obs.astype(np.float32),
            'pos': data.pos.astype(np.float32),
            'images': images.astype(np.uint8),
            'next_images': data.next_images.astype(np.uint8),
            'task_ids': data.get_onehot_task_id(),
            'domain_ids': data.get_onehot_domain_id(),
        }


def convert_traj_dataset_to_task_id_wise_step_dataset(
        traj_dataset: CDILTrajDataset):
    max_task_id = np.max(traj_dataset.task_ids)

    datasets = {}
    for task_id in range(max_task_id + 1):
        mask = (traj_dataset.task_ids == task_id)
        if mask.sum() == 0:
            continue
        obs = np.concatenate(traj_dataset.obs[mask])
        actions = np.concatenate(traj_dataset.actions[mask])
        next_obs = np.concatenate(traj_dataset.next_obs[mask])
        pos = np.concatenate(traj_dataset.pos[mask])
        images = np.concatenate(traj_dataset.images[mask])
        task_ids = np.repeat([task_id], len(obs))
        domain_ids = np.ones(len(obs), dtype=int) * traj_dataset.domain_id

        step_dataset = CDILStepDataset(
            obs=obs,
            actions=actions,
            next_obs=next_obs,
            pos=pos,
            images=images,
            task_ids=task_ids,
            n_task_id=traj_dataset.n_task_id,
            domain_ids=domain_ids,
            n_domain_id=traj_dataset.n_domain_id,
        )
        datasets[task_id] = step_dataset

    return datasets


def calc_pos(dones: np.ndarray, gamma: float):
    dones[-1] = True
    done_idxs = np.where(dones)[0]
    start_idxs = np.concatenate([[0], done_idxs + 1])

    default_pos = np.array([gamma**i for i in range(1, 1000)])[::-1]
    pos = np.empty_like(dones, dtype=np.float32)

    for i in range(len(start_idxs) - 1):
        traj_len = start_idxs[i + 1] - start_idxs[i]
        pos[start_idxs[i]:start_idxs[i + 1]] = default_pos[-traj_len:]

    return pos


def read_dataset(
    path: Path,
    env_id: str,
    n_additional_tasks: int = 0,
    image_observation: Optional[bool] = False,
    domain_id: Optional[int] = -1,
    args: Optional[DictConfig] = None,
) -> Tuple[TrajDataset, Union[PointMazeTaskIDManager, AntMazeTaskIDManager]]:
    with h5py.File(path, "r") as f:
        observations = np.array(f['observations'])
        if "v2" in env_id:
            self_state = observations[:, :4]
            prev_self_state = observations[:, 18:22]
            observations = np.concatenate((self_state, prev_self_state),
                                          axis=-1)
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        if 'ant' in env_id or 'ood' in path:
            dones = np.array(f['terminals'])
        else:
            dones = np.array(f['timeouts'])

        if 'Lift' in env_id or 'Stack' in env_id or "reach-goal" in env_id or "reach-color" in env_id:
            goal_ids = np.array(f['infos/goal_id'])

        if 'Lift' in env_id or 'Stack' in env_id:
            robot_state = np.array(f['infos/robot0_proprio-state'])
            observations = robot_state

        if image_observation and "v2" in env_id:
            assert domain_id >= 0
            print("Start loading image...")
            import time
            start = time.time()
            camera_name = "corner3" if domain_id == 0 else "corner"
            images = np.array(f[f'infos/{camera_name}_image'])
            end = time.time()
            print(f"Loading takes {end - start:.2f} seconds.")
        elif image_observation:
            images = np.empty((len(observations), 128, 128, 6),
                              dtype=np.float16)
            images[..., :3] = np.array(f['infos/agentview_image'],
                                       dtype=np.float16)
            images[..., 3:] = np.array(f['infos/sideview_image'],
                                       dtype=np.float16)
        else:
            images = np.empty((len(observations), 0))

    pos = calc_pos(dones, args.gamma)

    obs_trajs = _convert_to_traj_array(arr=observations, dones=dones)
    action_trajs = _convert_to_traj_array(arr=actions, dones=dones)
    pos_trajs = _convert_to_traj_array(arr=pos, dones=dones)
    image_trajs = _convert_to_traj_array(arr=images, dones=dones)

    task_id_manager = get_task_id_manager(env_id=env_id)
    if 'Lift' in env_id or 'Stack' in env_id or "reach-goal" in env_id or "reach-color" in env_id:
        task_id_manager.set_goal_id(goal_ids)
        task_id_manager.set_dones(dones)

    data = TrajDataset(
        obs=obs_trajs,
        actions=action_trajs,
        images=image_trajs,
        n_task_id=task_id_manager.n_task_id + n_additional_tasks,
    )

    data = CDILTrajDataset(
        obs=data.obs,
        actions=data.actions,
        pos=pos_trajs,
        images=data.images,
        n_task_id=task_id_manager.n_task_id + n_additional_tasks,
    )

    return data, task_id_manager


def read_multi_dataset(
    args: DictConfig,
    domain_info: DictConfig,
    image_observation: Optional[bool] = False,
    goal_id_offset: int = 0,
) -> Tuple[TrajDataset, Union[PointMazeTaskIDManager, AntMazeTaskIDManager]]:
    data = defaultdict(list)
    v2_envs = [
        "reach-goal-v2", "reach-color-v2", "reach-color_simple_3-v2",
        "reach-color_simple_2-v2", "window-close_4-v2"
    ]
    for env_tag in domain_info.env_tags:
        env_info = domain_info[env_tag]
        env_id = env_info.env
        path = env_info.dataset
        with h5py.File(path, "r") as f:
            observations = np.array(f['observations'])
            if "v2" in env_id:
                self_state = observations[:, :4]
                prev_self_state = observations[:, 17:21]
                observations = np.concatenate((self_state, prev_self_state),
                                              axis=-1)
            actions = np.array(f['actions'])
            goals = np.array(f['infos/goal'])
            if 'ant' in env_id:
                dones = np.array(f['terminals']).astype(np.bool_)
            else:
                dones = np.array(f['timeouts']).astype(np.bool_)

            if 'Lift' in env_id or 'Stack' in env_id or env_id in v2_envs:
                goal_ids = np.array(f['infos/goal_id'])
            else:
                goal_ids = None

            if 'Lift' in env_id or 'Stack' in env_id:
                robot_state = np.array(f['infos/robot0_proprio-state'])
                observations = robot_state

            if image_observation:
                camera_name = "corner3" if domain_info.domain_id == 0 else "corner"
                images = np.array(f[f'infos/{camera_name}_image'],
                                  dtype=np.uint8)
            else:
                images = np.empty((len(observations), 0))

        if goal_ids is not None:  # Lift to Stack
            goal_ids += goal_id_offset
            if env_tag == domain_info.target_env:
                goal_ids -= args.target_goal_id
            mask = ((goal_id_offset <= goal_ids) &
                    (goal_ids < args.n_task_ids))
            goal_ids = goal_ids[mask]
            observations = observations[mask]
            actions = actions[mask]
            images = images[mask]
            dones = dones[mask]

            dones_ = dones.copy()
            dones_[-1] = True
            goal_ids = goal_ids[dones_]
            data["goal_ids"].extend(goal_ids)
            goal_id_offset += env_info.n_goals
        else:  # Meta-World
            goal_id = goal_id_offset
            dones_ = dones.copy()
            dones_[-1] = True
            data["goal_ids"].extend([goal_id] * dones_.sum())
            goal_id_offset += 1

        pos = calc_pos(dones, args.gamma)

        obs_trajs = _convert_to_traj_array(arr=observations, dones=dones)
        action_trajs = _convert_to_traj_array(arr=actions, dones=dones)
        image_trajs = _convert_to_traj_array(arr=images, dones=dones)
        pos_trajs = _convert_to_traj_array(arr=pos, dones=dones)

        data["obs_trajs"].extend(obs_trajs)
        data["action_trajs"].extend(action_trajs)
        data["image_trajs"].extend(image_trajs)
        data["pos_trajs"].extend(pos_trajs)

    for k, v in data.items():
        if "trajs" in k:
            data[k] = convert_ndarray_list_to_obj_ndarray(v)
        else:
            data[k] = np.array(v)

    dataset = CDILTrajDataset(
        obs=data["obs_trajs"],
        actions=data["action_trajs"],
        images=data["image_trajs"],
        task_ids=data["goal_ids"],
        pos=data["pos_trajs"],
        n_task_id=args.n_task_ids,
    )

    return dataset, None


def read_step_dateset(args, inference=False):
    train_step_datasets = defaultdict(dict)
    val_step_datasets = defaultdict(dict)
    task_id_managers = []

    if args.complex_task and inference:
        original_domains = args.domains
        args.domains = args.adapt_domains
        read_env_config_yamls(args)
        domains = args.domains[:-1]
    else:
        domains = args.domains[:-1] if inference else args.domains
    for domain_info in domains:
        domain_id = domain_info.domain_id
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                args=args,
                domain_info=domain_info,
                image_observation=args.image_observation,
            )
        else:
            dataset, task_id_manager = read_dataset(
                args=args,
                path=domain_info.dataset,
                env_id=domain_info.env,
                n_additional_tasks=1 if args.complex_task else 0,
                image_observation=args.image_observation,
                domain_id=domain_info.domain_id,
            )

        if inference and args.complex_task:
            task_id_manager.add_task_id_to_traj_dataset(
                dataset, task_id=dataset.n_task_id - 1)
        elif task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        goal_ids = [args.goal] if inference else args.train_goal_ids
        if not (inference and args.complex_task):
            dataset = filter_by_goal_id(dataset,
                                        goal_ids=goal_ids,
                                        task_id_manager=task_id_manager)
        dataset = select_n_trajectories(
            dataset, n_traj=args.adapt_n_traj if inference else args.n_traj)

        domain_specific_transform(args, domain_info, dataset)

        train_traj_dataset, val_traj_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)

        train_step_datasets[domain_id] = convert_step_dataset(
            args, train_traj_dataset)

        val_step_datasets[domain_id] = convert_step_dataset(
            args, val_traj_dataset)

        task_id_managers.append(task_id_manager)

    if args.complex_task and inference:
        # restore original domains
        args.domains = original_domains

    train_step_loader_dict = {
        domain_id: DataLoader(
            ds,
            batch_size=args.batch_size,
            num_workers=0,
            shuffle=True,
        )
        for domain_id, ds in train_step_datasets.items()
    }

    val_step_loader_dict = {
        domain_id: DataLoader(
            ds,
            batch_size=args.batch_size,
            num_workers=0,
            shuffle=False,
        )
        for domain_id, ds in val_step_datasets.items()
    }

    dataloader_dict: Dict[str, DataLoader] = {
        "train": train_step_loader_dict,
        "val": val_step_loader_dict,
    }

    return dataloader_dict, task_id_managers


def domain_specific_transform(args, domain_info, dataset):
    if obs_converter_name := domain_info.get("obs_converter"):
        obs_converter = get_obs_converter(name=obs_converter_name)
        dataset.apply_obs_converter(obs_converter)

    if action_converter_name := domain_info.get("action_converter"):
        action_converter = get_action_converter(name=action_converter_name)
        dataset.apply_action_converter(action_converter)

    dataset.add_domain_id(domain_id=domain_info.domain_id,
                          n_domain_id=args.n_domains)


def convert_task_id_wise_step_dataset(args, traj_dataset):
    step_dataset_dict = {}
    train_step_dataset_dict = convert_traj_dataset_to_task_id_wise_step_dataset(
        traj_dataset)
    for task_id, dataset in train_step_dataset_dict.items():
        train_step_dataset = CDILTorchStepDataset(
            dataset,
            obs_dim=args.max_obs_dim,
            action_dim=args.max_action_dim,
        )
        step_dataset_dict[task_id] = train_step_dataset
    return step_dataset_dict


def convert_step_dataset(args, traj_dataset):
    step_dataset = convert_traj_dataset_to_step_dataset(traj_dataset)
    step_dataset = CDILTorchStepDataset(
        step_dataset,
        obs_dim=args.max_obs_dim,
        action_dim=args.max_action_dim,
    )
    return step_dataset


def convert_traj_dataset_to_step_dataset(traj_dataset: CDILTrajDataset):
    obs = np.concatenate(traj_dataset.obs)
    actions = np.concatenate(traj_dataset.actions)
    next_obs = np.concatenate(traj_dataset.next_obs)
    next_images = np.concatenate(traj_dataset.next_images)
    pos = np.concatenate(traj_dataset.pos)
    images = np.concatenate(traj_dataset.images)
    length_list = [len(tr) for tr in traj_dataset.obs]
    task_ids = np.repeat(traj_dataset.task_ids, length_list)
    domain_ids = np.ones(len(obs), dtype=int) * traj_dataset.domain_id

    step_dataset = CDILStepDataset(
        obs=obs,
        actions=actions,
        next_obs=next_obs,
        pos=pos,
        images=images,
        task_ids=task_ids,
        n_task_id=traj_dataset.n_task_id,
        domain_ids=domain_ids,
        n_domain_id=traj_dataset.n_domain_id,
        next_images=next_images,
    )

    return step_dataset
