import traceback
from pathlib import Path
from typing import Optional, Union

import matplotlib.collections as mc
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data as data
from comet_ml import ExistingExperiment, Experiment
from common.models import Policy
from omegaconf import DictConfig
from sklearn.manifold import TSNE

from .dataset_utils import get_dataset
from .datasets import TrajectoryDataset
from .trans_fn import get_trans_observations_fns

STATE_DIM_DICT = {"ant": 29, "point": 6, "maze2d": 4}


def load_policy(args, model_path):

    if hasattr(args.policy, "decode_with_state"):
        decode_with_state = args.policy.decode_with_state
    else:
        decode_with_state = False

    policy = Policy(
        state_dim=args.policy.state_dim,
        cond_dim=args.policy.cond_dim,
        out_dim=args.policy.out_dim,
        domain_dim=args.policy.domain_dim,
        latent_dim=args.policy.latent_dim,
        hid_dim=args.policy.hid_dim,
        num_hidden_layers=args.policy.num_hidden_layers,
        activation=args.policy.activation,
        repr_activation=args.policy.latent_activation,
        enc_sn=args.policy.spectral_norm,
        decode_with_state=decode_with_state,
    ).to(args.device)
    policy.load_state_dict(torch.load(model_path, map_location=args.device))
    return policy


def make_batches(args):
    # Load datasets
    source_dataset = get_dataset(
        dataset_path=args.source_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_source_observations,
        transform_actions=args.reverse_source_actions,
    )
    target_dataset = get_dataset(
        dataset_path=args.target_dataset,
        task_ids=args.task_ids,
        transform_observations=args.reverse_target_observations,
        transform_actions=args.reverse_target_actions,
    )

    source_dataset = TrajectoryDataset(
        dataset=source_dataset,
        num_task_ids=args.policy.cond_dim,
        max_size=args.max_dataset_size,
        state_dim=args.policy.state_dim,
        domain_id=args.source_domain_id,
        domain_dim=args.policy.domain_dim,
    )
    target_dataset = TrajectoryDataset(
        dataset=target_dataset,
        num_task_ids=args.policy.cond_dim,
        max_size=args.max_dataset_size,
        state_dim=args.policy.state_dim,
        domain_id=args.target_domain_id,
        domain_dim=args.policy.domain_dim,
    )

    dataloader_dict = {
        "source":
        data.DataLoader(source_dataset,
                        batch_size=args.latents_size,
                        shuffle=True,
                        num_workers=4),
        "target":
        data.DataLoader(target_dataset,
                        batch_size=args.latents_size,
                        shuffle=True,
                        num_workers=4),
    }

    loader = zip(dataloader_dict["source"], dataloader_dict["target"])

    src_batch, tgt_batch = next(loader)
    return src_batch, tgt_batch


def run_inference(args, policy, batch):
    states, task_ids, domain_ids, actions, _, _ = batch

    with torch.inference_mode():
        actions_pred, z, alpha = policy(
            states.to(args.device),
            task_ids.to(args.device),
            domain_ids.to(args.device),
        )

    states = states.cpu().numpy()
    task_ids = task_ids.cpu().numpy()
    domain_ids = domain_ids.cpu().numpy()
    actions = actions.cpu().numpy()
    actions_pred = actions_pred[:, :actions.shape[1]].cpu().numpy()
    z = z.cpu().numpy()
    alpha = alpha.cpu().numpy()

    return states, task_ids, domain_ids, actions, actions_pred, z, alpha


