from comet_ml import Experiment

# isort: split

import os
from collections import defaultdict
from pathlib import PosixPath

import yaml
from einops import rearrange

import d4rl
from common.cond.transformer_modules import (TransformerPredictor,
                                             generate_square_subsequent_mask)
from common.cond.utils import ContextualConfig
from common.ours.utils import CheckPointer, create_savedir_root
from common.utils.evaluate import eval_callback
from common.utils.process_dataset import *

TQDM_BAR_FORMAT = "{l_bar}{bar:50}{r_bar}"


def record_hparams(hparams: dict, savedir: Path):
    log_hparams = {f'align_{key}': val for key, val in hparams.items()}
    log_hparams = {
        k: str(v) if isinstance(v, PosixPath) else v
        for k, v in log_hparams.items()
    }
    path = savedir / 'hparams.yaml'

    with open(path, 'w') as f:
        yaml.safe_dump(log_hparams, f)
    logger.info(f'Hyper-parameters are saved as {path}.')


def main(args, experiment: Optional[Experiment], hparams: dict):
    savedir_root = create_savedir_root(phase_tag='align', name=args.name)
    record_hparams(hparams=hparams, savedir=savedir_root)

    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    ### Dataset
    train_traj_datasets = []
    val_traj_datasets = []
    test_traj_datasets = []
    task_id_managers = []

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )
        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                image_observation=args.image_observation,
                n_additional_tasks=1 if args.complex_task else 0,
                domain_id=domain_info.domain_id,
                args=args,
            )
        if task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        dataset = filter_by_goal_id(dataset,
                                    goal_ids=args.train_goal_ids,
                                    task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.n_traj)

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)
            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        train_dataset, val_dataset = train_val_split(
            dataset, train_ratio=args.train_ratio)
        train_traj_datasets.append(train_dataset)
        val_traj_datasets.append(val_dataset)

        task_id_managers.append(task_id_manager)

    train_paired_traj_dataset = PairedTrajDataset(
        traj_datasets=train_traj_datasets,
        obs_dim=args.max_obs_dim,
        action_dim=args.max_action_dim,
        seq_len=args.max_seq_len,
        sa_demo=args.sa_demo,
    )
    train_paired_traj_loader = DataLoader(
        train_paired_traj_dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        persistent_workers=True,
        num_workers=4,
        shuffle=True,
    )
    val_paired_traj_dataset = PairedTrajDataset(
        traj_datasets=val_traj_datasets,
        obs_dim=args.max_obs_dim,
        action_dim=args.max_action_dim,
        seq_len=args.max_seq_len,
        sa_demo=args.sa_demo,
    )
    val_paired_traj_loader = DataLoader(
        val_paired_traj_dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        persistent_workers=True,
        num_workers=4,
        shuffle=True,
    )

    if args.complex_task:
        args.domains = args.adapt_domains
        read_env_config_yamls(args=args)

    for domain_info in args.domains:
        if args.multienv:
            dataset, task_id_manager = read_multi_dataset(
                domain_info=domain_info,
                image_observation=args.image_observation,
                args=args,
            )
        else:
            dataset, task_id_manager = read_dataset(
                path=domain_info.dataset,
                env_id=domain_info.env,
                image_observation=args.image_observation,
                domain_id=domain_info.domain_id,
                args=args,
                n_additional_tasks=1 if args.complex_task else 0,
            )

        if args.complex_task:
            task_id_manager.add_task_id_to_traj_dataset(
                dataset, task_id=dataset.n_task_id - 1)
        elif task_id_manager is not None:
            task_id_manager.add_task_id_to_traj_dataset(dataset)

        dataset = remove_single_step_trajectories(dataset)
        if not args.complex_task:
            dataset = filter_by_goal_id(dataset,
                                        goal_ids=[args.goal],
                                        task_id_manager=task_id_manager)
        dataset = select_n_trajectories(dataset, n_traj=args.test_n_traj)

        if obs_converter_name := domain_info.get('obs_converter'):
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(name=obs_converter_name,
                                                  **trans_args)
            else:
                obs_converter = get_obs_converter(name=obs_converter_name)
            dataset.apply_obs_converter(obs_converter)

        if action_converter_name := domain_info.get('action_converter'):
            action_converter = get_action_converter(name=action_converter_name)
            dataset.apply_action_converter(action_converter)

        dataset.add_domain_id(domain_id=domain_info.domain_id,
                              n_domain_id=args.n_domains)

        test_traj_datasets.append(dataset)

    test_paired_traj_dataset = PairedTrajDataset(
        traj_datasets=test_traj_datasets,
        obs_dim=args.max_obs_dim,
        action_dim=args.max_action_dim,
        seq_len=args.max_seq_len,
        sa_demo=args.sa_demo,
    )
    test_paired_traj_loader = DataLoader(test_paired_traj_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True)

    ### Model
    model_dict: Dict[str, torch.nn.Module] = {}
    optimizer_dict: Dict[str, torch.optim.Optimizer] = {}

    sample: Dict[str, np.ndarray] = train_paired_traj_loader.dataset[0]
    source_obs_dim = args.domains[0].obs_dim
    target_obs_dim = args.domains[1].obs_dim
    source_action_dim = args.domains[0].action_dim

    policy = TransformerPredictor(
        enc_in_dim=source_obs_dim if not args.sa_demo else source_obs_dim +
        source_action_dim,
        dec_in_dim=target_obs_dim,
        out_dim=sample['actions2'].shape[-1],
        embed_dim=args.latent_dim,
        n_head=args.n_head,
        ff_dim=args.hid_dim,
        n_enc_layers=args.n_enc_layers,
        n_dec_layers=args.n_dec_layers,
        image_observation=args.image_observation,
        image_state_dim=args.image_state_dim,
        coord_conv=args.use_coord_conv,
    ).to(args.device)

    model_dict['policy'] = policy

    opt_params = list(policy.parameters())

    look_ahead_mask = generate_square_subsequent_mask(sz=args.max_seq_len).to(
        args.device)

    optimizer = torch.optim.Adam(opt_params, lr=args.lr)
    optimizer_dict['policy'] = optimizer
    checkpointer = CheckPointer(n_epochs=args.epochs,
                                savedir_root=savedir_root)

    def process_batch(args,
                      batch,
                      policy,
                      optimizer,
                      scaler,
                      train: bool = False):
        loss_dict = {}
        demo_dim_limit = (source_obs_dim + source_action_dim
                          if args.sa_demo else source_obs_dim)
        s_obs = batch['obs1'].to(args.device)[..., :demo_dim_limit]
        t_obs = batch['obs2'].to(args.device)[..., :target_obs_dim]
        t_act = batch['actions2'].to(args.device)
        s_pad_mask = batch['seq_pad_masks1'].to(args.device)
        t_pad_mask = batch['seq_pad_masks2'].to(args.device)

        with torch.cuda.amp.autocast(enabled=args.amp):
            if args.image_observation:
                s_img = batch['image1'].to(args.device)
                t_img = batch['image2'].to(args.device)
                s_img_reshaped = rearrange(s_img, 'b t h w c -> (b t) h w c')
                t_img_reshaped = rearrange(t_img, 'b t h w c -> (b t) h w c')
                s_img_encoded = policy.source_image_encoder(s_img_reshaped)
                t_img_encoded = policy.target_image_encoder(t_img_reshaped)
                s_obs[...,
                      -args.image_state_dim:] = rearrange(s_img_encoded,
                                                          '(b t) d -> b t d',
                                                          b=s_obs.shape[0])
                t_obs[...,
                      -args.image_state_dim:] = rearrange(t_img_encoded,
                                                          '(b t) d -> b t d',
                                                          b=t_obs.shape[0])

                if args.use_image_decoder:
                    images_recon_s = policy.source_image_decoder(s_img_encoded)
                    images_recon_s = rearrange(images_recon_s,
                                               '(b t) h w c -> b t h w c',
                                               b=s_img.shape[0])
                    images_scaled_s = (s_img.float() / 255.).clamp(0, 1)
                    image_recon_loss_s = torch.nn.MSELoss()(images_recon_s,
                                                            images_scaled_s)

                    images_recon_t = policy.target_image_decoder(t_img_encoded)
                    images_recon_t = rearrange(images_recon_t,
                                               '(b t) h w c -> b t h w c',
                                               b=t_img.shape[0])
                    images_scaled_t = (t_img.float() / 255.).clamp(0, 1)
                    image_recon_loss_t = torch.nn.MSELoss()(images_recon_t,
                                                            images_scaled_t)

                    image_recon_loss = image_recon_loss_s + image_recon_loss_t

            out = policy(
                source_obs=s_obs,
                target_obs=t_obs,
                tgt_look_ahead_mask=look_ahead_mask,
                src_pad_mask=s_pad_mask,
                tgt_pad_mask=t_pad_mask,
            )
            loss = torch.nn.MSELoss()(out[~t_pad_mask], t_act[~t_pad_mask])

        loss_dict['main_loss'] = loss.item()
        if args.image_observation and args.use_image_decoder:
            loss_dict['image_recon_loss'] = image_recon_loss.item()
            loss += image_recon_loss * args.image_recon_coef

        if train:
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        return loss_dict

    train_losses_dict = defaultdict(list)
    val_losses_dict = defaultdict(list)
    test_losses_dict = defaultdict(list)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for epoch in range(args.epochs):
        for idx, batch in tqdm(enumerate(train_paired_traj_loader),
                               bar_format=TQDM_BAR_FORMAT,
                               total=len(train_paired_traj_loader)):
            loss_dict = process_batch(args,
                                      batch=batch,
                                      policy=policy,
                                      optimizer=optimizer,
                                      scaler=scaler,
                                      train=True)
            for key, val in loss_dict.items():
                train_losses_dict[key].append(val)

        for idx, batch in tqdm(enumerate(val_paired_traj_loader),
                               bar_format=TQDM_BAR_FORMAT,
                               total=len(val_paired_traj_loader)):
            with torch.no_grad():
                loss_dict = process_batch(args,
                                          batch=batch,
                                          policy=policy,
                                          optimizer=optimizer,
                                          scaler=scaler,
                                          train=False)
            for key, val in loss_dict.items():
                val_losses_dict[key].append(val)

        for idx, batch in tqdm(enumerate(test_paired_traj_loader),
                               bar_format=TQDM_BAR_FORMAT,
                               total=len(test_paired_traj_loader)):
            with torch.no_grad():
                loss_dict = process_batch(args,
                                          batch=batch,
                                          policy=policy,
                                          optimizer=optimizer,
                                          scaler=scaler,
                                          train=False)
            for key, val in loss_dict.items():
                test_losses_dict[key].append(val)

        logger.info(f'=== Epoch {epoch} Summary ===')
        metrics_dict = {}
        for key, val in train_losses_dict.items():
            logger.info(f'train_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'train_{key}'] = np.array(val).mean()
        for key, val in val_losses_dict.items():
            logger.info(f'val_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'val_{key}'] = np.array(val).mean()
        for key, val in test_losses_dict.items():
            logger.info(f'test_{key}:\t{np.array(val).mean():.5f}')
            metrics_dict[f'test_{key}'] = np.array(val).mean()

        checkpointer.save_if_necessary(policy, metrics_dict["val_main_loss"],
                                       epoch)

        if experiment:
            experiment.log_metrics(
                metrics_dict,
                epoch=epoch,
            )

        if (epoch + 1) % args.eval_interval == 0 or (epoch + 1) == args.epochs:
            eval_callback(
                args=args,
                policy=policy,
                epoch=epoch + 1,
                goal_ids=args.train_goal_ids,
                savedir_root=savedir_root,
                task_id_manager=task_id_managers[
                    0],  # Do all domain have same manager?
                log_prefix='align',
                experiment=experiment,
                traj_dataset=val_paired_traj_dataset,
                skip_source=True,
                target_env=False,
            )
            eval_callback(
                args=args,
                policy=policy,
                epoch=epoch + 1,
                goal_ids=[args.goal],
                savedir_root=savedir_root,
                task_id_manager=task_id_managers[
                    0],  # Do all domain have same manager?
                log_prefix='adapt',
                experiment=experiment,
                traj_dataset=test_paired_traj_dataset,
                skip_source=True,
                target_env=True,
            )


if __name__ == '__main__':
    torch.set_num_threads(4)
    args = OmegaConf.structured(ContextualConfig)
    cli_args = OmegaConf.from_cli()

    config_file_path = cli_args.get('config', 'common/cond/config/p2p.yaml')
    assert Path(config_file_path).suffix == '.yaml'
    file_args = OmegaConf.load(config_file_path)
    args = OmegaConf.merge(args, file_args)

    args: ContextualConfig = OmegaConf.merge(args, cli_args)
    read_env_config_yamls(args)

    goal_counts = np.array(
        [domain_info.n_goals for domain_info in args.domains])
    assert np.all(goal_counts == goal_counts[0])

    goal_candidates = get_goal_candidates(
        n_goals=args.domains[0].n_goals
        if not args.multienv else args.n_task_ids,
        target_goal=args.goal,
        align=True,
        complex_task=args.complex_task,
        n_tasks=args.n_tasks,
        is_r2r='Lift' in args.domains[0]['env'],
    )
    args.train_goal_ids = goal_candidates
    hparams = OmegaConf.to_container(args)

    if args.comet:
        experiment = Experiment(
            project_name=os.environ['COMET_PLP_PROJECT_NAME'])
        experiment.set_name(args.name)
        experiment.log_parameters(hparams)
        experiment.add_tag('cond')
        experiment.add_tag(args.config.stem)  # env name
    else:
        experiment = None

    print(args)
    main(args, experiment, hparams)
