import logging
import os
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple

import comet_ml
import gym
import h5py
import imageio.v2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from common.models import Policy
from PIL import Image
from pygifsicle import gifsicle
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from tqdm import tqdm

from custom.transformer_modules import TransformerPredictor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def visualize_move(env, obs, goals):
    for ob, goal in zip(obs, goals):
        env.reset_to_location(ob[:2], no_noise=True)
        env.set_target(goal)
        env.set_marker()
        env.render()


def convert_obs_to_img_obs(env,
                           obs: np.ndarray,
                           resolution: Tuple[int, int] = (64, 64)):
    # obs = [N, D] return [N, H, W, C]
    # TODO currently only support maze-medium

    _ = env.render('rgb_array')
    cam = env.env.viewer.cam
    cam.elevation = -90
    cam.distance = 8
    cam.lookat[0] += 0.5
    cam.lookat[1] += 0.5

    imgs = []
    for i, ob in enumerate(tqdm(obs)):
        env.reset_to_location(ob[:2], no_noise=True)
        rgb = env.render('rgb_array')
        img = Image.fromarray(rgb).resize(size=resolution)
        imgs.append(np.array(img))

    return np.array(imgs)


def save_model(model: torch.nn.Module, path: Path, epoch=None):
    os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), path / 'model.pt')
    logger.info(f'The model is saved to {path}')

    if epoch:
        with open(path / 'epoch.txt', 'w') as f:
            f.write(str(epoch))


def read_dataset(path: str,
                 source_trans_fn: Callable,
                 source_action_trans_fn: Callable,
                 shift: bool = False):
    with h5py.File(path, 'r') as f:
        obs = np.array(f['observations'])
        actions = np.array(f['actions'])
        goals = np.array(f['infos/goal'])
        dones = np.array(f['timeouts'])

    goals = goals.round().astype(int)
    goal_list = np.unique(goals, axis=0)
    goal_to_task_id = {tuple(goal): i + 1 for i, goal in enumerate(goal_list)}
    logger.info(f'Goal to task ID: {goal_to_task_id}')
    task_ids = np.array([goal_to_task_id[tuple(goal)] for goal in goals])
    task_ids_onehot = np.eye(task_ids.max() + 1)[task_ids].astype(np.float32)

    # x <-> y and vx <-> vy are swapped in the source domain
    source_obs = source_trans_fn(obs, shift=shift).astype(np.float32)
    source_actions = source_action_trans_fn(actions).astype(np.float32)
    target_obs = obs.astype(np.float32)
    target_actions = actions.astype(np.float32)

    return task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id


def split_dataset(dataset: torch.utils.data.TensorDataset, train_ratio: float,
                  batch_size: int):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return train_loader, val_loader


@dataclass
class CheckPointer:
    n_epochs: int
    savedir_root: Path
    best_val_loss: float = 1e9

    def save_if_necessary(self, policy: Policy, val_loss: float, epoch: int):
        if epoch == self.n_epochs - 1:
            save_model(model=policy,
                       path=self.savedir_root /
                       f'last_epoch{self.n_epochs:04d}')

        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            save_model(model=policy,
                       path=self.savedir_root / 'best',
                       epoch=epoch)


def _remove_single_step_trajectories(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    dones: np.ndarray,
    task_ids: np.ndarray,
):
    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    traj_lens = done_idx[1:] - done_idx[:-1]
    single_step_traj_ids = np.where(traj_lens == 1)[0]

    valid_steps = np.ones(len(dones), dtype=bool)
    for single_step_traj_id in single_step_traj_ids:
        invalid_step_idx = done_idx[single_step_traj_id + 1]
        valid_steps[invalid_step_idx] = False

    source_obs = source_obs[valid_steps]
    target_obs = target_obs[valid_steps]
    task_ids_onehot = task_ids_onehot[valid_steps]
    source_actions = source_actions[valid_steps]
    target_actions = target_actions[valid_steps]
    dones = dones[valid_steps]
    task_ids = task_ids[valid_steps]

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids


