import os

from comet_ml import Experiment

# isort: split

from pathlib import PosixPath

import yaml
from omegaconf import OmegaConf
from torch.utils.data import ConcatDataset

import d4rl
from common.dail.adapt import adapt_main
from common.dail.dail import (train_gama, train_source_policy,
                              train_target_dynamics_model)
from common.dail.models import DAILAgent
from common.dail.utils import (GAMAConfig, calc_alignment_score,
                               configure_model_params)
from common.ours.utils import create_savedir_root
from common.utils.evaluate import eval_callback
from common.utils.process_dataset import *


def record_hparams(hparams: dict, savedir: Path):
    log_hparams = {f'align_{key}': val for key, val in hparams.items()}
    log_hparams = {
        k: str(v) if isinstance(v, PosixPath) else v
        for k, v in log_hparams.items()
    }
    path = savedir / 'hparams.yaml'

    with open(path, 'w') as f:
        yaml.safe_dump(log_hparams, f)
    logger.info(f'Hyper-parameters are saved as {path}.')


def align(
    args: DictConfig,
    experiment: Optional[Experiment],
    agent: DAILAgent,
    dataloader_dict: Dict[str, DataLoader],
):

    logdir = Path(args.logdir)
    model_dir = logdir / "model"
    os.makedirs(model_dir, exist_ok=True)

    # ----------------------------------------
    # Start training
    # ----------------------------------------

    model_path = model_dir / "000.pt"
    train_source_policy(
        args=args,
        omega_args=args,
        agent=agent,
        epochs=args.num_epoch_bc,
        train_loader=dataloader_dict["train"],
        model_path=model_path,
        experiment=experiment,
    )

    model_path = model_dir / "000.pt"
    train_target_dynamics_model(
        args=args,
        omega_args=args,
        agent=agent,
        epochs=args.num_epoch_dynamics,
        train_loader=dataloader_dict["train"],
        model_path=model_path,
        experiment=experiment,
    )

    model_path = model_dir / "010.pt"
    train_gama(
        args=args,
        omega_args=args,
        agent=agent,
        epochs=args.num_epoch_gama,
        train_loader=dataloader_dict["train"],
        model_path=model_path,
        experiment=experiment,
    )

    if args.n_domains == 2 and 'point' in args.domains[0].env_tag and 'point' in args.domains[1].env_tag \
        and not hasattr(args.domains[1], 'obs_converter'):
        calc_alignment_score(
            args=args,
            agent=agent,
            data_loader=dataloader_dict["validation"],
            experiment=experiment,
        )


def main(args, experiment: Optional[Experiment], hparams: dict):
    savedir_root = create_savedir_root(phase_tag='dail_align')
    record_hparams(hparams=hparams, savedir=savedir_root)

    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    args.logdir = savedir_root

    ### Dataset
    train_step_datasets = []
    train_traj_datasets = []
    val_step_datasets = []
    val_traj_datasets = []
    task_id_managers = []

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )
        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                n_additional_tasks=1 if args.complex_task else 0,
                image_observation=args.image_observation,
                domain_id=domain_info.domain_id,
                args=args,
            )
        if task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        dataset = filter_by_goal_id(dataset,
                                    goal_ids=args.train_goal_ids,
                                    task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.n_traj)
        if args.multienv:
            dataset.n_task_id = args.n_task_ids

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)
            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        train_dataset, val_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)
        train_traj_datasets.append(train_dataset)
        val_traj_datasets.append(val_dataset)

        train_step_dataset = convert_traj_dataset_to_step_dataset(
            train_dataset)
        val_step_dataset = convert_traj_dataset_to_step_dataset(val_dataset)

        train_step_dataset = TorchStepDataset(train_step_dataset,
                                              obs_dim=args.max_obs_dim,
                                              action_dim=args.max_action_dim)
        val_step_dataset = TorchStepDataset(val_step_dataset,
                                            obs_dim=args.max_obs_dim,
                                            action_dim=args.max_action_dim)

        train_step_datasets.append(train_step_dataset)
        val_step_datasets.append(val_step_dataset)

        task_id_managers.append(task_id_manager)

    train_step_dataset = ConcatDataset(train_step_datasets)
    val_step_dataset = ConcatDataset(val_step_datasets)
    train_step_loader = DataLoader(train_step_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True)
    val_step_loader = DataLoader(val_step_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    dataloader_dict: Dict[str, DataLoader] = {
        'train': train_step_loader,
        'validation': val_step_loader,
    }

    ### Model
    sample: Dict[str, torch.Tensor] = dataloader_dict['validation'].dataset[0]
    args.num_task_ids = sample['task_ids'].shape[-1]
    args.source_state_dim = args.domains[0].obs_dim
    args.target_state_dim = args.domains[1].obs_dim
    args.source_action_dim = args.domains[0].action_dim
    args.target_action_dim = args.domains[1].action_dim

    args = configure_model_params(args)

    logger.info('Alignment phase started.')

    agent = DAILAgent(args).to(args.device)

    align(
        args=args,
        experiment=experiment,
        agent=agent,
        dataloader_dict=dataloader_dict,
    )

    eval_callback(
        args=args,
        policy=agent,
        epoch=0,
        goal_ids=args.train_goal_ids,
        savedir_root=savedir_root,
        task_id_manager=task_id_managers[
            0],  # Do all domain have same manager?
        experiment=experiment,
        target_env=False,
    )
    logger.info('Alignment phase finished!')

    adapt_main(args, experiment, agent)

    eval_callback(
        args=args,
        policy=agent,
        epoch=0,
        goal_ids=[args.goal],
        savedir_root=savedir_root,
        task_id_manager=task_id_managers[
            0],  # Do all domain have same manager?
        log_prefix='adapt',
        experiment=experiment,
        target_env=True,
    )


if __name__ == '__main__':
    args = OmegaConf.structured(GAMAConfig)
    cli_args = OmegaConf.from_cli()

    config_file_path = cli_args.get('config', 'common/dail/config/p2p.yaml')
    assert Path(config_file_path).suffix == '.yaml'
    file_args = OmegaConf.load(config_file_path)
    args = OmegaConf.merge(args, file_args)

    args: GAMAConfig = OmegaConf.merge(args, cli_args)
    read_env_config_yamls(args)

    goal_counts = np.array(
        [domain_info.n_goals for domain_info in args.domains])
    assert np.all(goal_counts == goal_counts[0])

    goal_candidates = get_goal_candidates(
        n_goals=args.domains[0].n_goals
        if not args.multienv else args.n_task_ids,
        target_goal=args.goal,
        align=True,
        complex_task=args.complex_task,
        n_tasks=args.n_tasks,
        is_r2r='Lift' in args.domains[0]['env'],
    )
    args.train_goal_ids = goal_candidates
    hparams = OmegaConf.to_container(args)

    if args.comet:
        experiment = Experiment(
            project_name=os.environ['COMET_PLP_PROJECT_NAME'])
        experiment.set_name(args.name)
        experiment.log_parameters(hparams)
        experiment.add_tag('gama')
        experiment.add_tag(args.config.stem)  # env name
    else:
        experiment = None

    print(args)
    main(args, experiment, hparams)
