from comet_ml import Experiment

# isort: split

from typing import Dict, Optional

import torch
from torch.utils.data import ConcatDataset, DataLoader, random_split

from common.ours.algorithm_main import epoch_loop
from common.ours.models import Policy
from common.ours.utils import CheckPointer, PLPConfig
from common.utils.evaluate import eval_callback
from common.utils.process_dataset import *


def adapt(args: PLPConfig, experiment: Optional[Experiment], policy: Policy,
          align_dataset_dict: Dict[str, TorchStepDataset]):
    logger.info('Adaptation phase started.')
    args.eval_interval = args.adapt_eval_interval
    args.n_eval_episodes = args.adapt_n_eval_episodes
    args.mmd_coef = 0.0
    args.hausdorff_coef = 0.0

    savedir_root = Path(args.savedir_root.replace("align", "adapt"))
    if experiment:
        experiment.log_parameter('adapt_save_dir', str(savedir_root))

    model_dict: Dict[str, torch.nn.Module] = {}
    optimizer_dict: Dict[str, torch.optim.Optimizer] = {}

    #TODO remove target domain

    train_step_datasets = []
    train_traj_datasets = []
    val_step_datasets = []
    val_traj_datasets = []
    task_id_managers = []

    if args.complex_task:
        args.domains = args.adapt_domains
        read_env_config_yamls(args=args)

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )

        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                image_observation=args.image_observation,
                n_additional_tasks=1 if args.complex_task else 0,
                domain_id=domain_info.domain_id,
                args=args,
            )
        task_id_managers.append(task_id_manager)
        if domain_info.get('target'):
            logger.info(
                f'domain ID {domain_info.domain_id} ({domain_info.env_tag}) is a target domain. '
                f'Data is not included in the adaptation phase.')
            continue
        if args.multienv:
            dataset.n_task_id = args.n_task_ids

        if args.complex_task:
            task_id_manager.add_task_id_to_traj_dataset(
                dataset, task_id=dataset.n_task_id - 1)
        elif not task_id_manager is None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)

        if not args.complex_task:
            dataset = filter_by_goal_id(dataset,
                                        goal_ids=[args.goal],
                                        task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.adapt_n_traj)

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)

            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        train_dataset, val_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)
        train_traj_datasets.append(train_dataset)
        val_traj_datasets.append(val_dataset)

        train_step_dataset = convert_traj_dataset_to_step_dataset(
            train_dataset)
        val_step_dataset = convert_traj_dataset_to_step_dataset(val_dataset)

        train_step_dataset = TorchStepDataset(train_step_dataset,
                                              obs_dim=args.max_obs_dim,
                                              action_dim=args.max_action_dim)
        val_step_dataset = TorchStepDataset(val_step_dataset,
                                            obs_dim=args.max_obs_dim,
                                            action_dim=args.max_action_dim)

        train_step_datasets.append(train_step_dataset)
        val_step_datasets.append(val_step_dataset)

    if args.adapt_with_all_tasks:

        def _get_align_dataset_to_add(align_datasets: List[TorchStepDataset],
                                      adapt_datasets: List[TorchStepDataset]):
            adapt_data_size = sum([len(dataset) for dataset in adapt_datasets])
            target_size = int(adapt_data_size *
                              args.adapt_align_data_size_rate)
            original_dataset = ConcatDataset(align_datasets)
            original_dataset_to_add = random_split(
                original_dataset,
                [target_size, len(original_dataset) - target_size])[0]
            assert len(original_dataset_to_add) <= target_size, \
                f'len(original_dataset_to_add)={len(original_dataset_to_add)}, target_size={target_size}'
            return original_dataset_to_add

        train_dataset_to_add = _get_align_dataset_to_add(
            align_datasets=align_dataset_dict['train_step'],
            adapt_datasets=train_step_datasets)
        train_step_datasets.append(train_dataset_to_add)
        val_dataset_to_add = _get_align_dataset_to_add(
            align_datasets=align_dataset_dict['validation_step'],
            adapt_datasets=val_step_datasets)
        val_step_datasets.append(val_dataset_to_add)

    train_step_dataset = ConcatDataset(train_step_datasets)
    val_step_dataset = ConcatDataset(val_step_datasets)
    train_step_loader = DataLoader(train_step_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_data_workers,
                                   persistent_workers=args.num_data_workers
                                   > 0,
                                   pin_memory=True,
                                   shuffle=True)
    val_step_loader = DataLoader(val_step_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_data_workers,
                                 persistent_workers=args.num_data_workers > 0,
                                 pin_memory=True,
                                 shuffle=True)
    dataloader_dict: Dict[str, DataLoader] = {
        'train': train_step_loader,
        'validation': val_step_loader,
    }

    model_dict['policy'] = policy

    # fix encoder and decoder
    optimizer = torch.optim.Adam(
        policy.core.parameters(),
        lr=args.adapt_lr if args.adapt_lr is not None else args.lr,
        weight_decay=args.weight_decay)
    optimizer_dict['policy'] = optimizer

    checkpointer = CheckPointer(n_epochs=args.adapt_epochs,
                                savedir_root=savedir_root)

    # TODO Do better
    args.n_eval_episodes = args.adapt_n_eval_episodes

    eval_callback(
        args=args,
        policy=policy,
        epoch=0,
        goal_ids=[args.goal],
        savedir_root=savedir_root,
        task_id_manager=task_id_managers[
            0],  # Do all domain have same manager?
        log_prefix='adapt',
        experiment=experiment,
        target_env=True,
    )

    for epoch in range(args.adapt_epochs):
        val_loss = epoch_loop(args,
                              model_dict=model_dict,
                              optimizer_dict=optimizer_dict,
                              dataloader_dict=dataloader_dict,
                              epoch=epoch,
                              experiment=experiment,
                              log_prefix='adapt')
        checkpointer.save_if_necessary(policy=policy,
                                       val_loss=val_loss,
                                       epoch=epoch)

        if (epoch + 1) % args.adapt_eval_interval == 0 or (
                epoch + 1) == args.adapt_epochs:
            eval_callback(
                args=args,
                policy=policy,
                epoch=epoch + 1,
                goal_ids=[args.goal],
                savedir_root=savedir_root,
                task_id_manager=task_id_managers[
                    0],  # Do all domain have same manager?
                log_prefix='adapt',
                experiment=experiment,
                target_env=True,
            )

    logger.info('Adaptation phase finished!')
