import argparse
import logging
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from common.models import Policy
from sklearn.manifold import TSNE

from custom.utils import prepare_dataset

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TRAIN_TASK_IDS = [1, 2, 3, 4, 5, 6]


def trans_into_source_obs(original_obs, shift: bool = False):
    source_obs = original_obs[..., [1, 0, 3, 2]]
    if shift:
        source_obs[..., :2] -= 6
    return source_obs


def get_action_translator(action_type: str):
    assert action_type in ['normal',
                           'inv'], f'invalid action_type "{action_type}"'

    def trans_into_source_action(original_action):
        if action_type == 'normal':
            source_action = original_action
        else:
            source_action = -original_action

        return source_action

    return trans_into_source_action


def dataset_concat_fn(
    source_obs: np.ndarray,
    target_obs: np.ndarray,
    source_actions: np.ndarray,
    target_actions: np.ndarray,
    source_domain_id: np.ndarray,
    target_domain_id: np.ndarray,
    task_ids_onehot: np.ndarray,
    source_next_obs: np.ndarray,
    target_next_obs: np.ndarray,
    source_only: bool = False,
    target_only: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

    def _create_domain_id_array(domain_id: np.ndarray, length: int):
        domain_id_array = np.tile(domain_id, (length, 1)).astype(np.float32)
        return domain_id_array

    source_domain_array = _create_domain_id_array(source_domain_id,
                                                  len(source_obs))
    target_domain_array = _create_domain_id_array(target_domain_id,
                                                  len(target_obs))

    if source_only:
        obs_for_dataset = source_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = source_domain_array
        actions_for_dataset = source_actions
        next_obs_for_dataset = source_next_obs
    elif target_only:
        obs_for_dataset = target_obs
        cond_for_dataset = task_ids_onehot
        domains_for_dataset = target_domain_array
        actions_for_dataset = target_actions
        next_obs_for_dataset = target_next_obs
    else:
        obs_for_dataset = np.concatenate((source_obs, target_obs))
        next_obs_for_dataset = np.concatenate(
            (source_next_obs, target_next_obs))
        cond_for_dataset = np.concatenate((task_ids_onehot, task_ids_onehot))
        domains_for_dataset = np.concatenate(
            (source_domain_array, target_domain_array))
        actions_for_dataset = np.concatenate((source_actions, target_actions))

    return obs_for_dataset, cond_for_dataset, domains_for_dataset, actions_for_dataset, next_obs_for_dataset


def _filter_by_id(task_ids: np.ndarray):
    select_flag = np.zeros_like(task_ids, dtype=bool)
    for id_ in TRAIN_TASK_IDS:
        select_flag |= (task_ids == id_)

    return select_flag


def main(args):
    img_root = Path(args.load).parent

    yaml_path = Path(args.load).parent.parent / 'hparams.yaml'
    if not yaml_path.exists():
        logger.info(f'{yaml_path} not found.')
        return

    with open(yaml_path, 'r') as f:
        align_hparams = yaml.safe_load(f)

    hid_dim = align_hparams[
        'align_hid_dim'] if 'align_hid_dim' in align_hparams else 256
    latent_dim = align_hparams[
        'align_latent_dim'] if 'align_latent_dim' in align_hparams else 256
    no_head_domain = align_hparams[
        'align_no_head_domain'] if 'align_no_head_domain' in align_hparams else False
    activation = align_hparams[
        'align_activation'] if 'align_activation' in align_hparams else 'relu'
    repr_activation = align_hparams[
        'align_repr_activation'] if 'align_repr_activation' in align_hparams else 'relu'
    shift_flag = align_hparams[
        'align_shift'] if 'align_shift' in align_hparams else False
    enc_sn = align_hparams[
        'align_enc_sn'] if 'align_enc_sn' in align_hparams else False
    policy_num_layers = align_hparams[
        'align_policy_num_layers'] if 'align_policy_num_layers' in align_hparams else [
            3, 5, 3
        ]
    decode_with_state = align_hparams[
        'align_decode_with_state'] if 'align_decode_with_state' in align_hparams else False

    args.shift = shift_flag
    train_loader, val_loader, goal_to_task_id = prepare_dataset(
        args=args,
        filter_by_id_fn=_filter_by_id,
        trans_into_source_obs=trans_into_source_obs,
        trans_into_source_action=get_action_translator(
            action_type=args.action),
        dataset_concat_fn=dataset_concat_fn,
        task_id_zero=False)

    obs_sample, cond_sample, domain_sample, action_sample, _, _ = train_loader.dataset[
        0]
    policy = Policy(
        state_dim=obs_sample.shape[-1],
        cond_dim=cond_sample.shape[-1],
        domain_dim=domain_sample.shape[-1],
        latent_dim=latent_dim,
        hid_dim=hid_dim,
        out_dim=action_sample.shape[-1],
        no_head_domain=no_head_domain,
        activation=activation,
        repr_activation=repr_activation,
        enc_sn=enc_sn,
        num_hidden_layers=policy_num_layers,
        decode_with_state=decode_with_state,
    ).to(args.device)
    policy.load_state_dict(torch.load(args.load, map_location=args.device))

    source_state_vectors = []
    latents_z_source = []
    latents_z_target = []
    latents_za_source = []
    latents_za_target = []
    actions_source = []
    actions_target = []
    task_ids = []
    # state_vectors = []

    for states, conds, domain_ids, actions, _, _ in val_loader:
        states = states.to(args.device)
        domain_ids = domain_ids.to(args.device)

        with torch.no_grad():
            # [0, 1] is target domain where observation is normal
            target_state_flag = domain_ids[..., 1] > 0.5
            source_states = states[~target_state_flag]
            source_domain_ids = domain_ids[~target_state_flag]
            conds = conds[~target_state_flag]

            c = torch.zeros((len(source_domain_ids), policy.cond_dim),
                            device=args.device)
            c[..., 0] = 1.  # task ID

            source_a, source_z, source_alpha = policy(
                s=source_states,
                d=source_domain_ids,
                c=c,
            )
            if args.action == 'inv':
                source_a = -source_a

            target_states = source_states[..., [1, 0, 3, 2]]
            if shift_flag:
                shift = torch.tensor([6, 6, 0, 0],
                                     device=target_states.device,
                                     dtype=torch.float)
                target_states += shift
            target_domain_ids = source_domain_ids[..., [1, 0]]
            c = torch.zeros((len(target_domain_ids), policy.cond_dim),
                            device=args.device)
            c[..., 0] = 1.  # task ID

            target_a, target_z, target_alpha = policy(
                s=target_states,
                d=target_domain_ids,
                c=c,
            )

        restored_source_state = source_states.cpu().numpy()
        if shift_flag:
            restored_source_state += shift.cpu().numpy()
        source_state_vectors.append(restored_source_state)
        latents_z_source.append(source_z.cpu().numpy())
        latents_z_target.append(target_z.cpu().numpy())
        latents_za_source.append(source_alpha.cpu().numpy())
        latents_za_target.append(target_alpha.cpu().numpy())
        actions_source.append(source_a.cpu().numpy())
        actions_target.append(target_a.cpu().numpy())
        task_ids.append(conds.cpu().numpy().argmax(1) + 1)

    latents_z_source = np.vstack(latents_z_source)
    latents_z_target = np.vstack(latents_z_target)
    latents_za_source = np.vstack(latents_za_source)
    latents_za_target = np.vstack(latents_za_target)
    actions_source = np.vstack(actions_source)
    actions_target = np.vstack(actions_target)
    source_state_vectors = np.vstack(source_state_vectors)
    task_ids = np.concatenate(task_ids)
    # state_vectors = np.vstack(state_vectors)

    names = ['z', 'z_a', 'action']
    source_vectors = [latents_z_source, latents_za_source, actions_source]
    target_vectors = [latents_z_target, latents_za_target, actions_target]

    for name, source_vector, target_vector in zip(names, source_vectors,
                                                  target_vectors):

        # ============== TSNE ================
        len_source = len(source_vector)
        tsne = TSNE(n_components=2, random_state=0, perplexity=30, n_iter=1000)
        if name == 'action':
            z_transform = np.concatenate(
                (source_vector, target_vector))  # for actions
        else:
            z_transform = tsne.fit_transform(
                np.concatenate((source_vector, target_vector)))
        z_transform_source = z_transform[:len_source]
        z_transform_target = z_transform[len_source:]

        if name == 'action':
            latent_min, latent_max = [-1, -1], [1, 1]
        else:
            latent_max = z_transform.max(0)
            latent_min = z_transform.min(0)

        # ============== ALL PLOT ================
        plt.clf()

        plt.scatter(*z_transform_source.T, label='source', marker='.')
        plt.scatter(*z_transform_target.T, label='target', marker='.')
        plt.xlim([latent_min[0], latent_max[0]])
        plt.ylim([latent_min[1], latent_max[1]])

        tsne_dist = np.linalg.norm(z_transform_source - z_transform_target,
                                   axis=1).mean()
        with open(img_root / f'dist_{name}.txt', 'w') as f:
            f.write(str(tsne_dist))

        # draw lines
        for k in range(10):
            first = z_transform_source[k]
            second = z_transform_target[k]
            plt.plot((first[0], second[0]), (first[1], second[1]),
                     color='black',
                     lw=1)
        plt.legend()
        plt.title(
            f'(All) {name}: {args.load}'
        )
        plt.tight_layout()
        plt.savefig(img_root / f'all-{name}.png')
        if args.show:
            plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument('--load',
                        type=str,
                        default='custom/test/best/model.pt')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--train_ratio', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--env', type=str, default='maze2d-medium-v1')
    parser.add_argument('--n_traj', type=int, default=100)
    parser.add_argument('--comet', action='store_true')
    parser.add_argument('--adversarial', action='store_true')
    parser.add_argument('--prior', action='store_true')
    parser.add_argument('--goal',
                        type=int,
                        default=7,
                        help="Goal ID of target task.")
    parser.add_argument(
        '--prior_epochs',
        type=int,
        default=-1,
        help='# epochs for prior learning. It is set to args.epoch by default.'
    )
    parser.add_argument('--adversarial_coef', type=float, default=5.)
    parser.add_argument('--action',
                        type=str,
                        default='normal',
                        help='"normal" or "inv"')
    parser.add_argument('--no_head_domain', action='store_true')
    parser.add_argument('--show', action='store_true')
    args = parser.parse_args()

    if args.env == 'maze2d-umaze-v1':
        args.dataset = 'data/maze2d-umaze-sparse-v1.hdf5'
        TRAIN_TASK_IDS = (np.array(range(7)) + 1).tolist()
        TRAIN_TASK_IDS.remove(args.goal)
    elif args.env == 'maze2d-medium-v1':
        args.dataset = 'data/maze2d-medium-sparse-v1.hdf5'
        TRAIN_TASK_IDS = (np.array(range(26)) + 1).tolist()
        TRAIN_TASK_IDS.remove(args.goal)
    else:
        raise ValueError(f'Invalid env {args.env} is specified.')

    main(args)