def _select_n_trajecotries(source_obs: np.ndarray, target_obs: np.ndarray,
                           task_ids_onehot: np.ndarray,
                           source_actions: np.ndarray,
                           target_actions: np.ndarray, dones: np.ndarray,
                           task_ids: np.ndarray, n: int):

    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    start_bit = np.concatenate(([False], dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    selected_traj_ids = np.random.choice(range(traj_ids.max() + 1),
                                         n,
                                         replace=False)

    selected_steps = np.zeros(len(traj_ids))
    for selected_id in selected_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_steps):
            selected_steps[gl_idx + 1] -= 1

    selected_steps = np.cumsum(selected_steps).astype(bool)

    source_obs = source_obs[selected_steps]
    target_obs = target_obs[selected_steps]
    task_ids_onehot = task_ids_onehot[selected_steps]
    source_actions = source_actions[selected_steps]
    target_actions = target_actions[selected_steps]
    dones = dones[selected_steps]
    task_ids = task_ids[selected_steps]

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids


def filter_dataset(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    dones: np.ndarray,
    task_ids: np.ndarray,
    filter_by_id_fn: Callable,
    n_traj: Optional[int] = None,
):
    dones[-1] = True
    logger.info(f'At first, {dones.sum()} trajectories available.')

    select_flag = filter_by_id_fn(task_ids)
    source_obs = source_obs[select_flag]
    target_obs = target_obs[select_flag]
    task_ids_onehot = task_ids_onehot[select_flag]
    source_actions = source_actions[select_flag]
    target_actions = target_actions[select_flag]
    dones = dones[select_flag]
    task_ids = task_ids[select_flag]
    task_id_list = np.unique(task_ids)
    logger.info(f'After task ID filtering, {dones.sum()} trajectories remain.')
    logger.info(
        f'{len(task_id_list)} tasks are going to be used. (ID={task_id_list})')

    # remove single_step trajectory
    source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids = _remove_single_step_trajectories(
        source_obs=source_obs,
        target_obs=target_obs,
        task_ids_onehot=task_ids_onehot,
        source_actions=source_actions,
        target_actions=target_actions,
        dones=dones,
        task_ids=task_ids,
    )
    logger.info(
        f'After removing single-step trajectories, {dones.sum()} trajectories remain.'
    )

    # select final trajectories
    if n_traj:
        source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids = _select_n_trajecotries(
            source_obs=source_obs,
            target_obs=target_obs,
            task_ids_onehot=task_ids_onehot,
            source_actions=source_actions,
            target_actions=target_actions,
            dones=dones,
            task_ids=task_ids,
            n=n_traj,
        )

    logger.info(f'Finally, {dones.sum()} trajectories are selected.')
    if n_traj:
        assert dones.sum() == n_traj

    return source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids


def create_savedir_root(phase_tag: str, env_tag) -> Path:
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    random_no = np.random.randint(low=0, high=int(1e6) - 1)
    savedir_root = Path(f'custom/results/{env_tag}/{phase_tag}'
                        ) / f'{timestamp}-{random_no:06d}'
    os.makedirs(savedir_root, exist_ok=True)

    return savedir_root


def get_processed_data(dataset: str,
                       task_id_zero: bool,
                       filter_by_id_fn: Callable,
                       trans_into_source_obs: Callable,
                       trans_into_source_action: Callable,
                       n_traj: Optional[int] = None,
                       shift: bool = False):
    task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id = read_dataset(
        path=dataset,
        source_trans_fn=trans_into_source_obs,
        source_action_trans_fn=trans_into_source_action,
        shift=shift,
    )

    source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids = filter_dataset(
        source_obs,
        target_obs,
        task_ids_onehot,
        source_actions,
        target_actions,
        dones,
        task_ids,
        filter_by_id_fn=filter_by_id_fn,
        n_traj=n_traj,
    )

    source_next_obs = np.concatenate((source_obs[1:], source_obs[-1:]))
    target_next_obs = np.concatenate((target_obs[1:], target_obs[-1:]))
    source_next_obs = np.where(dones[..., None], source_obs, source_next_obs)
    target_next_obs = np.where(dones[..., None], target_obs, target_next_obs)

    # task_id = 0 for adaptation. Here it's False.
    if task_id_zero:
        task_ids_onehot = np.zeros_like(task_ids_onehot, dtype=np.float32)
    return source_obs, source_actions, target_obs, target_actions, task_ids_onehot, goal_to_task_id, source_next_obs, target_next_obs


def prepare_dataset(args, filter_by_id_fn: Callable,
                    trans_into_source_obs: Callable,
                    trans_into_source_action: Callable,
                    dataset_concat_fn: Callable, task_id_zero: bool):
    logger.info('Start creating dataset...')
    source_obs, source_actions, target_obs, target_actions, task_ids_onehot, goal_to_task_id, source_next_obs, target_next_obs, = get_processed_data(
        dataset=args.dataset,
        task_id_zero=task_id_zero,
        filter_by_id_fn=filter_by_id_fn,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=trans_into_source_action,
        n_traj=args.n_traj,
        shift=args.shift,
    )

    source_domain_id, target_domain_id = np.array(
        [[1, 0]], dtype=np.float32), np.array([[0, 1]], dtype=np.float32)
    obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset = dataset_concat_fn(
        source_obs=source_obs,
        target_obs=target_obs,
        source_actions=source_actions,
        target_actions=target_actions,
        source_domain_id=source_domain_id,
        target_domain_id=target_domain_id,
        task_ids_onehot=task_ids_onehot,
        source_next_obs=source_next_obs,
        target_next_obs=target_next_obs,
        source_only=args.source_only
        if hasattr(args, 'source_only') else False,
        target_only=args.target_only
        if hasattr(args, 'target_only') else False,
    )

    # create torch dataset
    obs_for_dataset = torch.from_numpy(obs_for_dataset)
    cond_for_dataset = torch.from_numpy(cond_for_dataset)
    domains_for_dataset = torch.from_numpy(domains_for_dataset)
    actions_for_dataset = torch.from_numpy(actions_for_dataset)
    next_obs_for_dataset = torch.from_numpy(next_obs_for_dataset)
    action_masks_for_dataset = torch.ones_like(actions_for_dataset)
    dataset = TensorDataset(obs_for_dataset, cond_for_dataset,
                            domains_for_dataset, actions_for_dataset,
                            next_obs_for_dataset, action_masks_for_dataset)
    train_loader, val_loader = split_dataset(dataset=dataset,
                                             train_ratio=args.train_ratio,
                                             batch_size=args.batch_size)
    logger.info('Dataset has been successfully created.')
    return train_loader, val_loader, goal_to_task_id


def get_success(obs, target, threshold=0.1):
    return np.linalg.norm(obs[0:2] - target) <= threshold


def eval_policy(
    env: gym.Env,
    policy: Policy,
    device,
    source_trans_fn: Callable,
    source_action_type: str = 'normal',
    times=10,
    target_center=(1, 1),
    source_flag=False,
    render_episodes: int = 0,
    video_path: Path = '.',
    experiment: Optional[comet_ml.Experiment] = None,
    shift: bool = False,
):
    assert source_action_type in ['normal', 'inv']

    # if source_flag == True,  x and y will be inverted.
    results = []
    steps = []
    task_id_dim = policy.cond_dim
    images = []

    for t in range(times):
        obs = env.reset()
        target_location = np.array(target_center) + env.np_random.uniform(
            low=-.1, high=.1, size=env.model.nq)
        env.set_target(target_location)

        while get_success(obs=obs, target=env.get_target(), threshold=0.5):
            obs = env.reset()

        source_domain_id = torch.from_numpy(np.array([[1, 0]])).to(device)
        target_domain_id = torch.from_numpy(np.array([[0, 1]])).to(device)

        done = False
        step = 0
        success = False
        while not done:
            step += 1
            if source_flag:
                state_input = source_trans_fn(obs, shift=shift)[None].astype(
                    np.float32)
                domain_input = source_domain_id
            else:  # normal evaluation
                state_input = obs[None].astype(np.float32)
                domain_input = target_domain_id
            state_input = torch.from_numpy(state_input).to(device)
            task_id = torch.zeros((1, task_id_dim),
                                  dtype=torch.float32,
                                  device=device)
            with torch.no_grad():
                action, _, _ = policy(
                    s=state_input,
                    c=task_id,
                    d=domain_input,
                )
                action = action.detach().cpu().numpy()[0]

            if source_flag and source_action_type == 'inv':
                action = -action

            obs, reward, done, _ = env.step(action)

            if t < render_episodes:
                images.append(env.render('rgb_array'))

            if get_success(obs=obs, target=env.get_target()):
                success = True
                if t < render_episodes:
                    for _ in range(10):
                        images.append(env.render('rgb_array'))
                break

        results.append(success)
        steps.append(step)
        logger.debug(f'Trial {t}: success={success}; steps={step}')

    success_rate = np.array(results).mean()
    steps_mean = np.array(steps).mean()

    if images:
        save_video(path=video_path,
                   images=images,
                   fps=20,
                   skip_rate=5,
                   experiment=experiment)

    return success_rate, steps_mean


def save_video(path: Path,
               images: List[np.ndarray],
               fps: int = 20,
               skip_rate: int = 5,
               experiment: Optional[comet_ml.Experiment] = None):
    os.makedirs(path.parent, exist_ok=True)

    # imageio.mimsave(video_path, images[1::5], fps=20)
    # logger.info(f'video is saved to {video_path}')
    # gifsicle(sources=str(video_path),
    #          destination=str(video_path),
    #          optimize=False,
    #          colors=256,
    #          options=['--optimize=3'])
    # logger.info(f'video has been successfully compressed.')

    mp4_path = path.with_suffix('.mp4')
    imageio.mimsave(mp4_path, images[1::skip_rate], fps=fps)
    logger.info(f'video is saved to {mp4_path}')
    if experiment:
        # experiment.log_image(video_path,
        #                      name=video_path.parent.name,
        #                      step=int(video_path.stem))
        experiment.log_asset(mp4_path,
                             file_name=path.parent.name,
                             step=int(path.stem))


class PairedTrajectoryDataset(Dataset):

    def __init__(
        self,
        source_obs: np.ndarray,
        target_obs: np.ndarray,
        source_action: np.ndarray,
        target_action: np.ndarray,
        dones: np.ndarray,
        task_ids: np.ndarray,
        seq_len: int = 400,
        goal_demo: bool = False,
        sa_demo: bool = False,
    ):
        self.seq_len = seq_len
        self.goal_demo = goal_demo
        self.sa_demo = sa_demo
        self.source_domain_id, self.target_domain_id = np.array(
            [[1, 0]], dtype=np.float32), np.array([[0, 1]], dtype=np.float32)

        self.source_obs_seqs = []
        self.source_action_seqs = []
        self.target_obs_seqs = []
        self.target_action_seqs = []
        self.source_task_ids = []
        self.target_task_ids = []

        self.n_traj = dones.sum()
        logger.info(
            f'start constructing PairedTrajectoryDataset ({dones.sum()} trajectories)'
        )
        st = time()

        done_idx = np.concatenate(([-1], np.where(dones)[0]))
        for i in range(len(done_idx)):
            if i % 1000 == 0:
                logger.info(f'traj {i}')
            if i + 1 == len(done_idx):
                break
            # start: done_idx[i] + 1, end: done_idx[i+1]
            flag = np.zeros(len(source_obs), dtype=bool)
            flag[done_idx[i] + 1:done_idx[i + 1] + 1] = True

            source_obs_seq = source_obs[flag]
            source_action_seq = source_action[flag]
            target_obs_seq = target_obs[flag]
            target_action_seq = target_action[flag]

            self.source_obs_seqs.append(source_obs_seq)
            self.source_action_seqs.append(source_action_seq)
            self.target_obs_seqs.append(target_obs_seq)
            self.target_action_seqs.append(target_action_seq)
            # dones_seq = dones[flag]
            assert task_ids[flag][0] == task_ids[flag][-1]
            task_id = task_ids[flag][0]
            self.source_task_ids.append(task_id)
            self.target_task_ids.append(task_id)

        self.source_task_ids = np.array(self.source_task_ids)
        self.target_task_ids = np.array(self.target_task_ids)
        self.task_id_list = np.unique(self.source_task_ids)
        gl = time()
        logger.info(f'Finish constructing dataset, took {gl - st:.2f} sec.')

    def get_padded_seq(self, arr):
        pad_size = self.seq_len - arr.shape[0]
        pad_mask = np.concatenate(
            (np.zeros(arr.shape[0]), np.ones(pad_size))).astype(bool)
        return np.pad(arr, pad_width=((0, pad_size), (0, 0))), pad_mask

    def sample_seq_of_task_id(self, task_id: int,
                              source: bool) -> Tuple[np.ndarray, np.ndarray]:
        if source:
            source_candidate_idx = np.where(self.source_task_ids == task_id)[0]
            source_selected_idx = np.random.choice(source_candidate_idx)
            source_obs_seq = self.source_obs_seqs[source_selected_idx]
            source_action_seq = self.source_action_seqs[source_selected_idx]

            if self.sa_demo:
                source_obs_seq = np.concatenate(
                    (source_obs_seq, source_action_seq), axis=-1)

            if self.goal_demo:
                return source_obs_seq[-1:], source_action_seq[-1:]
            else:
                return source_obs_seq, source_action_seq
        else:
            target_candidate_idx = np.where(self.target_task_ids == task_id)[0]
            target_selected_idx = np.random.choice(target_candidate_idx)
            target_obs_seq = self.target_obs_seqs[target_selected_idx]
            target_action_seq = self.target_action_seqs[target_selected_idx]
            return target_obs_seq, target_action_seq

    def __len__(self):
        return self.n_traj

    def __getitem__(self, item):
        selected_task_id = np.random.choice(self.task_id_list)
        source_obs_seq, _ = self.sample_seq_of_task_id(selected_task_id,
                                                       source=True)
        target_obs_seq, target_action_seq = self.sample_seq_of_task_id(
            selected_task_id, source=False)

        source_obs_seq, source_pad_mask = self.get_padded_seq(source_obs_seq)
        target_obs_seq, target_pad_mask = self.get_padded_seq(target_obs_seq)
        target_action_seq, _ = self.get_padded_seq(target_action_seq)

        return source_obs_seq, target_obs_seq, target_action_seq, source_pad_mask, target_pad_mask


def _train_val_split_trajs(source_obs: np.ndarray, target_obs: np.ndarray,
                           source_actions: np.ndarray,
                           target_actions: np.ndarray, dones: np.ndarray,
                           task_ids: np.ndarray, train_ratio: float):
    done_idx = np.concatenate(([-1], np.where(dones)[0]))
    start_bit = np.concatenate(([False], dones))[:-1]
    traj_ids = np.cumsum(start_bit)

    total_size = traj_ids.max()
    train_size = int(total_size * train_ratio)
    val_size = total_size - train_size
    logger.info(f'Train: {train_size} trajs; Val: {val_size} trajs.')

    train_traj_ids = np.random.choice(range(traj_ids.max() + 1),
                                      train_size,
                                      replace=False)

    selected_train_steps = np.zeros(len(traj_ids))
    for selected_id in train_traj_ids:
        st_idx, gl_idx = done_idx[selected_id] + 1, done_idx[selected_id + 1]
        selected_train_steps[st_idx] += 1
        if gl_idx + 1 < len(selected_train_steps):
            selected_train_steps[gl_idx + 1] -= 1

    selected_train_steps = np.cumsum(selected_train_steps).astype(bool)

    train_source_obs = source_obs[selected_train_steps]
    train_target_obs = target_obs[selected_train_steps]
    train_source_actions = source_actions[selected_train_steps]
    train_target_actions = target_actions[selected_train_steps]
    train_dones = dones[selected_train_steps]
    train_task_ids = task_ids[selected_train_steps]

    val_source_obs = source_obs[~selected_train_steps]
    val_target_obs = target_obs[~selected_train_steps]
    val_source_actions = source_actions[~selected_train_steps]
    val_target_actions = target_actions[~selected_train_steps]
    val_dones = dones[~selected_train_steps]
    val_task_ids = task_ids[~selected_train_steps]

    return (train_source_obs, train_target_obs, train_source_actions, train_target_actions, train_dones, train_task_ids), \
             (val_source_obs, val_target_obs, val_source_actions, val_target_actions, val_dones, val_task_ids)


def prepare_paired_trajectory_dataset(
    args,
    filter_by_id_fn: Callable,
    trans_into_source_obs: Callable,
    trans_into_source_action: Callable,
    goal_demo: bool = False,
    sa_demo: bool = False,
):
    task_ids, task_ids_onehot, source_obs, source_actions, target_obs, target_actions, dones, goal_to_task_id = read_dataset(
        path=args.dataset,
        source_trans_fn=trans_into_source_obs,
        source_action_trans_fn=trans_into_source_action,
        shift=args.shift,
    )

    source_obs, target_obs, task_ids_onehot, source_actions, target_actions, dones, task_ids = filter_dataset(
        source_obs,
        target_obs,
        task_ids_onehot,
        source_actions,
        target_actions,
        dones,
        task_ids,
        filter_by_id_fn=filter_by_id_fn,
        n_traj=args.n_traj,
    )

    (train_source_obs, train_target_obs, train_source_actions, train_target_actions, train_dones, train_task_ids), \
         (val_source_obs, val_target_obs, val_source_actions, val_target_actions, val_dones, val_task_ids)  \
        = _train_val_split_trajs(source_obs=source_obs,
                           target_obs=target_obs,
                           source_actions=source_actions,
                           target_actions=target_actions,
                           dones=dones,
                           task_ids=task_ids,
                           train_ratio=args.train_ratio,
                           )

    train_dataset = PairedTrajectoryDataset(
        train_source_obs,
        train_target_obs,
        train_source_actions,
        train_target_actions,
        train_dones,
        train_task_ids,
        seq_len=args.seq_len,
        goal_demo=goal_demo,
        sa_demo=sa_demo,
    )
    val_dataset = PairedTrajectoryDataset(
        val_source_obs,
        val_target_obs,
        val_source_actions,
        val_target_actions,
        val_dones,
        val_task_ids,
        seq_len=args.seq_len,
        goal_demo=goal_demo,
        sa_demo=sa_demo,
    )
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size,
    )

    return train_loader, val_loader, goal_to_task_id