def visualize_latents(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]],
    model_path: Union[str, Path],
    image_dir: Path,
    epoch: int,
    prefix: str = "",
):
    policy = load_policy(args, model_path)
    source_batch, target_batch = make_batches(args)

    source_states, source_task_ids, source_domain_ids, source_actions, source_actions_pred, source_z, source_alpha = run_inference(
        args, policy, source_batch)
    target_states, target_task_ids, target_domain_ids, target_actions, target_actions_pred, target_z, target_alpha = run_inference(
        args, policy, target_batch)

    if args.reverse_source_observations:
        _, inv_trans_observation_fn = get_trans_observations_fns(
            args.source_env_id)
        source_states = inv_trans_observation_fn(
            source_states[:, :args.source_state_dim])
    if args.reverse_target_observations:
        _, inv_trans_observation_fn = get_trans_observations_fns(
            args.target_env_id)
        target_states = inv_trans_observation_fn(
            target_states[:, :args.target_state_dim])

    variables = [
        source_states,
        source_task_ids,
        source_domain_ids,
        source_actions,
        source_actions_pred,
        source_z,
        source_alpha,
        target_states,
        target_task_ids,
        target_domain_ids,
        target_actions,
        target_actions_pred,
        target_z,
        target_alpha,
    ]
    idx = np.random.shuffle(np.arange(len(source_states)))
    for v in variables:
        v = v[idx]

    visualize_states(
        args,
        experiment,
        image_dir,
        epoch,
        prefix,
        source_states,
        target_states,
    )

    visualize_actions(
        args,
        experiment,
        image_dir,
        epoch,
        prefix,
        source_actions_pred,
        target_actions_pred,
    )

    # visualize z ----------

    source_z_latent, target_z_latent = transform_latents(source_z, target_z)
    source_alpha_latent, target_alpha_latent = transform_latents(
        source_alpha, target_alpha)

    n = 30
    # sourceの点から座標が近いtargetの点を探して線を引く．
    # 元のobservationの先頭二つはxyであると仮定する．
    source_xys = source_states[:n, :2]
    target_xys = target_states[:, :2]
    if args.target_morph == "maze2d":
        target_xys[:, [0, 1]] = (target_xys[:, [1, 0]] - 1) * 4
    k = 1
    idxs_ = []
    for xy in source_xys:
        idxs = np.argsort(np.sum((target_xys - xy)**2, axis=1))
        for i in range(k):
            idxs_.append(idxs[i])
    m = k * n

    source_z_latent_corr = source_z_latent[:n]
    source_z_latent_corr = np.tile(source_z_latent_corr[None],
                                   (k, 1, 1)).transpose(1, 0,
                                                        2).reshape(m, -1)
    target_z_latent_corr = target_z_latent[idxs_]

    plt.cla()
    fig = plt.figure(figsize=(8, 6), dpi=120)
    ax = fig.add_subplot()
    plt.scatter(*source_z_latent.T, label='source', marker='.', alpha=0.5)
    plt.scatter(*target_z_latent.T, label='target', marker='.', alpha=0.5)
    lines = make_lines(source_z_latent_corr, target_z_latent_corr)
    lc = mc.LineCollection(lines, linewidths=1, color="k")
    ax.add_collection(lc)
    plt.legend()
    plt.title(f"z_all, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_z_all_{epoch:03d}.png")
    plt.savefig(image_dir / f"{prefix}_z_all.png")

    if experiment:
        experiment.log_asset(image_dir / f"{prefix}_z_all.png", step=epoch)

    # visualize alpha ----------
    source_alpha_latent_corr = source_alpha_latent[:n]
    source_alpha_latent_corr = np.tile(source_alpha_latent_corr[None],
                                       (k, 1, 1)).transpose(1, 0,
                                                            2).reshape(m, -1)
    target_alpha_latent_corr = target_alpha_latent[idxs_]

    plt.cla()
    fig = plt.figure(figsize=(8, 6), dpi=120)
    ax = fig.add_subplot()
    lines = make_lines(source_alpha_latent_corr, target_alpha_latent_corr)
    lc = mc.LineCollection(lines, linewidths=1, color="k")
    plt.scatter(*source_alpha_latent.T, label='source', marker='.', alpha=0.5)
    plt.scatter(*target_alpha_latent.T, label='target', marker='.', alpha=0.5)
    ax.add_collection(lc)
    plt.legend()
    plt.title(f"z_a_all, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_alpha_all_{epoch:03d}.png")
    plt.savefig(image_dir / f"{prefix}_alpha_all.png")

    if experiment:
        experiment.log_asset(image_dir / f"{prefix}_alpha_all.png", step=epoch)

    source_xys = source_states[:, :2]
    # target_xys = target_states[:, :2]
    # if args.target_morph == "maze2d":
    #     target_xys[:, [0, 1]] = (target_xys[:, [1, 0]] - 1) * 4

    if "medium" in args.source_env_id:
        # 領域別プロット z ---------
        # scale = 1 if source_morph == "maze2d" else 4
        # offset = 1 if source_morph == "maze2d" else 0
        scale = 4
        offset = 0
        n = 5
        fig = plt.figure(figsize=(8, 8), dpi=120)
        alpha_corr_srcs = []
        alpha_corr_tgts = []
        for i in range(3):
            for j in range(3):
                try:
                    x_min = (i - 0.25) * 2 * scale + offset
                    x_max = (i + 0.75) * 2 * scale + offset
                    y_min = (j - 0.25) * 2 * scale + offset
                    y_max = (j + 0.75) * 2 * scale + offset

                    source_mask = ((x_min < source_xys[:, 0]) *
                                   (source_xys[:, 0] <= x_max) *
                                   (y_min < source_xys[:, 1]) *
                                   (source_xys[:, 1] <= y_max))
                    target_mask = ((x_min < target_xys[:, 0]) *
                                   (target_xys[:, 0] <= x_max) *
                                   (y_min < target_xys[:, 1]) *
                                   (target_xys[:, 1] <= y_max))

                    # sourceの点から座標が近いtargetの点を探して線を引く．
                    source_xys_masked = source_xys[source_mask][:n]
                    target_xys_masked = target_xys[target_mask]
                    k = 1
                    idxs_ = []
                    for xy in source_xys_masked:
                        idxs = np.argsort(
                            np.sum((target_xys_masked - xy)**2, axis=1))
                        for l in range(k):
                            idxs_.append(idxs[l])

                    m = k * n

                    z_transform_src_corr = source_z_latent[source_mask][:n]
                    z_transform_tgt_corr = target_z_latent[target_mask][idxs_]
                    alpha_transform_src_corr = source_alpha_latent[
                        source_mask][:n]
                    alpha_transform_tgt_corr = target_alpha_latent[
                        target_mask][idxs_]

                    alpha_corr_srcs.append(alpha_transform_src_corr)
                    alpha_corr_tgts.append(alpha_transform_tgt_corr)

                    ax = fig.add_subplot(3, 3, 7 + i - 3 * j)
                    ax.set_title(f"[{x_min}, {x_max}], [{y_min}, {y_max}]")

                    lines = []
                    for ix in range(n):
                        z_s = z_transform_src_corr[ix]
                        for jx in range(k):
                            z_stt = z_transform_tgt_corr[k * ix + jx]
                            lines.append([z_s, z_stt])
                    lc = mc.LineCollection(lines, linewidths=1, color="k")
                    plt.scatter(*source_z_latent[source_mask].T,
                                label='source',
                                marker='.',
                                alpha=0.5)
                    plt.scatter(*target_z_latent[target_mask].T,
                                label='target',
                                marker='.',
                                alpha=0.5)
                    ax.add_collection(lc)
                    plt.xlim(-100, 100)
                    plt.ylim(-100, 100)

                    if i == 2 and j == 2:
                        plt.legend()

                except IndexError as e:
                    print(i, j)
                    traceback.print_exc()
                    continue

        plt.suptitle("z")
        plt.tight_layout()
        plt.savefig(image_dir / f"{prefix}_z domain_{epoch:03d}.png")
        plt.savefig(image_dir / f"{prefix}_z_domain.png")

        if experiment:
            experiment.log_asset(image_dir / f"{prefix}_z_domain.png",
                                 step=epoch)

        # 領域別プロット alpha ---------
        n = 5
        fig = plt.figure(figsize=(8, 8), dpi=120)
        for i in range(3):
            for j in range(3):
                try:
                    x_min = (i - 0.25) * 2 * scale + offset
                    x_max = (i + 0.75) * 2 * scale + offset
                    y_min = (j - 0.25) * 2 * scale + offset
                    y_max = (j + 0.75) * 2 * scale + offset

                    source_mask = (x_min < source_xys[:, 0]) * (
                        source_xys[:, 0] <=
                        x_max) * (y_min < source_xys[:, 1]) * (source_xys[:, 1]
                                                               <= y_max)
                    target_mask = (x_min < target_xys[:, 0]) * (
                        target_xys[:, 0] <=
                        x_max) * (y_min < target_xys[:, 1]) * (target_xys[:, 1]
                                                               <= y_max)

                    # sourceの点から座標が近いtargetの点を探して線を引く．
                    source_xys_ = source_xys[source_mask][:n]
                    target_xys_ = target_xys[target_mask]
                    k = 2
                    idxs_ = []
                    for xy in source_xys_:
                        idxs = np.argsort(np.sum((target_xys_ - xy)**2,
                                                 axis=1))
                        # idxs = torch.flip(idxs, (0, ))
                        for l in range(k):
                            idxs_.append(idxs[l])

                    alpha_transform_src_corr = source_alpha_latent[
                        source_mask][:n]
                    alpha_transform_tgt_corr = target_alpha_latent[
                        target_mask][idxs_]

                    ax = fig.add_subplot(3, 3, 7 + i - 3 * j)
                    ax.set_title(f"[{x_min}, {x_max}], [{y_min}, {y_max}]")

                    lines = []
                    for ix in range(n):
                        z_s = alpha_transform_src_corr[ix]
                        for jx in range(k):
                            z_stt = alpha_transform_tgt_corr[k * ix + jx]
                            lines.append([z_s, z_stt])
                    lc = mc.LineCollection(lines, linewidths=1, color="k")
                    plt.scatter(*source_alpha_latent[source_mask].T,
                                label='source',
                                marker='.',
                                alpha=0.5)
                    plt.scatter(*target_alpha_latent[target_mask].T,
                                label='target',
                                marker='.',
                                alpha=0.5)
                    ax.add_collection(lc)
                    plt.xlim(-100, 100)
                    plt.ylim(-100, 100)

                    if i == 2 and j == 2:
                        plt.legend()

                except IndexError as e:
                    print(i, j)
                    traceback.print_exc()
                    continue

        plt.suptitle("z_a")
        plt.tight_layout()
        plt.savefig(image_dir / f"{prefix}_alpha_domain_{epoch:03d}.png")
        plt.savefig(image_dir / f"{prefix}_alpha_domain.png")

        if experiment:
            experiment.log_asset(image_dir / f"{prefix}_alpha_domain.png",
                                 step=epoch)


