import os
import sys

import comet_ml
import yaml
from comet_ml import Experiment

# isort: split

import argparse
import logging
import random
import string
from pathlib import Path
from subprocess import call
from typing import Callable, Dict, Optional, Tuple

import d4rl
import gym
import numpy as np
import torch
from common.algorithm_main import epoch_loop
from common.models import Discriminator, Policy, ReconstructionDecoder
from torch.utils.data import DataLoader

from custom.utils import CheckPointer, create_savedir_root, prepare_dataset

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TRAIN_TASK_IDS = [1, 2, 3, 4, 5, 6]


def trans_into_source_obs(original_obs, shift: bool = False):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    if shift:
        source_obs[..., :2] -= 6
    return source_obs


def get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action


def dataset_concat_fn(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    source_domain_id: np.ndarray,
    target_domain_id: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_next_obs: np.ndarray,
    target_next_obs: np.ndarray,
    source_only: bool = False,
    target_only: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

    def _create_domain_id_array(domain_id: np.ndarray, length: int):
        domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
        return domain_id_array

    source_domain_array = _create_domain_id_array(source_domain_id,
                                                  len(source_obs))
    target_domain_array = _create_domain_id_array(target_domain_id,
                                                  len(target_obs))

    if source_only:
        obs_for_dataset = source_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = source_domain_array
        actions_for_dataset = source_actions
        next_obs_for_dataset = source_next_obs
    elif target_only:
        obs_for_dataset = target_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = target_domain_array
        actions_for_dataset = target_actions
        next_obs_for_dataset = target_next_obs
    else:
        obs_for_dataset = np.concatenate((source_obs, target_obs))
        next_obs_for_dataset = np.concatenate(
            (source_next_obs, target_next_obs))
        cond_for_dataset = np.concatenate((task_ids_onehot, task_ids_onehot))
        domains_for_dataset = np.concatenate(
            (source_domain_array, target_domain_array))
        actions_for_dataset = np.concatenate((source_actions, target_actions))
    return obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset


def _filter_by_id(task_ids: np.ndarray):
    select_flag = np.zeros_like(task_ids, dtype=bool)
    for id_ in TRAIN_TASK_IDS:
        select_flag |= (task_ids == id_)

    return select_flag


def record_hparams(hparams: dict, savedir: Path):
    log_hparams = {f'align_{key}': val for key, val in hparams.items()}
    path = savedir / 'hparams.yaml'
    with open(path, 'w') as f:
        yaml.safe_dump(log_hparams, f)
    logger.info(f'Hyper-parameters are saved as {path}.')


def main(args, experiment: Optional[comet_ml.Experiment], hpamras: dict):
    savedir_root = create_savedir_root(phase_tag='align',
                                       env_tag=args.env[:-3])
    record_hparams(hparams=hparams, savedir=savedir_root)
    cmd = [
        'cp',
        'custom/latent_visualizer.py',
        str(savedir_root),
    ]
    call(cmd)
    visualizer_path = savedir_root / 'latent_visualizer.py'

    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    model_dict: Dict[str, torch.nn.Module] = {}
    optimizer_dict: Dict[str, torch.optim.Optimizer] = {}
    dataloader_dict: Dict[str, DataLoader] = {}

    train_loader, val_loader, _ = prepare_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        dataset_concat_fn=dataset_concat_fn,
        task_id_zero=args.task_id_zero)
    dataloader_dict['train'] = train_loader
    dataloader_dict['validation'] = val_loader
    obs_sample, cond_sample, domain_sample, action_sample, _, _ = train_loader.dataset[
        0]
    policy = Policy(
        state_dim=obs_sample.shape[-1],
        cond_dim=cond_sample.shape[-1],
        domain_dim=domain_sample.shape[-1],
        latent_dim=args.latent_dim,
        hid_dim=args.hid_dim,
        out_dim=action_sample.shape[-1],
        num_hidden_layers=args.policy_num_layers,
        no_head_domain=args.no_head_domain,
        activation=args.activation,
        repr_activation=args.repr_activation,
        enc_sn=args.enc_sn,
        decode_with_state=args.decode_with_state,
    ).to(args.device)
    model_dict['policy'] = policy

    recon_decoder = ReconstructionDecoder(
        latent_dim=policy.latent_dim,
        state_dim=obs_sample.shape[-1],
        hid_dim=policy.hid_dim,
        activation=args.activation,
    ).to(args.device)
    model_dict['recon_decoder'] = recon_decoder
    opt_recon_decoder = torch.optim.Adam(recon_decoder.parameters(),
                                         lr=args.recon_decoder_lr)
    optimizer_dict['recon_decoder'] = opt_recon_decoder

    if args.target_adapt:
        opt_params = list(policy.encoder.parameters()) + list(
            policy.head.parameters())
    else:
        opt_params = list(policy.parameters())

    optimizer = torch.optim.Adam(opt_params, lr=args.lr)
    optimizer_dict['policy'] = optimizer

    discriminator = Discriminator(
        latent_dim=policy.latent_dim,
        hid_dim=args.hid_dim,
        cond_dim=cond_sample.shape[-1],
        task_cond=args.task_cond,
        sa_disc=args.sa_disc,
        activation=args.activation,
        num_hidden_layer=args.disc_num_layers,
        sn=args.disc_sn,
    ).to(args.device) if args.adversarial else None
    optimizer_disc = torch.optim.Adam(
        discriminator.parameters(),
        lr=args.disc_lr) if args.adversarial else None
    model_dict['discriminator'] = discriminator
    optimizer_dict['discriminator'] = optimizer_disc

    logger.info('Alignment phase started.')
    checkpointer = CheckPointer(n_epochs=args.epochs,
                                savedir_root=savedir_root)
    for epoch in range(args.epochs):
        val_loss = epoch_loop(args,
                              model_dict=model_dict,
                              optimizer_dict=optimizer_dict,
                              dataloader_dict=dataloader_dict,
                              epoch=epoch,
                              experiment=experiment,
                              log_prefix='align')
        checkpointer.save_if_necessary(policy=policy,
                                       val_loss=val_loss,
                                       epoch=epoch)

        if not args.no_plot:
            plot_interval = args.epochs // args.plot_times
            if plot_interval == 0:
                epoch_save_flag = epoch == 0 or (epoch + 1) == args.epochs
            else:
                epoch_save_flag = (epoch + 1) % plot_interval == 0 or (
                    epoch + 1) == args.epochs or epoch == 0
            if epoch_save_flag:
                logger.info('Generate a representation plot.')
                cmd = [
                    sys.executable,
                    str(visualizer_path),
                    '--n_traj',
                    '100',
                    '--load',
                    str(savedir_root / 'best' / 'model.pt'),
                    '--env',
                    args.env,
                    '--device',
                    args.device,
                ]
                if args.action == 'inv':
                    cmd += ['--action', 'inv']
                logger.info(f'call: {cmd}')
                call(cmd)
                if experiment is not None:
                    for img_path in (savedir_root / 'best').glob('*.png'):
                        experiment.log_image(img_path, step=epoch + 1)
                logger.info('Plots are saved.')

                for dist_txt_path in (savedir_root / 'best').glob('dist*.txt'):
                    with open(dist_txt_path, 'r') as f:
                        dist = float(f.read())
                    dist_name = dist_txt_path.stem
                    logger.info(f'DIST_INFO: {dist_name}: {dist}')
                    if experiment is not None:
                        experiment.log_metric(
                            name=dist_name,
                            value=dist,
                            epoch=epoch,
                        )

    logger.info('Alignment phase finished!')

    if experiment:
        logger.info('Call adaptation script.')
        args.name = args.name if args.name else ''.join(
            random.choices(string.ascii_letters + string.digits, k=10))
        cmd = [
            sys.executable,
            'custom/adapt.py',
            '--comet',
            '--load',
            str(savedir_root / f'last_epoch{args.epochs:04d}' / 'model.pt'),
            '--device',
            args.device,
            '--name',
            f'eval-{args.name}',
            '--goal',
            str(args.goal),
            '--env',
            args.env,
        ]
        logger.info(f'call: {cmd}')
        call(cmd)
    else:
        logger.info('Call adaptation script.')
        cmd = [
            sys.executable,
            'custom/adapt.py',
            '--load',
            str(savedir_root / f'last_epoch{args.epochs:04d}' / 'model.pt'),
            '--device',
            args.device,
            '--goal',
            str(args.goal),
            '--env',
            args.env,
        ]
        logger.info(f'call: {cmd}')
        call(cmd)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--hid_dim', type=int, default=256)
    parser.add_argument('--latent_dim', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--disc_lr', type=float, default=1e-3)
    parser.add_argument('--recon_decoder_lr', type=float, default=1e-3)
    parser.add_argument('--train_ratio', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--adapt_epochs', type=int, default=50)
    parser.add_argument('--plot_times', type=int, default=5)
    parser.add_argument('--n_tasks', type=int, default=-1)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--env', type=str, default='maze2d-medium-v1')
    parser.add_argument('--n_traj', type=int, default=None)
    parser.add_argument('--comet', action='store_true')
    parser.add_argument('--adversarial_coef', type=float, default=0.5)
    parser.add_argument('--sa_disc', action='store_true')
    parser.add_argument('--action',
                        type=str,
                        default='normal',
                        help='"normal" or "inv"')
    parser.add_argument('--no_head_domain', action='store_true')
    parser.add_argument('--activation',
                        type=str,
                        default='relu',
                        help='activation. ["relu", "noact", "tanh", "ln"].')
    parser.add_argument(
        '--repr_activation',
        type=str,
        default='relu',
        help=
        'activation before representation. ["relu", "noact", "tanh", "ln"].')
    parser.add_argument('--enc_decay', type=float, default=0.)
    parser.add_argument('--shift', action='store_true')
    parser.add_argument('--enc_sn', action='store_true')
    parser.add_argument('--disc_sn', action='store_true')
    parser.add_argument('--source_only', action='store_true')
    parser.add_argument('--target_only', action='store_true')
    parser.add_argument('--target_adapt', action='store_true')
    parser.add_argument('--task_id_zero', action='store_true')
    parser.add_argument('--no_plot', action='store_true')
    parser.add_argument('--decode_with_state', action='store_true')
    parser.add_argument('--policy_num_layers',
                        type=int,
                        nargs=3,
                        default=[3, 5, 3])
    parser.add_argument('--disc_num_layers', type=int, default=4)
    parser.add_argument(
        '--goal',
        type=int,
        default=7,
        help="Goal ID of target task. It's not used in alignment phase.")
    args = parser.parse_args()
    args.adversarial = True
    args.task_cond = True

    if args.env == 'maze2d-umaze-v1':
        args.dataset = 'data/maze2d-umaze-sparse-v1.hdf5'
        args.calc_align_score = True
        lis = (np.array(range(7)) + 1).tolist()
        lis.remove(args.goal)

        if args.n_tasks > 0:
            lis = random.sample(lis, k=args.n_tasks)

        TRAIN_TASK_IDS = lis
        logger.info(
            f'Train tasks ({len(TRAIN_TASK_IDS)} tasks) = {TRAIN_TASK_IDS}')
    elif args.env == 'maze2d-medium-v1':
        args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.calc_align_score = True
        lis = (np.array(range(26)) + 1).tolist()
        lis.remove(args.goal)

        if args.n_tasks > 0:
            lis = random.sample(lis, k=args.n_tasks)

        TRAIN_TASK_IDS = lis
        logger.info(
            f'Train tasks ({len(TRAIN_TASK_IDS)} tasks) = {TRAIN_TASK_IDS}')
    else:
        raise ValueError(f'Invalid env {args.env} is specified.')

    assert not (args.source_only and args.target_only)

    hparams = {
        'env': args.env,
        'batch_size': args.batch_size,
        'hid_dim': args.hid_dim,
        'latent_dim': args.latent_dim,
        'lr': args.lr,
        'disc_lr': args.disc_lr,
        'epochs': args.epochs,
        'train_task_ids': TRAIN_TASK_IDS,
        'test_goal': args.goal,
        'n_traj': args.n_traj,
        'adversarial': args.adversarial,
        'action_type': args.action,
        'no_head_domain': args.no_head_domain,
        'task_cond': args.task_cond,
        'activation': args.activation,
        'repr_activation': args.repr_activation,
        'enc_decay': args.enc_decay,
        'n_tasks': args.n_tasks,
        'shift': args.shift,
        'enc_sn': args.enc_sn,
        'disc_sn': args.disc_sn,
        'source_only': args.source_only,
        'target_adapt': args.target_adapt,
        'calc_align_score': args.calc_align_score,
        'recon_decoder_lr': args.recon_decoder_lr,
        'sa_disc': args.sa_disc,
        'policy_num_layers': args.policy_num_layers,
        'adapt_epochs': args.adapt_epochs,
        'task_id_zero': args.task_id_zero,
        'decode_with_state': args.decode_with_state,
    }

    if args.adversarial:
        hparams['adversarial_coef'] = args.adversarial_coef

    if args.comet:
        experiment = Experiment(project_name='d4rl-cd')
        experiment.set_name(args.name)
        experiment.add_tags(['align', args.env[:-3], f'{args.action}-action'])
        experiment.log_parameters(hparams)
        if args.adversarial:
            experiment.add_tag('adversarial')
    else:
        experiment = None

    main(args, experiment, hparams)