def task_id_to_target_pos(
    goal_to_task_id: Dict,
    task_id: int,
) -> Tuple[int, int]:
    target_pos = None
    for pos, id_ in goal_to_task_id.items():
        if id_ == task_id:
            target_pos = pos
    assert target_pos is not None

    return target_pos


def eval_cond_transformer_policy(
        env: gym.Env,
        policy: TransformerPredictor,
        device,
        times: int = 10,
        render_episodes: int = 0,
        video_path: Path = '.',
        experiment: Optional[comet_ml.Experiment] = None,
        traj_dataset: Optional[PairedTrajectoryDataset] = None,
        task_ids: List[int] = (),
        goal_to_task_id: Dict = (),
):
    from transformer_modules import generate_square_subsequent_mask
    timeout_len = env._max_episode_steps if hasattr(
        env, '_max_episode_steps') else 1000
    look_ahead_mask = generate_square_subsequent_mask(
        sz=timeout_len).to(device)

    # if source_flag == True,  x and y will be inverted.
    results = []
    steps = []
    images = []

    for t in range(times):
        while True:
            try:
                selected_task_id = random.choice(task_ids)
                target_center = task_id_to_target_pos(
                    goal_to_task_id=goal_to_task_id, task_id=selected_task_id)
                source_demo, _ = traj_dataset.sample_seq_of_task_id(
                    task_id=selected_task_id, source=True)
                break
            except ValueError:
                pass

        source_demo = source_demo[None]
        obs = env.reset()
        target_location = np.array(target_center) + env.np_random.uniform(
            low=-.1, high=.1, size=env.model.nq)
        env.set_target(target_location)

        while get_success(obs=obs, target=env.get_target(), threshold=0.5):
            obs = env.reset()

        history = [obs]
        done = False
        step = 0
        success = False
        while not done:
            step += 1
            with torch.no_grad():
                t_obs = np.array(history, dtype=np.float32)[None]

                s_obs = torch.from_numpy(source_demo).to(device)
                t_obs = torch.from_numpy(t_obs).to(device)
                out = policy(
                    source_obs=s_obs,
                    target_obs=t_obs,
                    tgt_look_ahead_mask=look_ahead_mask[:len(history
                                                             ), :len(history)],
                )
                action = out.detach().cpu().numpy()[0, -1]

            obs, reward, done, _ = env.step(action)
            history.append(obs)

            if t < render_episodes:
                images.append(env.render('rgb_array'))

            if get_success(obs=obs, target=env.get_target()):
                success = True
                if t < render_episodes:
                    for _ in range(10):
                        images.append(env.render('rgb_array'))
                break

        results.append(success)
        steps.append(step)
        logger.debug(f'Trial {t}: success={success}; steps={step}')

    success_rate = np.array(results).mean()
    steps_mean = np.array(steps).mean()

    if images:
        save_video(path=video_path,
                   images=images,
                   fps=20,
                   skip_rate=5,
                   experiment=experiment)

    return success_rate, steps_mean
