from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from time import sleep
from typing import Dict, List, Optional, Union

import numpy as np
import torch
from comet_ml import Experiment
from omegaconf import DictConfig
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from common.utils.evaluate import eval_callback
from common.utils.process_dataset import get_obs_converter


@dataclass
class CDILConfig:
    config: Optional[Path] = None
    domains: List[Dict] = field(default_factory=list)
    adapt_domains: List[Dict] = field(default_factory=list)

    n_domains: int = 2
    complex_task: bool = False

    n_tasks: int = -1
    n_traj: int = 1000
    train_ratio: float = 0.9
    batch_size: int = 256
    device: str = 'cuda:0'
    adapt_n_traj: int = -1

    name: Optional[str] = None
    robot: bool = False

    image_observation: bool = False
    image_state_dim: int = 1024
    use_image_decoder: bool = False
    image_recon_coef: float = 1.0
    use_coord_conv: bool = False
    comet: bool = False
    goal: int = 5
    pretrain_epochs: int = 50
    epochs: int = 50
    idm_epochs: int = 50
    bc_epochs: int = 250
    lr: float = 1e-4
    evaluate: bool = True
    evaluate_parallel: bool = True  # avoid bug in parallelization of P2A
    amp: bool = False
    pretrained: bool = False  # TODO remove?

    multienv: bool = False
    n_task_ids: Optional[int] = None
    target_goal_id: Optional[int] = None
    task_id_offset_list: Optional[List[int]] = None
    target_task_id_offset: Optional[int] = None

    eval_interval: int = 50
    n_eval_episodes: int = 100
    n_render_episodes: int = 10

    gamma: float = 0.95
    latent_dim: int = 256
    lambda_1: float = 2
    lambda_2: float = 1
    lambda_3: float = 1
    lambda_4: float = 1
    lambda_5: float = 1

    # set in the script
    max_obs_dim: int = field(init=False)
    max_action_dim: int = field(init=False)
    max_seq_len: int = field(init=False)
    train_goal_ids: List[int] = field(init=False)
    logdir: Union[str, Path] = field(init=False)

    state_converter: Dict = field(default_factory=lambda: {
        'hid_dim': 256,
        'num_hidden_layers': [2, 2],
        'activation': 'relu',
    })

    discriminator: Dict = field(
        default_factory=lambda: {
            'hid_dim': 256,
            'num_hidden_layers': 2,
            'activation': 'relu',
            'adversarial_coef': 1,
        })

    positional_encoder: Dict = field(default_factory=lambda: {
        'hid_dim': 256,
        'num_hidden_layers': 2,
        'activation': 'relu',
    })

    inverse_dynamics_model: Dict = field(default_factory=lambda: {
        'hid_dim': 256,
        'num_hidden_layers': 2,
        'activation': 'relu',
    })

    policy: Dict = field(default_factory=lambda: {
        'hid_dim': 256,
        'num_hidden_layers': 2,
        'activation': 'relu',
    })


def calc_alignment_score(
    args,
    model_dict,
    data_loader: DataLoader,
    experiment: Experiment,
):

    if obs_converter_name := args.domains[0].get('obs_converter'):
        if trans_args := args.domains[0].get('obs_converter_args'):
            obs_converter = get_obs_converter(name=obs_converter_name,
                                              **trans_args)
        else:
            obs_converter = get_obs_converter(name=obs_converter_name)
    else:
        obs_converter = None

    forward_converter = model_dict["forward_converter"]
    backward_converter = model_dict["backward_converter"]

    source_z_list = []
    target_z_list = []
    for batch in data_loader:
        target_flag = (batch['domain_ids'][..., -1:] > .5).flatten()
        assert np.all(target_flag.numpy())
        target_states = batch['observations'].to(args.device)

        if len(target_states) == 0:
            continue

        source_states = target_states
        if obs_converter is not None:
            source_states = obs_converter(source_states.cpu().numpy())
            source_states = torch.from_numpy(source_states).to(args.device)

        with torch.no_grad():
            _, source_z = forward_converter(source_states)
            _, target_z = backward_converter(target_states)

        source_z_list.append(source_z.cpu().numpy())
        target_z_list.append(target_z.cpu().numpy())

    source_z = np.concatenate(source_z_list)
    target_z = np.concatenate(target_z_list)

    dists = np.linalg.norm(source_z - target_z, axis=-1, keepdims=True)

    source_norm = np.linalg.norm(source_z, axis=-1, keepdims=True)
    target_norm = np.linalg.norm(target_z, axis=-1, keepdims=True)
    norm_two = (source_norm + target_norm) / 2
    norm_all = norm_two.mean()
    dist_norm_all = (dists / norm_all).mean()
    print(f'Latent score = {dist_norm_all:4f}')

    if experiment is not None:
        metrics_dict = {'latent_dist': dist_norm_all}
        experiment.log_metrics(metrics_dict)


