import string
import sys
from subprocess import call

import comet_ml
import yaml
from comet_ml import Experiment

# isort: split

import argparse
import logging
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import d4rl
import gym
import numpy as np
import torch
from common.dail.dail import (DAILAgent, train_gama, train_source_policy,
                              train_target_dynamics_model)
from omegaconf import DictConfig

from custom.dail_utils import (eval_policy_dail, get_omega_args,
                               record_align_hparams)
from custom.utils import create_savedir_root, prepare_dataset

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

SELECTED_TASK_IDS = [1, 2, 3, 4, 5, 6]


def trans_into_source_obs(original_obs, shift: bool = False):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    if shift:
        source_obs[..., :2] -= 6
    return source_obs


def get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action


def task_id_to_target_pos(
    goal_to_task_id: Dict,
    task_id: int,
) -> Tuple[int, int]:
    target_pos = None
    for pos, id_ in goal_to_task_id.items():
        if id_ == task_id:
            target_pos = pos
    assert target_pos is not None

    return target_pos


def dataset_concat_fn(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    source_domain_id: np.ndarray,
    target_domain_id: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_next_obs: np.ndarray,
    target_next_obs: np.ndarray,
    source_only: bool = False,
    target_only: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

    def _create_domain_id_array(domain_id: np.ndarray, length: int):
        domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
        return domain_id_array

    source_domain_array = _create_domain_id_array(source_domain_id,
                                                  len(source_obs))
    target_domain_array = _create_domain_id_array(target_domain_id,
                                                  len(target_obs))

    if source_only:
        obs_for_dataset = source_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = source_domain_array
        actions_for_dataset = source_actions
        next_obs_for_dataset = source_next_obs
    elif target_only:
        obs_for_dataset = target_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = target_domain_array
        actions_for_dataset = target_actions
        next_obs_for_dataset = target_next_obs
    else:
        obs_for_dataset = np.concatenate((source_obs, target_obs))
        next_obs_for_dataset = np.concatenate(
            (source_next_obs, target_next_obs))
        cond_for_dataset = np.concatenate((task_ids_onehot, task_ids_onehot))
        domains_for_dataset = np.concatenate(
            (source_domain_array, target_domain_array))
        actions_for_dataset = np.concatenate((source_actions, target_actions))
    return obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset


def _filter_by_id(task_ids: np.ndarray):
    select_flag = np.zeros_like(task_ids, dtype=bool)
    for id_ in SELECTED_TASK_IDS:
        select_flag |= (task_ids == id_)

    return select_flag


def record_hparams(hparams: dict, savedir: Path):
    log_hparams = {f'align_{key}': val for key, val in hparams.items()}
    path = savedir / 'hparams.yaml'
    with open(path, 'w') as f:
        yaml.safe_dump(log_hparams, f)
    logger.info(f'Hyper-parameters are saved as {path}.')


def main(args, experiment: Optional[comet_ml.Experiment], hparams: dict,
         omega_args: DictConfig):
    savedir_root = create_savedir_root(phase_tag=f'dail-{args.phase}',
                                       env_tag=args.env[:-3])
    omega_args.logdir = str(savedir_root)
    record_hparams(hparams=hparams, savedir=savedir_root)

    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    is_task_id_zero = args.phase == 'adapt'
    train_loader, val_loader, goal_to_task_id = prepare_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        dataset_concat_fn=dataset_concat_fn,
        task_id_zero=is_task_id_zero)

    if is_task_id_zero:
        logger.info('Task ID is zero-cleared.')
    obs_sample, cond_sample, domain_sample, action_sample, _, _ = train_loader.dataset[
        0]

    #### After loading dataset...

    logdir = Path(omega_args.logdir)
    model_path = logdir / 'model.pt'

    # Make model
    agent = DAILAgent(omega_args).to(args.device)
    if args.load:
        pretrained_model_path = Path(args.load)
        if pretrained_model_path.suffix != '.pt':
            pretrained_model_path = pretrained_model_path / 'model.pt'

        if not pretrained_model_path.exists():
            logger.error(f'{pretrained_model_path} does not exist.')
            exit(0)

        agent.load_state_dict(
            torch.load(pretrained_model_path, map_location=args.device))
        print("Loaded pretrained model", omega_args.pretrained)

    if omega_args.train_source_policy:
        train_source_policy(args=args,
                            omega_args=omega_args,
                            agent=agent,
                            epochs=args.bc_epochs,
                            train_loader=train_loader,
                            model_path=model_path,
                            experiment=experiment)

    if args.phase == 'align' and omega_args.train_dynamics_model:
        train_target_dynamics_model(args=args,
                                    omega_args=omega_args,
                                    agent=agent,
                                    epochs=args.dyn_epochs,
                                    train_loader=train_loader,
                                    model_path=model_path,
                                    experiment=experiment)

    if args.phase == 'align' and omega_args.train_gama:
        train_gama(
            args=args,
            omega_args=omega_args,
            agent=agent,
            epochs=args.gama_epochs,
            train_loader=train_loader,
            val_loader=val_loader,
            model_path=model_path,
            experiment=experiment,
        )

    if args.phase == 'align':
        if args.comet:
            logger.info('Call adaptation script.')
            args.name = args.name if args.name else ''.join(
                random.choices(string.ascii_letters + string.digits, k=10))
            cmd = [
                sys.executable,
                'custom/dail.py',
                '--comet',
                '--load',
                str(savedir_root / 'model.pt'),
                '--device',
                args.device,
                '--name',
                f'eval-{args.name}',
                '--goal',
                str(args.goal),
                '--phase',
                'adapt',
                '--env',
                args.env,
            ]
            logger.info(f'call: {cmd}')
            call(cmd)
        else:
            logger.info('Call adaptation script.')
            cmd = [
                sys.executable,
                'custom/dail.py',
                '--load',
                str(savedir_root / 'model.pt'),
                '--device',
                args.device,
                '--goal',
                str(args.goal),
                '--phase',
                'adapt',
                '--env',
                args.env,
            ]
            logger.info(f'call: {cmd}')
            call(cmd)

    if args.phase == 'adapt':
        env = gym.make(args.env)
        assert len(SELECTED_TASK_IDS) == 1
        target_pos = task_id_to_target_pos(
            goal_to_task_id=goal_to_task_id,
            task_id=SELECTED_TASK_IDS[0],
        )
        success_rate_source, steps_mean_source = eval_policy_dail(
            env=env,
            policy=agent,
            device=args.device,
            source_trans_fn=trans_into_source_obs,
            source_action_type=args.action,
            task_dim=cond_sample.shape[-1],
            times=10,
            target_center=target_pos,
            source_flag=True,
            render_episodes=args.render_episodes,
            video_path=savedir_root / 'source' / f'{args.gama_epochs}.gif',
            experiment=experiment,
            shift=args.shift,
        )
        success_rate, steps_mean = eval_policy_dail(
            env=env,
            policy=agent,
            device=args.device,
            source_trans_fn=trans_into_source_obs,
            source_action_type=args.action,
            task_dim=cond_sample.shape[-1],
            times=10,
            target_center=target_pos,
            source_flag=False,
            render_episodes=args.render_episodes,
            video_path=savedir_root / 'target' / f'{args.gama_epochs}.gif',
            experiment=experiment,
            shift=args.shift,
        )
        logger.info(f'Success rate (source): {success_rate_source * 100:.1f}% '
                    f'Step len (source): {steps_mean_source:.1f}')
        logger.info(
            f'Success rate: {success_rate * 100:.1f}% Step len: {steps_mean:.1f}'
        )
        if experiment:
            experiment.log_metric('source_success_rate', success_rate_source)
            experiment.log_metric('source_ave_step', steps_mean)
            experiment.log_metric('success_rate', success_rate)
            experiment.log_metric('ave_step', steps_mean)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--train_ratio', type=float, default=0.9)
    parser.add_argument('--bc_epochs', type=int, default=10)
    parser.add_argument('--bc_epochs_adapt', type=int, default=50)
    parser.add_argument('--dyn_epochs', type=int, default=3)
    parser.add_argument('--gama_epochs', type=int, default=10)
    parser.add_argument('--plot_interval', type=int, default=5)
    parser.add_argument('--n_tasks', type=int, default=-1)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--env', type=str, default='maze2d-medium-v1')
    parser.add_argument('--load',
                        type=str,
                        default='',
                        help='dir path to the pre-trained model.')
    parser.add_argument('--n_traj', type=int, default=None)
    parser.add_argument('--comet', action='store_true')
    parser.add_argument('--adversarial_coef', type=float, default=0.5)
    parser.add_argument('--action',
                        type=str,
                        default='normal',
                        help='"normal" or "inv"')
    parser.add_argument('--task_cond', action='store_true')
    parser.add_argument('--shift', action='store_true')
    parser.add_argument('--source_only', action='store_true')
    parser.add_argument('--target_only', action='store_true')
    parser.add_argument('--phase',
                        type=str,
                        help='"align" or "adapt"',
                        default='align')
    parser.add_argument(
        '--goal',
        type=int,
        default=7,
        help="Goal ID of target task. It's not used in alignment phase.")
    parser.add_argument('--render_episodes', type=int, default=0)
    args = parser.parse_args()
    args.task_cond = True

    if args.env == 'maze2d-umaze-v1':
        args.dataset = 'data/maze2d-umaze-sparse-v1.hdf5'
        args.task_dim = 7
        lis = (np.array(range(args.task_dim)) + 1).tolist()
        args.calc_align_score = True
        if args.phase == 'align':
            lis.remove(args.goal)
            SELECTED_TASK_IDS = lis
        else:
            SELECTED_TASK_IDS = [args.goal]

    elif args.env == 'maze2d-medium-v1':
        # args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.task_dim = 26
        lis = (np.array(range(args.task_dim)) + 1).tolist()
        args.calc_align_score = True
        if args.phase == 'align':
            lis.remove(args.goal)
            SELECTED_TASK_IDS = lis
        else:
            SELECTED_TASK_IDS = [args.goal]

        if args.n_tasks > 0:
            SELECTED_TASK_IDS = random.sample(SELECTED_TASK_IDS,
                                              k=args.n_tasks)
        logger.info(
            f'Train tasks ({len(SELECTED_TASK_IDS)} tasks) = {SELECTED_TASK_IDS}'
        )
    else:
        raise ValueError(f'Invalid env {args.env} is specified.')

    assert not (args.source_only and args.target_only)

    hparams = {
        'env': args.env,
        'batch_size': args.batch_size,
        'bc_epochs': args.bc_epochs,
        'bc_epochs_adapt': args.bc_epochs,
        'dyn_epochs': args.dyn_epochs,
        'gama_epochs': args.gama_epochs,
        'train_task_ids': SELECTED_TASK_IDS,
        'n_traj': args.n_traj,
        'action_type': args.action,
        'task_cond': args.task_cond,
        'n_tasks': args.n_tasks,
        'shift': args.shift,
        'source_only': args.source_only,
        'adversarial_coef': args.adversarial_coef,
    }

    if args.comet:
        experiment = Experiment(project_name='d4rl-cd')
        experiment.set_name(args.name)
        experiment.add_tags(
            [f'dail-{args.phase}', args.env[:-3], f'{args.action}-action'])
        experiment.log_parameters(hparams)
    else:
        experiment = None

    omega_args = get_omega_args(phase=args.phase, source_env=args.env)
    omega_args.train_source_policy = True
    omega_args.train_dynamics_model = True
    omega_args.train_gama = True

    args.source_state_dim = omega_args.source_state_dim
    args.source_action_dim = omega_args.source_action_dim
    args.target_state_dim = omega_args.target_state_dim
    args.target_action_dim = omega_args.target_action_dim

    if args.phase == 'adapt':
        # TODO wrong hparam record
        record_align_hparams(Path(args.load).parent, experiment, args)
        args.bc_epochs = args.bc_epochs_adapt

    if args.task_cond:
        omega_args.models.discriminator.in_dim += omega_args.num_task_ids

    main(args, experiment, hparams, omega_args)
