from comet_ml import Experiment

# isort: split

from typing import Dict, Optional

import torch
from torch.utils.data import ConcatDataset, DataLoader

from common.dail.dail import train_source_policy
from common.dail.models import DAILAgent
from common.ours.utils import create_savedir_root
from common.utils.evaluate import eval_callback
from common.utils.process_dataset import *


def adapt(
    args: DictConfig,
    experiment: Optional[Experiment],
    agent: DAILAgent,
    dataloader_dict: Dict[str, DataLoader],
):
    logdir = Path(args.logdir)
    model_dir = logdir / "model"

    # ----------------------------------------
    # Train source domain policy with BC
    # ----------------------------------------

    adapt_model_path = model_dir / "020.pt"
    train_source_policy(
        args=args,
        omega_args=args,
        agent=agent,
        epochs=args.num_epoch_adapt,
        train_loader=dataloader_dict["train"],
        model_path=adapt_model_path,
        experiment=experiment,
    )


def adapt_main(args, experiment: Optional[Experiment], agent: DAILAgent):
    logger.info('Adaptation phase started.')
    args.eval_interval = args.adapt_eval_interval
    args.n_eval_episodes = args.adapt_n_eval_episodes

    savedir_root = create_savedir_root(phase_tag='adapt_dail')
    if experiment:
        experiment.log_parameter('adapt_save_dir', str(savedir_root))

    train_step_datasets = []
    train_traj_datasets = []
    val_step_datasets = []
    val_traj_datasets = []
    task_id_managers = []

    if args.complex_task:
        args.domains = args.adapt_domains
        read_env_config_yamls(args=args)

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )
        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                n_additional_tasks=1 if args.complex_task else 0,
                image_observation=args.image_observation,
                domain_id=domain_info.domain_id,
                args=args,
            )
        if task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)
        if domain_info.get('target'):
            logger.info(
                f'domain ID {domain_info.domain_id} ({domain_info.env_tag}) is a target domain. '
                f'Data is not included in the adaptation phase.')
            continue

        if args.complex_task:
            task_id_manager.add_task_id_to_traj_dataset(
                dataset, task_id=dataset.n_task_id - 1)
        elif not task_id_manager is None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        if not args.complex_task:
            dataset = filter_by_goal_id(dataset,
                                        goal_ids=[args.goal],
                                        task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.adapt_n_traj)
        if args.multienv:
            dataset.n_task_id = args.n_task_ids

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)

            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        train_dataset, val_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)
        train_traj_datasets.append(train_dataset)
        val_traj_datasets.append(val_dataset)

        train_step_dataset = convert_traj_dataset_to_step_dataset(
            train_dataset)
        val_step_dataset = convert_traj_dataset_to_step_dataset(val_dataset)

        train_step_dataset = TorchStepDataset(train_step_dataset,
                                              obs_dim=args.max_obs_dim,
                                              action_dim=args.max_action_dim)
        val_step_dataset = TorchStepDataset(val_step_dataset,
                                            obs_dim=args.max_obs_dim,
                                            action_dim=args.max_action_dim)

        train_step_datasets.append(train_step_dataset)
        val_step_datasets.append(val_step_dataset)

    train_step_dataset = ConcatDataset(train_step_datasets)
    val_step_dataset = ConcatDataset(val_step_datasets)
    train_step_loader = DataLoader(train_step_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True)
    val_step_loader = DataLoader(val_step_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True)
    dataloader_dict: Dict[str, DataLoader] = {
        'train': train_step_loader,
        'validation': val_step_loader,
    }

    # TODO Do better
    args.n_eval_episodes = args.adapt_n_eval_episodes

    adapt(args, experiment, agent=agent, dataloader_dict=dataloader_dict)

    logger.info('Adaptation phase finished!')
