import os

import comet_ml
from comet_ml import Experiment

# isort: split
import argparse
import logging
import sys
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple

import d4rl
import gym
import numpy as np
import torch
import yaml
from common.algorithm_main import epoch_loop
from common.models import Policy
from torch.utils.data import DataLoader

from custom.utils import create_savedir_root, eval_policy, prepare_dataset

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TEST_TASK_ID = 7


def trans_into_source_obs(original_obs, shift: bool = False):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    if shift:
        source_obs[..., :2] -= 6
    return source_obs


def get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action


def dataset_concat_fn(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    source_domain_id: np.ndarray,
    target_domain_id: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_next_obs: np.ndarray,
    target_next_obs: np.ndarray,
    source_only: bool = False,
    target_only: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    assert not source_only

    def _create_domain_id_array(domain_id: np.ndarray, length: int):
        domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
        return domain_id_array

    source_domain_array = _create_domain_id_array(source_domain_id,
                                                  len(source_obs))

    obs_for_dataset = source_obs
    cond_for_dataset = task_ids_onehot
    domains_for_dataset = source_domain_array
    actions_for_dataset = source_actions
    next_obs_for_dataset = source_next_obs
    return obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset


def task_id_to_target_pos(
    goal_to_task_id: Dict,
    task_id: int,
) -> Tuple[int, int]:
    target_pos = None
    for pos, id_ in goal_to_task_id.items():
        if id_ == task_id:
            target_pos = pos
    assert target_pos is not None

    return target_pos


def _filter_by_id(task_ids: np.ndarray):
    select_flag = task_ids == TEST_TASK_ID
    return select_flag


def eval_callback(
    policy: Policy,
    env: gym.Env,
    epoch: int,
    target_pos: np.array,
    device,
    experiment: Optional[comet_ml.Experiment],
    render_episodes: int = 0,
    savedir_root: Path = Path(),
    source_action_type: str = 'normal',
    shift: bool = False,
):
    success_rate_source, steps_mean_source = \
        eval_policy(env=env,
                    policy=policy,
                    device=device,
                    source_trans_fn=trans_into_source_obs,
                    source_action_type=source_action_type,
                    target_center=target_pos,
                    source_flag=True,
                    render_episodes=render_episodes,
                    video_path=savedir_root / 'source' / f'{epoch}.gif',
                    experiment=experiment,
                    shift=shift,
                    )
    success_rate, steps_mean = eval_policy(
        env=env,
        policy=policy,
        device=device,
        source_trans_fn=trans_into_source_obs,
        source_action_type=source_action_type,
        target_center=target_pos,
        render_episodes=render_episodes,
        video_path=savedir_root / 'target' / f'{epoch}.gif',
        experiment=experiment,
        shift=shift,
    )
    logger.info(
        f'Epoch: {epoch} Success rate (source): {success_rate_source * 100:.1f}% '
        f'Step len (source): {steps_mean_source:.1f}')
    logger.info(
        f'Epoch: {epoch} Success rate: {success_rate * 100:.1f}% Step len: {steps_mean:.1f}'
    )
    if experiment is not None:
        experiment.log_metrics(
            {
                'source_success_rate': success_rate_source,
                'source_steps_mean': steps_mean_source,
                'success_rate': success_rate,
                'steps_mean': steps_mean,
            },
            epoch=epoch)


def record_align_hparams(path: Path, experiment: Optional[comet_ml.Experiment],
                         args):
    yaml_path = path / 'hparams.yaml'
    if not yaml_path.exists():
        logger.info(f'{yaml_path} not found')
        return

    with open(yaml_path, 'r') as f:
        align_hparams = yaml.safe_load(f)

    logger.info(f'Align hparams in {yaml_path} are loaded.')

    if 'align_action_type' in align_hparams.keys():
        args.action = align_hparams["align_action_type"]

    if 'align_no_head_domain' in align_hparams.keys():
        args.no_head_domain = align_hparams['align_no_head_domain']
    else:
        args.no_head_domain = False

    if 'align_latent_dim' in align_hparams.keys():
        args.latent_dim = align_hparams['align_latent_dim']
    else:
        args.latent_dim = 256

    if 'align_hid_dim' in align_hparams.keys():
        args.hid_dim = align_hparams['align_hid_dim']
    else:
        args.hid_dim = 256

    if 'align_activation' in align_hparams.keys():
        args.activation = align_hparams['align_activation']
    else:
        args.activation = 'relu'

    if 'align_repr_activation' in align_hparams.keys():
        args.repr_activation = align_hparams['align_repr_activation']
    else:
        args.repr_activation = 'relu'

    if 'align_shift' in align_hparams.keys():
        args.shift = align_hparams['align_shift']
    else:
        args.shift = False

    if 'align_enc_sn' in align_hparams.keys():
        args.enc_sn = align_hparams['align_enc_sn']
    else:
        args.enc_sn = False

    if 'align_policy_num_layers' in align_hparams.keys():
        args.policy_num_layers = align_hparams['align_policy_num_layers']
    else:
        args.policy_num_layers = [3, 5, 3]

    if 'align_decode_with_state' in align_hparams.keys():
        args.decode_with_state = align_hparams['align_decode_with_state']
    else:
        args.decode_with_state = False

    if experiment:
        experiment.log_parameters(align_hparams)
        if 'align_adversarial' in align_hparams.keys(
        ) and align_hparams['align_adversarial']:
            experiment.add_tag('adversarial')

        if 'align_action_type' in align_hparams.keys():
            experiment.add_tag(f'{align_hparams["align_action_type"]}-action')


def main(args, experiment: Optional[comet_ml.Experiment]):
    savedir_root = create_savedir_root(phase_tag='adapt',
                                       env_tag=args.env[:-3])
    if experiment:
        experiment.log_parameter('save_dir', str(savedir_root))

    record_align_hparams(Path(args.load).parent.parent, experiment, args)

    model_dict: Dict[str, torch.nn.Module] = {}
    optimizer_dict: Dict[str, torch.optim.Optimizer] = {}
    dataloader_dict: Dict[str, DataLoader] = {}

    train_loader, val_loader, goal_to_task_id = prepare_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        dataset_concat_fn=dataset_concat_fn,
        task_id_zero=True)
    dataloader_dict['train'] = train_loader
    dataloader_dict['validation'] = val_loader
    obs_sample, cond_sample, domain_sample, action_sample, _, _ = train_loader.dataset[
        0]
    policy = Policy(
        state_dim=obs_sample.shape[-1],
        cond_dim=cond_sample.shape[-1],
        domain_dim=domain_sample.shape[-1],
        latent_dim=args.latent_dim,
        hid_dim=args.hid_dim,
        out_dim=action_sample.shape[-1],
        num_hidden_layers=args.policy_num_layers,
        no_head_domain=args.no_head_domain,
        activation=args.activation,
        repr_activation=args.repr_activation,
        enc_sn=args.enc_sn,
        decode_with_state=args.decode_with_state,
    ).to(args.device)
    model_dict['policy'] = policy

    # fix encoder and decoder
    optimizer = torch.optim.Adam(policy.core.parameters(), lr=args.lr)
    optimizer_dict['policy'] = optimizer

    ### adapt specific preprocessing strat
    policy.load_state_dict(torch.load(args.load, map_location=args.device))

    target_pos = task_id_to_target_pos(
        goal_to_task_id=goal_to_task_id,
        task_id=TEST_TASK_ID,
    )

    env = gym.make(args.env)
    ### adapt specific preprocessing end

    eval_callback(
        policy,
        env,
        0,
        target_pos,
        args.device,
        experiment,
        render_episodes=args.render_episodes,
        savedir_root=savedir_root,
        source_action_type=args.action,
    )

    for epoch in range(args.epochs):
        _ = epoch_loop(
            args=args,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
            dataloader_dict=dataloader_dict,
            epoch=epoch,
            experiment=experiment,
        )

        if (epoch + 1) % args.eval_interval == 0 or (epoch + 1) == args.epochs:
            eval_callback(
                policy,
                env,
                epoch + 1,
                target_pos,
                args.device,
                experiment,
                render_episodes=args.render_episodes,
                savedir_root=savedir_root,
                source_action_type=args.action,
                shift=args.shift,
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument(
        '--load',
        type=str,
        default=
        'custom/results/maze2d-umaze/align/20220709-032430/best/model.pt')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--train_ratio', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--eval_interval', type=int, default=20)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--dataset',
                        type=str,
                        default='data/maze2d-umaze-sparse-v1.hdf5')
    parser.add_argument('--n_traj', type=int, default=None)
    parser.add_argument('--env', type=str, default='maze2d-medium-v1')
    parser.add_argument('--render_episodes', type=int, default=0)
    parser.add_argument('--comet', action='store_true')
    parser.add_argument('--goal',
                        type=int,
                        default=7,
                        help="Goal ID of target task.")
    args = parser.parse_args()

    assert Path(args.load).exists(), f'{args.load} does not exist.'

    if args.env == 'maze2d-umaze-v1':
        args.dataset = 'data/maze2d-umaze-sparse-v1.hdf5'
        args.calc_align_score = True
        TEST_TASK_ID = args.goal
    elif args.env == 'maze2d-medium-v1':
        # args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        args.calc_align_score = True
        TEST_TASK_ID = args.goal
    else:
        raise ValueError(f'Invalid env {args.env} is specified.')

    hparams = {
        'load': args.load,
        'batch_size': args.batch_size,
        'lr': args.lr,
        'epochs': args.epochs,
        'test_task_id': TEST_TASK_ID,
    }

    if args.comet:
        experiment = Experiment(project_name='d4rl-cd')
        experiment.set_name(args.name)
        experiment.add_tags(['adapt', args.env[:-3]])
        experiment.log_parameters(hparams)

    else:
        experiment = None

    main(args, experiment)