def visualize_actions(args, experiment, image_dir, epoch, prefix,
                      actions_source, actions_target):
    # visualize actions ----------
    plt.cla()
    fig = plt.figure(figsize=(6, 6), dpi=120)
    plt.grid()
    plt.scatter(*actions_source[:, :2].T, marker=".")
    plt.title(f"source actions, {args.source_env_id}, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_actions_source.png")

    plt.cla()
    fig = plt.figure(figsize=(6, 6), dpi=120)
    plt.grid()
    plt.scatter(*actions_target[:, :2].T, marker=".")
    plt.title(f"target actions, {args.target_env_id}, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_actions_target.png")

    if experiment:
        experiment.log_asset(image_dir / f"{prefix}_actions_source.png",
                             step=epoch)
        experiment.log_asset(image_dir / f"{prefix}_actions_target.png",
                             step=epoch)


def visualize_states(args, experiment, image_dir, epoch, prefix, states_source,
                     states_target):
    # visualize states ----------
    plt.rcParams["font.size"] = 12
    plt.cla()
    fig = plt.figure(figsize=(6, 6), dpi=120)
    plt.grid()
    plt.scatter(*states_source[:, :2].T, marker=".")
    plt.title(f"source states, {args.source_env_id}, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_states_source.png")

    plt.cla()
    fig = plt.figure(figsize=(6, 6), dpi=120)
    plt.grid()
    plt.scatter(*states_target[:, :2].T, marker=".")
    plt.title(f"target states, {args.target_env_id}, epoch:{epoch}")
    plt.tight_layout()
    plt.savefig(image_dir / f"{prefix}_states_target.png")

    if experiment:
        experiment.log_asset(image_dir / f"{prefix}_states_source.png",
                             step=epoch)
        experiment.log_asset(image_dir / f"{prefix}_states_target.png",
                             step=epoch)


def make_lines(
    source_latents_corr: np.ndarray,
    target_latents_corr: np.ndarray,
):
    n = len(source_latents_corr)
    m = len(target_latents_corr)
    k = m // n

    lines = []
    for i in range(n):
        z_s = source_latents_corr[i]
        for j in range(k):
            z_stt = target_latents_corr[k * i + j]
            lines.append([z_s, z_stt])
    return lines


def transform_latents(
    x_source: np.ndarray,
    x_target: np.ndarray,
):
    tsne = TSNE(
        n_components=2,
        random_state=0,
        perplexity=30,
        n_iter=1000,
        verbose=0,
    )

    len_source = len(x_source)

    z_transform = tsne.fit_transform(np.concatenate((x_source, x_target)))
    z_transform_source = z_transform[:len_source]
    z_transform_target = z_transform[len_source:]

    return z_transform_source, z_transform_target
