from comet_ml import Experiment

# isort: split

import os
from pathlib import PosixPath

import yaml
from torch.utils.data import ConcatDataset

import d4rl
from common.ours.adapt import adapt
from common.ours.algorithm_main import epoch_loop
from common.ours.latent_visualizer import visualize_latent
from common.ours.models import Discriminator, Policy
from common.ours.utils import (CheckPointer, PLPConfig, create_savedir_root)
from common.utils.evaluate import eval_callback
from common.utils.process_dataset import *


def record_hparams(hparams: dict, savedir: Path):
    log_hparams = {f'align_{key}': val for key, val in hparams.items()}
    log_hparams = {
        k: str(v) if isinstance(v, PosixPath) else v
        for k, v in log_hparams.items()
    }
    path = savedir / 'hparams.yaml'

    with open(path, 'w') as f:
        yaml.safe_dump(log_hparams, f)
    logger.info(f'Hyper-parameters are saved as {path}.')


def main(args: PLPConfig, experiment: Optional[Experiment], hparams: dict):
    savedir_root = create_savedir_root(phase_tag='align', name=args.name)
    record_hparams(hparams=hparams, savedir=savedir_root)

    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    ### Dataset
    dataset_dict = defaultdict(list)
    task_id_managers = []

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )
        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                n_additional_tasks=1 if args.complex_task else 0,
                image_observation=args.image_observation,
                domain_id=domain_info.domain_id,
                args=args,
            )
        if task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        dataset = filter_by_goal_id(dataset,
                                    goal_ids=args.train_goal_ids,
                                    task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.n_traj)

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)
            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        train_dataset, val_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)
        dataset_dict["train_traj"].append(train_dataset)
        dataset_dict["validation_traj"].append(val_dataset)

        train_step_dataset = convert_traj_dataset_to_step_dataset(
            train_dataset)
        val_step_dataset = convert_traj_dataset_to_step_dataset(val_dataset)

        train_step_dataset = TorchStepDataset(
            train_step_dataset,
            obs_dim=args.max_obs_dim,
            action_dim=args.max_action_dim,
        )
        val_step_dataset = TorchStepDataset(
            val_step_dataset,
            obs_dim=args.max_obs_dim,
            action_dim=args.max_action_dim,
        )
        dataset_dict["train_step"].append(train_step_dataset)
        dataset_dict["validation_step"].append(val_step_dataset)

        task_id_managers.append(task_id_manager)

    dataloader_dict = {}
    for key, dataset in dataset_dict.items():
        if "traj" in key:
            continue
        if args.multienv:
            for i in range(len(dataset)):
                dataset[i].data.n_task_id = args.n_task_ids
        dataset = ConcatDataset(dataset)
        dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            num_workers=args.num_data_workers,
            pin_memory=True,
            persistent_workers=args.num_data_workers > 0,
            shuffle='train' in key,
        )
        key = key.replace("_step", "")
        dataloader_dict[key] = dataloader

    if args.tcc:
        train_paired_traj_dataset = PairedTrajDataset(
            traj_datasets=dataset_dict["train_traj"],
            obs_dim=args.max_obs_dim,
            action_dim=args.max_action_dim,
            seq_len=args.max_seq_len,
        )
        train_paired_traj_loader = DataLoader(
            train_paired_traj_dataset,
            batch_size=args.tcc_batch_size,
            pin_memory=True,
            num_workers=0,
            shuffle=True,
        )
        if args.tcc_validation:
            val_paired_traj_dataset = PairedTrajDataset(
                traj_datasets=dataset_dict["val_traj"],
                obs_dim=args.max_obs_dim,
                action_dim=args.max_action_dim,
                seq_len=args.max_seq_len,
            )
            val_paired_traj_loader = DataLoader(val_paired_traj_dataset,
                                                batch_size=args.tcc_batch_size,
                                                shuffle=True)
        else:
            val_paired_traj_loader = None

        paired_traj_dataloader_dict: Dict[str, DataLoader] = {
            'train': train_paired_traj_loader,
            'validation': val_paired_traj_loader,
        }
    else:
        paired_traj_dataloader_dict = {}

    ### Model
    model_dict: Dict[str, torch.nn.Module] = {}
    optimizer_dict: Dict[str, torch.optim.Optimizer] = {}

    sample: Dict[str, torch.Tensor] = dataloader_dict['validation'].dataset[0]
    policy = Policy(
        state_dim=args.max_obs_dim,
        cond_dim=sample['task_ids'].shape[-1],
        domain_dim=sample['domain_ids'].shape[-1],
        latent_dim=args.latent_dim,
        hid_dim=args.hid_dim,
        out_dim=args.max_action_dim,
        num_hidden_layers=args.policy_num_layers,
        activation=args.activation,
        repr_activation=args.repr_activation,
        decode_with_state=args.decode_with_state,
        image_observation=args.image_observation,
        image_state_dim=args.image_state_dim,
        use_coord_conv=args.use_coord_conv,
        z_norm=args.z_norm,
        discrete=args.discrete,
        discrete_bins=args.discrete_bins,
        naive_bc=args.naive_bc,
        use_image_decoder=args.use_image_decoder,
        input_image_state_into_decoder=args.input_image_state_into_decoder,
    ).to(args.device)
    model_dict['policy'] = policy

    if args.target_adapt:
        opt_params = list(policy.encoder.parameters()) + list(
            policy.head.parameters())
    else:
        opt_params = list(policy.parameters())

    optimizer = torch.optim.AdamW(opt_params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    optimizer_dict['policy'] = optimizer

    use_adversarial = args.adversarial_coef > 0
    discriminator = Discriminator(
        latent_dim=policy.latent_dim,
        hid_dim=args.hid_dim,
        cond_dim=sample['task_ids'].shape[-1],
        num_classes=args.n_domains,
        activation=args.activation,
        num_hidden_layer=args.disc_num_layers,
    ).to(args.device) if use_adversarial else None
    optimizer_disc = torch.optim.Adam(
        discriminator.parameters(),
        lr=args.disc_lr) if use_adversarial else None
    model_dict['discriminator'] = discriminator
    optimizer_dict['discriminator'] = optimizer_disc

    logger.info('Alignment phase started.')
    checkpointer = CheckPointer(n_epochs=args.epochs,
                                savedir_root=savedir_root)

    for epoch in range(args.epochs):
        val_loss = epoch_loop(
            args,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
            dataloader_dict=dataloader_dict,
            epoch=epoch,
            experiment=experiment,
            paired_dataloader_dict=paired_traj_dataloader_dict,
            log_prefix='align')
        checkpointer.save_if_necessary(policy=policy,
                                       val_loss=val_loss,
                                       epoch=epoch)

        eval_flag = (epoch + 1) % args.eval_interval == 0 or (epoch +
                                                              1) == args.epochs

        if not args.naive_bc:
            visualize_latent(
                args,
                policy=policy,
                data_loader=dataloader_dict['train'],
                epoch=epoch + 1,
                shuffle=True,
                metrics_only=not eval_flag,
                savedir=savedir_root,
                experiment=experiment,
            )

        if eval_flag:
            eval_callback(
                args=args,
                policy=policy,
                epoch=epoch + 1,
                goal_ids=args.train_goal_ids,
                savedir_root=savedir_root,
                task_id_manager=task_id_managers[
                    0],  # Do all domain have same manager?
                log_prefix='align',
                experiment=experiment,
                target_env=False)

    logger.info('Alignment phase finished!')

    args.savedir_root = str(savedir_root)
    adapt(args, experiment, policy, dataset_dict)


if __name__ == '__main__':
    args = OmegaConf.structured(PLPConfig)
    cli_args = OmegaConf.from_cli()

    config_file_path = cli_args.get('config', 'common/ours/config/p2p.yaml')
    assert Path(config_file_path).suffix == '.yaml'
    file_args = OmegaConf.load(config_file_path)
    args = OmegaConf.merge(args, file_args)

    args: PLPConfig = OmegaConf.merge(args, cli_args)
    read_env_config_yamls(args)

    torch.set_num_threads(16)
    torch.backends.cudnn.benchmark = True

    goal_counts = np.array(
        [domain_info.n_goals for domain_info in args.domains])
    assert np.all(goal_counts == goal_counts[0])

    goal_candidates = get_goal_candidates(
        n_goals=args.domains[0].n_goals
        if not args.multienv else args.n_task_ids,
        target_goal=args.goal,
        align=True,
        complex_task=args.complex_task,
        n_tasks=args.n_tasks,
        is_r2r='Lift' in args.domains[0]['env'],
        r2r_single_layer=args.r2r_single_layer,
    )
    args.train_goal_ids = goal_candidates
    hparams = OmegaConf.to_container(args)

    if args.comet:
        experiment = Experiment(
            project_name=os.environ['COMET_PLP_PROJECT_NAME'])
        experiment.set_name(args.name)
        experiment.log_parameters(hparams)
        experiment.add_tag(args.config.stem)  # env name
        if args.naive_bc:
            experiment.add_tag("BC")
        else:
            experiment.add_tag('PLP')
            if args.adversarial_coef > 0:
                experiment.add_tag("Adv")
            if args.mmd_coef > 0:
                experiment.add_tag("MMD")
            if args.tcc:
                experiment.add_tag("TCC")
            if args.state_pred:
                experiment.add_tag("state_pred")
    else:
        experiment = None

    print(args)
    main(args, experiment, hparams)