def pretrain_position_encoding(
    args,
    experiment,
    model_dict,
    optimizer_dict,
    dataloader_dict,
):
    source_positional_encoder = model_dict["source_positional_encoder"]
    target_positional_encoder = model_dict["target_positional_encoder"]
    optimizer_pretrain = optimizer_dict["pretrain"]

    if args.image_observation:
        source_image_encoder = model_dict["policy"].source_image_encoder
        target_image_encoder = model_dict["policy"].target_image_encoder

    mse = torch.nn.MSELoss()

    def align_single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        source_dataloader = dataloader_dict[0]
        target_dataloader = dataloader_dict[1]
        total_len = min(len(source_dataloader), len(target_dataloader))
        loader = zip(source_dataloader, target_dataloader)
        for batch in tqdm(loader, bar_format=bar_format, total=total_len):
            source_batch, target_batch = batch
            source_obs = source_batch["observations"].to(args.device)
            source_pos = source_batch["pos"].to(args.device)
            source_task_ids = source_batch["task_ids"].to(args.device)
            target_obs = target_batch["observations"].to(args.device)
            target_pos = target_batch["pos"].to(args.device)
            target_task_ids = target_batch["task_ids"].to(args.device)

            if args.image_observation:
                source_img = source_batch["images"].to(args.device)
                source_img_encoded = source_image_encoder(source_img)
                source_obs = torch.cat((source_obs, source_img_encoded),
                                       dim=-1)

                target_img = target_batch["images"].to(args.device)
                target_img_encoded = target_image_encoder(target_img)
                target_obs = torch.cat((target_obs, target_img_encoded),
                                       dim=-1)

            ### Temporal Position
            source_pos_hat = source_positional_encoder(source_obs,
                                                       source_task_ids)
            source_pos_loss = mse(source_pos_hat[:, 0], source_pos)

            target_pos_hat = target_positional_encoder(target_obs,
                                                       target_task_ids)
            target_pos_loss = mse(target_pos_hat[:, 0], target_pos)

            total_pos_loss = source_pos_loss + target_pos_loss

            epoch_metrics_dict["source_pos_loss"].append(
                source_pos_loss.item())
            epoch_metrics_dict["target_pos_loss"].append(
                target_pos_loss.item())
            epoch_metrics_dict["total_pos_loss"].append(total_pos_loss.item())

            if args.image_observation and args.use_image_decoder:
                source_image_decoder = model_dict[
                    "policy"].source_image_decoder
                source_images_recon = source_image_decoder(source_img_encoded)
                source_images_scaled = (source_img.float() / 255.).clamp(0, 1)
                source_image_recon_loss = mse(source_images_recon,
                                              source_images_scaled)

                target_image_decoder = model_dict[
                    "policy"].target_image_decoder
                target_images_recon = target_image_decoder(target_img_encoded)
                target_images_scaled = (target_img.float() / 255.).clamp(0, 1)
                target_image_recon_loss = mse(target_images_recon,
                                              target_images_scaled)

                image_recon_loss = source_image_recon_loss + target_image_recon_loss
                epoch_metrics_dict["image_recon_loss"].append(
                    image_recon_loss.item())
                total_pos_loss += image_recon_loss * args.image_recon_coef

            if optimizer:
                optimizer.zero_grad()
                total_pos_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )

        return epoch_metrics_dict

    def adapt_single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        source_dataloader = dataloader_dict[0]
        for batch in tqdm(source_dataloader, bar_format=bar_format):
            source_obs = batch["observations"].to(args.device)
            source_pos = batch["pos"].to(args.device)
            source_task_ids = batch["task_ids"].to(args.device)

            if args.image_observation:
                source_img = batch["images"].to(args.device)
                source_img_encoded = source_image_encoder(source_img)
                source_obs = torch.cat((source_obs, source_img_encoded),
                                       dim=-1)

            ### Temporal Position
            source_pos_hat = source_positional_encoder(source_obs,
                                                       source_task_ids)
            source_pos_loss = mse(source_pos_hat[:, 0], source_pos)

            total_pos_loss = source_pos_loss

            epoch_metrics_dict["source_pos_loss"].append(
                source_pos_loss.item())
            epoch_metrics_dict["total_pos_loss"].append(total_pos_loss.item())

            if args.image_observation and args.use_image_decoder:
                source_image_decoder = model_dict[
                    "policy"].source_image_decoder
                source_images_recon = source_image_decoder(source_img_encoded)
                source_images_scaled = (source_img.float() / 255.).clamp(0, 1)
                source_image_recon_loss = mse(source_images_recon,
                                              source_images_scaled)

                image_recon_loss = source_image_recon_loss
                epoch_metrics_dict["image_recon_loss"].append(
                    image_recon_loss.item())
                total_pos_loss += image_recon_loss * args.image_recon_coef

            if optimizer:
                optimizer.zero_grad()
                total_pos_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )

        return epoch_metrics_dict

    ### Run training loop
    print("Start pretraining temporal position models...")
    for epoch in range(1, args.pretrain_epochs + 1):
        train_epoch_metrics_dict = align_single_epoch(
            args,
            dataloader_dict["align_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_pretrain,
            prefix="pretrain_align_train",
        )

        val_epoch_metrics_dict = align_single_epoch(
            args,
            dataloader_dict["align_val"],
            epoch=epoch,
            experiment=experiment,
            optimizer=None,
            prefix="pretrain_align_val",
        )

        _ = adapt_single_epoch(
            args,
            dataloader_dict["adapt_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_pretrain,
            prefix="pretrain_adapt_train",
        )

        s = f"Epoch: {epoch}/{args.pretrain_epochs} | "
        s += "Position Loss | "
        s += f"Train : {train_epoch_metrics_dict['total_pos_loss']:.4f} | "
        s += f"Val : {val_epoch_metrics_dict['total_pos_loss']:.4f} | "
        print(s)


def main_loop(
    args,
    experiment,
    model_dict,
    optimizer_dict,
    dataloader_dict,
):
    forward_converter = model_dict["forward_converter"]
    backward_converter = model_dict["backward_converter"]
    source_discriminator = model_dict["source_discriminator"]
    target_discriminator = model_dict["target_discriminator"]
    latent_discriminator = model_dict["latent_discriminator"]
    source_positional_encoder = model_dict["source_positional_encoder"]
    target_positional_encoder = model_dict["target_positional_encoder"]
    latent_positional_encoder = model_dict["latent_positional_encoder"]
    optimizer_all: torch.optim.Optimizer = optimizer_dict["all"]

    if args.image_observation:
        source_image_encoder = model_dict["policy"].source_image_encoder
        target_image_encoder = model_dict["policy"].target_image_encoder

    mse = torch.nn.MSELoss()

    def align_single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        source_dataloader = dataloader_dict[0]
        target_dataloader = dataloader_dict[1]
        total_len = min(len(source_dataloader), len(target_dataloader))
        loader = zip(source_dataloader, target_dataloader)
        for batch in tqdm(loader, bar_format=bar_format, total=total_len):
            source_batch, target_batch = batch
            source_obs = source_batch["observations"].to(args.device)
            source_next_obs = source_batch["next_observations"].to(args.device)
            source_domain_ids = source_batch["domain_ids"].to(args.device)
            source_task_ids = source_batch["task_ids"].to(args.device)

            target_obs = target_batch["observations"].to(args.device)
            target_next_obs = target_batch["next_observations"].to(args.device)
            target_domain_ids = target_batch["domain_ids"].to(args.device)
            target_task_ids = target_batch["task_ids"].to(args.device)

            if args.image_observation:
                source_img = source_batch["images"].to(args.device)
                source_next_img = source_batch["next_images"].to(args.device)
                source_img_encoded = source_image_encoder(source_img)
                source_next_img_encoded = source_image_encoder(source_next_img)
                source_obs = torch.cat((source_obs, source_img_encoded),
                                       dim=-1)
                source_next_obs = torch.cat(
                    (source_next_obs, source_next_img_encoded), dim=-1)

                target_img = target_batch["images"].to(args.device)
                target_next_img = target_batch["next_images"].to(args.device)
                target_img_encoded = target_image_encoder(target_img)
                target_next_img_encoded = target_image_encoder(target_next_img)
                target_obs = torch.cat((target_obs, target_img_encoded),
                                       dim=-1)
                target_next_obs = torch.cat(
                    (target_next_obs, target_next_img_encoded), dim=-1)

            domain_ids = torch.cat((source_domain_ids, target_domain_ids),
                                   dim=0)
            task_ids = torch.cat((source_task_ids, target_task_ids), dim=0)

            target_obs_hat, source_z = forward_converter(source_obs)
            target_next_obs_hat, source_next_z = forward_converter(
                source_next_obs)
            source_obs_hat_hat, source_z_hat = backward_converter(
                target_obs_hat)
            source_obs_hat, target_z = backward_converter(target_obs)
            source_next_obs_hat, target_next_z = backward_converter(
                target_next_obs)
            target_obs_hat_hat, target_z_hat = forward_converter(
                source_obs_hat)

            ### Cycle Consistency Loss
            source_cycle_loss = mse(source_obs, source_obs_hat_hat)
            target_cycle_loss = mse(target_obs, target_obs_hat_hat)
            latent_cycle_loss = mse(source_z, source_z_hat) + mse(
                target_z, target_z_hat)
            total_loss = (source_cycle_loss + target_cycle_loss +
                          latent_cycle_loss) * args.lambda_2

            epoch_metrics_dict["source_cycle_loss"].append(
                source_cycle_loss.item())
            epoch_metrics_dict["target_cycle_loss"].append(
                target_cycle_loss.item())
            epoch_metrics_dict["latent_cycle_loss"].append(
                latent_cycle_loss.item())

            ### Source Discriminator
            source_input = torch.cat((source_obs, source_next_obs), dim=-1)
            source_input_hat = torch.cat((source_obs_hat, source_next_obs_hat),
                                         dim=-1)
            source_input = torch.cat((source_input, source_input_hat), dim=0)
            source_logit = source_discriminator(source_input, c=task_ids)
            source_disc_loss = mse(source_logit, domain_ids)
            total_loss += source_disc_loss * args.lambda_1

            epoch_metrics_dict["source_disc_loss"].append(
                source_disc_loss.item())

            ### Target Discriminator
            target_input = torch.cat((target_obs, target_next_obs), dim=-1)
            target_input_hat = torch.cat((target_obs_hat, target_next_obs_hat),
                                         dim=-1)
            target_input = torch.cat((target_input_hat, target_input), dim=0)
            target_logit = target_discriminator(target_input, c=task_ids)
            target_disc_loss = mse(target_logit, domain_ids)
            total_loss += target_disc_loss * args.lambda_1

            epoch_metrics_dict["target_disc_loss"].append(
                target_disc_loss.item())

            ### Latent Discriminator
            source_latent_input = torch.cat((source_z, source_next_z), dim=-1)
            target_latent_input = torch.cat((target_z, target_next_z), dim=-1)
            latent_input = torch.cat(
                (source_latent_input, target_latent_input), dim=0)
            latent_logit = latent_discriminator(latent_input, c=task_ids)
            latent_disc_loss = mse(latent_logit, domain_ids)
            total_loss += latent_disc_loss * args.lambda_4

            epoch_metrics_dict["latent_disc_loss"].append(
                latent_disc_loss.item())

            ### Temporal Position
            source_pos = source_positional_encoder(source_obs, source_task_ids)
            target_pos = target_positional_encoder(target_obs, target_task_ids)

            source_pos_hat = source_positional_encoder(source_obs_hat,
                                                       target_task_ids)
            target_pos_hat = target_positional_encoder(target_obs_hat,
                                                       source_task_ids)

            source_pos_loss = mse(source_pos, target_pos_hat)
            target_pos_loss = mse(target_pos, source_pos_hat)
            pos_loss = source_pos_loss + target_pos_loss
            total_loss += pos_loss * args.lambda_3

            epoch_metrics_dict["source_pos_loss"].append(
                source_pos_loss.item())
            epoch_metrics_dict["target_pos_loss"].append(
                target_pos_loss.item())
            epoch_metrics_dict["total_main_loss"].append(total_loss.item())

            if optimizer:
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )

        return epoch_metrics_dict

    def adapt_single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        source_dataloader = dataloader_dict[0]
        for batch in tqdm(source_dataloader, bar_format=bar_format):
            source_obs = batch["observations"].to(args.device)
            source_task_ids = batch["task_ids"].to(args.device)

            if args.image_observation:
                source_img = batch["images"].to(args.device)
                source_img_encoded = source_image_encoder(source_img)
                source_obs = torch.cat((source_obs, source_img_encoded),
                                       dim=-1)

            target_obs_hat, source_z = forward_converter(source_obs)
            source_obs_hat_hat, source_z_hat = backward_converter(
                target_obs_hat)

            ### Cycle Consistency Loss
            cycle_loss = mse(source_obs, source_obs_hat_hat)
            total_loss = cycle_loss * args.lambda_5

            epoch_metrics_dict["cycle_loss"].append(cycle_loss.item())

            ### Temporal Position
            source_pos = source_positional_encoder(source_obs, source_task_ids)
            latent_pos = latent_positional_encoder(source_z)

            pos_loss = mse(source_pos, latent_pos)
            total_loss += pos_loss * args.lambda_5

            epoch_metrics_dict["pos_loss"].append(pos_loss.item())
            epoch_metrics_dict["total_main_loss"].append(total_loss.item())

            if optimizer:
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )

        return epoch_metrics_dict

    ### Run training loop
    print("Start main loop...")
    for epoch in range(1, args.epochs + 1):
        align_train_epoch_metrics_dict = align_single_epoch(
            args,
            dataloader_dict["align_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_all,
            prefix="align_train",
        )

        align_val_epoch_metrics_dict = align_single_epoch(
            args,
            dataloader_dict["align_val"],
            epoch=epoch,
            experiment=experiment,
            optimizer=None,
            prefix="align_val",
        )

        adapt_train_epoch_metrics_dict = adapt_single_epoch(
            args,
            dataloader_dict["adapt_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_all,
            prefix="adapt_train",
        )

        adapt_val_epoch_metrics_dict = adapt_single_epoch(
            args,
            dataloader_dict["adapt_val"],
            epoch=epoch,
            experiment=experiment,
            optimizer=None,
            prefix="adapt_val",
        )

        s = f"Epoch: {epoch}/{args.epochs} | "
        s += "Total Loss | "
        s += "Align | "
        s += f"Train : {align_train_epoch_metrics_dict['total_main_loss']:.4f} | "
        s += f"Val : {align_val_epoch_metrics_dict['total_main_loss']:.4f} | "
        s += "Adapt | "
        s += f"Train : {adapt_train_epoch_metrics_dict['total_main_loss']:.4f} | "
        s += f"Val : {adapt_val_epoch_metrics_dict['total_main_loss']:.4f} | "
        print(s)
        print(dict(align_train_epoch_metrics_dict))
        print(dict(adapt_train_epoch_metrics_dict))


def train_idm(
    args,
    experiment,
    model_dict,
    optimizer_dict,
    dataloader_dict,
):
    """train inverse dynamics model in target domain
    """
    inverse_dynamics_model = model_dict["inverse_dynamics_model"]
    optimizer_idm: torch.optim.Optimizer = optimizer_dict[
        "inverse_dynamics_model"]

    if args.image_observation:
        target_image_encoder = model_dict["policy"].target_image_encoder

    mse = torch.nn.MSELoss()

    def single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        target_dataloader = dataloader_dict[1]
        for batch in tqdm(target_dataloader, bar_format=bar_format):
            target_obs = batch["observations"].to(args.device)
            target_next_obs = batch["next_observations"].to(args.device)
            target_actions = batch["actions"].to(args.device)

            if args.image_observation:
                target_img = batch["images"].to(args.device)
                target_next_img = batch["next_images"].to(args.device)
                target_img_encoded = target_image_encoder(target_img)
                target_next_img_encoded = target_image_encoder(target_next_img)
                target_obs = torch.cat((target_obs, target_img_encoded),
                                       dim=-1)
                target_next_obs = torch.cat(
                    (target_next_obs, target_next_img_encoded), dim=-1)

            target_actions_pred = inverse_dynamics_model(
                target_obs, target_next_obs)

            idm_loss = mse(target_actions_pred, target_actions)

            epoch_metrics_dict["idm_loss"].append(idm_loss.item())

            if optimizer:
                optimizer.zero_grad()
                idm_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )

        return epoch_metrics_dict

    ### Run training loop
    print("Start training inverse dynamics model...")
    for epoch in range(1, args.idm_epochs + 1):
        train_epoch_metrics_dict = single_epoch(
            args,
            dataloader_dict["align_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_idm,
            prefix="train",
        )

        val_epoch_metrics_dict = single_epoch(
            args,
            dataloader_dict["align_val"],
            epoch=epoch,
            experiment=experiment,
            optimizer=None,
            prefix="val",
        )

        s = f"Epoch: {epoch}/{args.idm_epochs} | "
        s += "IDM Loss | "
        s += f"Train : {train_epoch_metrics_dict['idm_loss']:.4f} | "
        s += f"Val : {val_epoch_metrics_dict['idm_loss']:.4f} | "
        print(s)
        print(dict(train_epoch_metrics_dict))


def behavioral_cloning(
    args,
    experiment,
    model_dict,
    optimizer_dict,
    dataloader_dict,
    task_id_manager,
):
    """train target domain policy by BC
    """
    forward_converter = model_dict["forward_converter"]
    inverse_dynamics_model = model_dict["inverse_dynamics_model"]
    policy = model_dict["policy"]
    optimizer_bc: torch.optim.Optimizer = optimizer_dict["bc"]

    if args.image_observation:
        source_image_encoder = policy.source_image_encoder

    mse = torch.nn.MSELoss()

    def single_epoch(
        args: DictConfig,
        dataloader_dict: DataLoader,
        epoch: int,
        experiment: Optional[Experiment] = None,
        optimizer: Optional[Optimizer] = None,
        prefix: str = "",
    ):
        epoch_metrics_dict = defaultdict(list)
        bar_format = prefix + " {l_bar}{bar:50}{r_bar}"
        source_dataloader = dataloader_dict[0]
        for batch in tqdm(source_dataloader, bar_format=bar_format):
            source_obs = batch["observations"].to(args.device)
            source_next_obs = batch["next_observations"].to(args.device)

            if args.image_observation:
                source_img = batch["images"].to(args.device)
                source_next_img = batch["next_images"].to(args.device)
                source_img_encoded = source_image_encoder(source_img)
                source_next_img_encoded = source_image_encoder(source_next_img)
                source_obs = torch.cat((source_obs, source_img_encoded),
                                       dim=-1)
                source_next_obs = torch.cat(
                    (source_next_obs, source_next_img_encoded), dim=-1)

            target_obs_hat, _ = forward_converter(source_obs)
            target_next_obs_hat, _ = forward_converter(source_next_obs)
            target_actions_hat = inverse_dynamics_model(
                target_obs_hat, target_next_obs_hat)
            target_actions_pred = policy(target_obs_hat)

            bc_loss = mse(target_actions_pred, target_actions_hat)

            epoch_metrics_dict["bc_loss"].append(bc_loss.item())

            if optimizer:
                optimizer.zero_grad()
                bc_loss.backward()
                optimizer.step()

        for k, v in epoch_metrics_dict.items():
            epoch_metrics_dict[k] = np.mean(v)

        if experiment:
            experiment.log_metrics(
                epoch_metrics_dict,
                prefix=prefix,
                epoch=epoch,
            )
            sleep(1)

        return epoch_metrics_dict

    eval_callback(
        args=args,
        policy=policy,
        epoch=0,
        goal_ids=[args.goal],
        savedir_root=Path(args.logdir),
        task_id_manager=task_id_manager,
        experiment=experiment,
        log_prefix="adapt",
        skip_source=True,
        target_env=True,
    )

    ### Run training loop
    print("Start behavioral cloning...")
    for epoch in range(1, args.bc_epochs + 1):
        train_epoch_metrics_dict = single_epoch(
            args,
            dataloader_dict["adapt_train"],
            epoch=epoch,
            experiment=experiment,
            optimizer=optimizer_bc,
            prefix="train",
        )

        val_epoch_metrics_dict = single_epoch(
            args,
            dataloader_dict["adapt_val"],
            epoch=epoch,
            experiment=experiment,
            optimizer=None,
            prefix="val",
        )

        if args.evaluate and (epoch % args.eval_interval == 0
                              or epoch == args.bc_epochs):
            eval_callback(
                args=args,
                policy=policy,
                epoch=epoch,
                goal_ids=[args.goal],
                savedir_root=Path(args.logdir),
                task_id_manager=task_id_manager,
                experiment=experiment,
                log_prefix="adapt",
                skip_source=True,
                target_env=True,
            )

        s = f"Epoch: {epoch}/{args.bc_epochs} | "
        s += "BC Loss | "
        s += f"Train : {train_epoch_metrics_dict['bc_loss']:.6f} | "
        s += f"Val : {val_epoch_metrics_dict['bc_loss']:.6f} | "
        print(s)
        print(dict(train_epoch_metrics_dict))


def save_translation_models(args, model_dict):
    forward_converter = model_dict["forward_converter"]
    backward_converter = model_dict["backward_converter"]

    # save models
    save_path = Path(args.logdir) / "forward_converter.pt"
    torch.save(forward_converter.state_dict(), save_path)

    save_path = Path(args.logdir) / "backward_converter.pt"
    torch.save(backward_converter.state_dict(), save_path)
