from comet_ml import Experiment

# isort: split

import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import yaml
from omegaconf import OmegaConf
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader

import d4rl
from common.ours.models import Policy
from common.utils.process_dataset import get_obs_converter

warnings.simplefilter(action='ignore', category=FutureWarning)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def visualize_latent(args,
                     policy: Policy,
                     data_loader: DataLoader,
                     epoch: int,
                     savedir: Path,
                     n_steps: int = 3000,
                     n_lines: int = 20,
                     shuffle: bool = False,
                     metrics_only: bool = True,
                     show: bool = False,
                     experiment: Optional[Experiment] = None):
    if args.get("no_visualize_latent", False):
        return
    if args.n_domains == 2 and 'point' in args.domains[0].env_tag and 'point' in args.domains[1].env_tag \
        and not hasattr(args.domains[1], 'obs_converter'):
        visualize_latent_p2p(
            args=args,
            policy=policy,
            data_loader=data_loader,
            epoch=epoch,
            savedir=savedir,
            n_steps=n_steps,
            n_lines=n_lines,
            shuffle=shuffle,
            metrics_only=metrics_only,
            show=show,
            experiment=experiment,
        )
    elif 'Lift' in args.domains[0].env_tag and 'Lift' in args.domains[
            1].env_tag:
        visualize_latent_robot(
            args=args,
            policy=policy,
            data_loader=data_loader,
            epoch=epoch,
            savedir=savedir,
            n_steps=n_steps,
            n_lines=n_lines,
            shuffle=shuffle,
            metrics_only=metrics_only,
            show=show,
            experiment=experiment,
        )
    else:
        visualize_latent_others(
            args=args,
            policy=policy,
            data_loader=data_loader,
            epoch=epoch,
            savedir=savedir,
            n_steps=n_steps,
            n_lines=0,
            shuffle=shuffle,
            metrics_only=metrics_only,
            show=show,
            experiment=experiment,
        )


def visualize_latent_others(
    args,
    policy: Policy,
    data_loader: DataLoader,
    epoch: int,
    savedir: Path,
    n_steps: int = 3000,
    n_lines: int = 20,
    shuffle: bool = False,
    metrics_only: bool = True,
    show: bool = False,
    experiment: Optional[Experiment] = None,
):
    if metrics_only:
        return
    logger.info('Visualization started.')

    policy.eval()

    if obs_converter_name := args.domains[0].get('obs_converter'):
        if trans_args := args.domains[0].get('obs_converter_args'):
            obs_converter = get_obs_converter(name=obs_converter_name,
                                              **trans_args)
        else:
            obs_converter = get_obs_converter(name=obs_converter_name)
    else:
        obs_converter = None

    source_s_list = []
    target_s_list = []
    source_z_list = []
    target_z_list = []
    data_cnt = 0

    for batch in data_loader:
        target_flag = (batch['domain_ids'][..., -1:] > .5).flatten()
        target_states = batch['observations'][target_flag].to(args.device)
        if args.image_observation:
            target_images = batch['images'][target_flag].to(args.device)
            with torch.inference_mode():
                target_images_encoded = policy.image_encoder(
                    target_images.float())
            target_states[..., -args.image_state_dim:] = target_images_encoded
        target_domain_ids = batch['domain_ids'][target_flag].to(args.device)
        target_inp = torch.cat((target_states, target_domain_ids), dim=-1)

        source_states = batch['observations'][~target_flag].to(args.device)
        if args.image_observation:
            source_images = batch['images'][~target_flag].to(args.device)
            with torch.inference_mode():
                source_images_encoded = policy.image_encoder(
                    source_images.float())
            source_states[..., -args.image_state_dim:] = source_images_encoded
        source_domain_ids = batch['domain_ids'][~target_flag].to(args.device)
        if obs_converter is not None:
            source_states = obs_converter(source_states.cpu().numpy())
            source_states = torch.from_numpy(source_states).to(args.device)
        source_inp = torch.cat((source_states, source_domain_ids), dim=-1)

        with torch.no_grad():
            target_z = policy.encoder(target_inp)
            source_z = policy.encoder(source_inp)

        source_z_list.append(source_z.cpu().numpy())
        target_z_list.append(target_z.cpu().numpy())
        source_s_list.append(source_states.cpu().numpy())
        target_s_list.append(target_states.cpu().numpy())

        data_cnt += len(source_z)
        if data_cnt > n_steps:
            break

    source_s_list = np.concatenate(source_s_list)
    target_s_list = np.concatenate(target_s_list)
    source_z_list = np.concatenate(source_z_list)
    target_z_list = np.concatenate(target_z_list)

    if shuffle:
        idx = np.arange(len(source_s_list))
        np.random.shuffle(idx)
        source_s_list = source_s_list[idx]
        source_z_list = source_z_list[idx]

        idx = np.arange(len(target_s_list))
        np.random.shuffle(idx)
        target_s_list = target_s_list[idx]
        target_z_list = target_z_list[idx]

    # ===================TSNE========================
    data_size = len(source_z_list)
    all_z_list = np.concatenate((source_z_list, target_z_list))

    tsne = TSNE(n_components=2, random_state=0, perplexity=30, n_iter=1000)
    z_transform = tsne.fit_transform(all_z_list)
    latent_max = z_transform.max(0)
    latent_min = z_transform.min(0)

    z_transform_source = z_transform[:data_size]
    z_transform_target = z_transform[data_size:]

    # ===================ALL========================
    def plot_all(legend: bool = False):
        plt.clf()
        fig = plt.figure(figsize=(7, 7))
        marker_size = 100
        plt.scatter(*z_transform_source.T,
                    label='Source',
                    marker='.',
                    s=marker_size)
        plt.scatter(*z_transform_target.T,
                    label='Target',
                    marker='.',
                    s=marker_size)
        plt.xlim([latent_min[0], latent_max[0]])
        plt.ylim([latent_min[1], latent_max[1]])

        if legend:
            plt.legend(prop={'size': 20})
        # plt.title(f'Epoch: {epoch} (All)\n{savedir}')
        plt.tick_params(labelbottom=False, labelleft=False, length=0)
        fig.tight_layout()

        if legend:
            img_path = savedir / f'all_with_legend.png'
        else:
            img_path = savedir / f'all.png'

        plt.savefig(img_path)

        if experiment is not None:
            experiment.log_image(img_path, step=epoch)
        if show:
            plt.show()

    plot_all(True)
    plot_all(False)

    if show:
        plt.show()

    policy.train()
    logger.info('Visualization finished.')


def visualize_latent_robot(
    args,
    policy: Policy,
    data_loader: DataLoader,
    epoch: int,
    savedir: Path,
    n_steps: int = 3000,
    n_lines: int = 20,
    shuffle: bool = False,
    metrics_only: bool = True,
    show: bool = False,
    experiment: Optional[Experiment] = None,
):
    if metrics_only:
        return
    logger.info('Visualization started.')

    policy.eval()

    if obs_converter_name := args.domains[0].get('obs_converter'):
        if trans_args := args.domains[0].get('obs_converter_args'):
            obs_converter = get_obs_converter(name=obs_converter_name,
                                              **trans_args)
        else:
            obs_converter = get_obs_converter(name=obs_converter_name)
    else:
        obs_converter = None

    source_s_list = []
    target_s_list = []
    source_z_list = []
    target_z_list = []
    data_cnt = 0

    for batch in data_loader:
        target_flag = (batch['domain_ids'][..., -1:] > .5).flatten()
        target_states = batch['observations'][target_flag].to(args.device)
        if args.image_observation:
            target_images = batch['images'][target_flag].to(args.device)
            with torch.inference_mode():
                target_images_encoded = policy.image_encoder(
                    target_images.float())
            target_states[..., -args.image_state_dim:] = target_images_encoded
        target_domain_ids = batch['domain_ids'][target_flag].to(args.device)
        target_task_ids = batch['task_ids'][target_flag].to(args.device)
        target_inp = torch.cat((target_states, target_domain_ids), dim=-1)

        source_states = batch['observations'][~target_flag].to(args.device)
        if args.image_observation:
            source_images = batch['images'][~target_flag].to(args.device)
            with torch.inference_mode():
                source_images_encoded = policy.image_encoder(
                    source_images.float())
            source_states[..., -args.image_state_dim:] = source_images_encoded
        source_domain_ids = batch['domain_ids'][~target_flag].to(args.device)
        source_task_ids = batch['task_ids'][~target_flag].to(args.device)
        if obs_converter is not None:
            source_states = obs_converter(source_states.cpu().numpy())
            source_states = torch.from_numpy(source_states).to(args.device)
        source_inp = torch.cat((source_states, source_domain_ids), dim=-1)

        with torch.no_grad():
            _, target_z, target_alpha = policy(target_states, target_task_ids,
                                               target_domain_ids)
            _, source_z, source_alpha = policy(source_states, source_task_ids,
                                               source_domain_ids)

            # target_alpha_d = torch.cat(
            #     (target_alpha, target_domain_ids, target_states), dim=-1)
            # source_alpha_d = torch.cat(
            #     (source_alpha, source_domain_ids, source_states), dim=-1)
            # head = policy.head.net
            # n = 1

        source_z_list.append(source_z.cpu().numpy())
        target_z_list.append(target_z.cpu().numpy())
        source_s_list.append(source_states.cpu().numpy())
        target_s_list.append(target_states.cpu().numpy())

        data_cnt += len(source_z)
        if data_cnt > n_steps:
            break

    source_s_list = np.concatenate(source_s_list)
    target_s_list = np.concatenate(target_s_list)
    source_z_list = np.concatenate(source_z_list)
    target_z_list = np.concatenate(target_z_list)

    if shuffle:
        idx = np.arange(len(source_s_list))
        np.random.shuffle(idx)
        source_s_list = source_s_list[idx]
        source_z_list = source_z_list[idx]

        idx = np.arange(len(target_s_list))
        np.random.shuffle(idx)
        target_s_list = target_s_list[idx]
        target_z_list = target_z_list[idx]

    # ===================TSNE========================
    data_size = len(source_z_list)
    all_z_list = np.concatenate((source_z_list, target_z_list))

    tsne = TSNE(n_components=2, random_state=0, perplexity=30, n_iter=1000)
    z_transform = tsne.fit_transform(all_z_list)
    latent_max = z_transform.max(0)
    latent_min = z_transform.min(0)

    z_transform_source = z_transform[:data_size]
    z_transform_target = z_transform[data_size:]

    # ===================ALL========================
    # Find corresponding points
    source_sample_pos = source_s_list[:n_lines, 21:24]
    target_pos = target_s_list[:, 18:21]

    source_idxs = np.arange(n_lines)
    target_idxs = []
    for pos in source_sample_pos:
        distance = np.sqrt(np.sum((target_pos - pos)**2, axis=1))
        idx = np.argmin(distance)
        target_idxs.append(idx)
    target_idxs = np.array(target_idxs, dtype=np.int32)

    def plot_all(legend: bool = False):
        plt.clf()
        fig = plt.figure(figsize=(7, 7))
        marker_size = 100
        plt.scatter(*z_transform_source.T,
                    label='Source',
                    marker='.',
                    s=marker_size)
        plt.scatter(*z_transform_target.T,
                    label='Target',
                    marker='.',
                    s=marker_size)
        plt.xlim([latent_min[0], latent_max[0]])
        plt.ylim([latent_min[1], latent_max[1]])

        # draw lines
        for k in range(n_lines):
            first = z_transform_source[source_idxs[k]]
            second = z_transform_target[target_idxs[k]]
            plt.plot((first[0], second[0]), (first[1], second[1]),
                     color='black',
                     lw=1)

        if legend:
            plt.legend(prop={'size': 20})
        # plt.title(f'Epoch: {epoch} (All)\n{savedir}')
        plt.tick_params(labelbottom=False, labelleft=False, length=0)
        fig.tight_layout()

        if legend:
            img_path = savedir / f'all_with_legend.png'
        else:
            img_path = savedir / f'all.png'

        plt.savefig(img_path)

        if experiment is not None:
            experiment.log_image(img_path, step=epoch)
        if show:
            plt.show()

    plot_all(True)
    plot_all(False)

    if show:
        plt.show()

    policy.train()
    logger.info('Visualization finished.')


def visualize_latent_p2p(args,
                         policy: Policy,
                         data_loader: DataLoader,
                         epoch: int,
                         savedir: Path,
                         n_steps: int = 3000,
                         n_lines: int = 20,
                         shuffle: bool = False,
                         metrics_only: bool = True,
                         show: bool = False,
                         experiment: Optional[Experiment] = None):
    # NOTE! It assumes there are only two point environments,
    # where domain1 has no observation translation.
    logger.info('Visualization started.')

    policy.eval()

    if obs_converter_name := args.domains[0].get('obs_converter'):
        if trans_args := args.domains[0].get('obs_converter_args'):
            obs_converter = get_obs_converter(name=obs_converter_name,
                                              **trans_args)
        else:
            obs_converter = get_obs_converter(name=obs_converter_name)
    else:
        obs_converter = None

    source_s_list = []
    target_s_list = []
    source_z_list = []
    target_z_list = []

    data_cnt = 0
    for batch in data_loader:
        target_flag = (batch['domain_ids'][..., -1:] > .5).flatten()
        target_states = batch['observations'][target_flag].to(args.device)
        target_domain_ids = batch['domain_ids'][target_flag].to(args.device)
        target_inp = torch.cat((target_states, target_domain_ids), dim=-1)

        source_states = target_states
        if obs_converter is not None:
            source_states = obs_converter(source_states.cpu().numpy())
            source_states = torch.from_numpy(source_states).to(args.device)
        source_domain_ids = target_domain_ids[..., [1, 0]]
        source_inp = torch.cat((source_states, source_domain_ids), dim=-1)

        with torch.no_grad():
            target_z = policy.encoder(target_inp)
            source_z = policy.encoder(source_inp)

        source_z_list.append(source_z.cpu().numpy())
        target_z_list.append(target_z.cpu().numpy())
        source_s_list.append(source_states.cpu().numpy())
        target_s_list.append(target_states.cpu().numpy())

        data_cnt += len(source_z)
        if data_cnt > n_steps:
            break

    source_s_list = np.concatenate(source_s_list)
    target_s_list = np.concatenate(target_s_list)
    source_z_list = np.concatenate(source_z_list)
    target_z_list = np.concatenate(target_z_list)

    if shuffle:
        idx = np.arange(len(source_z_list))
        np.random.shuffle(idx)
        source_s_list = source_s_list[idx]
        target_s_list = target_s_list[idx]
        source_z_list = source_z_list[idx]
        target_z_list = target_z_list[idx]

    # metrics
    source_norm = np.linalg.norm(source_z_list, axis=-1, keepdims=True)
    target_norm = np.linalg.norm(target_z_list, axis=-1, keepdims=True)
    source_z_normalized = source_z_list / source_norm
    target_z_normalized = target_z_list / target_norm
    cosine_sim = (source_z_normalized *
                  target_z_normalized).sum(axis=-1).mean()

    norm_vector_dist = np.linalg.norm(source_z_normalized -
                                      target_z_normalized,
                                      axis=-1).mean()

    dists = np.linalg.norm(source_z_list - target_z_list,
                           axis=-1,
                           keepdims=True)
    norm_two = (source_norm + target_norm) / 2
    norm_all = norm_two.mean()
    # dist_norm_two = (dists / norm_two).mean()
    dist_norm_all = (dists / norm_all).mean()

    if experiment is not None:
        metrics_dict = {
            'repr_norm': norm_all,
            'latent_dist': dist_norm_all,
        }
        experiment.log_metrics(metrics_dict, step=epoch)

    if metrics_only:
        logger.info('Metrics logging finished.')
        return

    sns.set_context('paper', font_scale=2)

    # ===================TSNE========================
    data_size = len(source_z_list)
    all_z_list = np.concatenate((source_z_list, target_z_list))

    tsne = TSNE(n_components=2, random_state=0, perplexity=30, n_iter=1000)
    z_transform = tsne.fit_transform(all_z_list)
    latent_max = z_transform.max(0)
    latent_min = z_transform.min(0)

    z_transform_source = z_transform[:data_size]
    z_transform_target = z_transform[data_size:]

    # ===================ALL========================

    def plot_all(legend: bool = False):
        plt.clf()
        fig = plt.figure(figsize=(7, 7))
        marker_size = 100
        plt.scatter(*z_transform_source.T,
                    label='Source',
                    marker='.',
                    s=marker_size)
        plt.scatter(*z_transform_target.T,
                    label='Target',
                    marker='.',
                    s=marker_size)
        plt.xlim([latent_min[0], latent_max[0]])
        plt.ylim([latent_min[1], latent_max[1]])

        # draw lines
        for k in range(n_lines):
            first = z_transform_source[k]
            second = z_transform_target[k]
            plt.plot((first[0], second[0]), (first[1], second[1]),
                     color='black',
                     lw=2)
        if legend:
            plt.legend(prop={'size': 20})
        # plt.title(f'Epoch: {epoch} (All)\n{savedir}')
        plt.tick_params(labelbottom=False, labelleft=False, length=0)
        fig.tight_layout()

        if legend:
            img_path = savedir / f'all_with_legend.png'
        else:
            img_path = savedir / f'all.png'

        plt.savefig(img_path)

        if experiment is not None:
            experiment.log_image(img_path, step=epoch)
        if show:
            plt.show()

    plot_all(legend=False)
    plot_all(legend=True)

    # ===================AREA========================
    plt.clf()

    fig = plt.figure(figsize=(10, 10))

    if 'medium' in args.domains[0].env_tag:
        xrange_list = [(0.5, 2.5), (2.5, 4.5), (4.5, 6.5)]  # state
        yrange_list = [(0.5, 2.5), (2.5, 4.5), (4.5, 6.5)]  # state
    elif 'umaze' in args.domains[0].env_tag:
        xrange_list = [(0.5, 1.5), (1.5, 2.5), (2.5, 3.5)]  # state
        yrange_list = [(0.5, 1.5), (1.5, 2.5), (2.5, 3.5)]  # state
    elif 'large' in args.domains[0].env_tag:
        xrange_list = [(0.5, 3.0), (3.0, 5.5), (5.5, 7.5)]  # state
        yrange_list = [(0.5, 4.0), (4.0, 7.5), (7.5, 10.5)]  # state
    else:
        raise ValueError(f'{args.domains[0].env_tag} is not a valid env name')

    for i in range(9):
        xrange = xrange_list[i % 3]
        yrange = yrange_list[i // 3]

        xval = target_s_list[..., 0]
        xflag = (xrange[0] < xval) & (xval < xrange[1])
        yval = target_s_list[..., 1]
        yflag = (yrange[0] < yval) & (yval < yrange[1])
        flag = xflag & yflag

        def mse_loss(source: np.ndarray, target: np.ndarray):
            return ((source - target)**2).mean()

        distance = mse_loss(source_z_list[flag], target_z_list[flag])

        ax = fig.add_subplot(3, 3, i + 1)
        ax.scatter(*z_transform_source[flag].T, label='source', marker='.')
        ax.scatter(*z_transform_target[flag].T, label='target', marker='.')
        ax.set_xlim([latent_min[0], latent_max[0]])
        ax.set_ylim([latent_min[1], latent_max[1]])
        ax.set_title(f'x = {xrange}, y = {yrange}\n loss = {distance:.5f}')

        # draw lines
        for k in range(n_lines // 2):
            try:
                first = z_transform_source[flag][k]
                second = z_transform_target[flag][k]
                plt.plot((first[0], second[0]), (first[1], second[1]),
                         color='black',
                         lw=1)
            except Exception as e:
                print(e)

        if i == 0:
            ax.legend()

    fig.suptitle(f'Epoch: {epoch} (Area)\n {savedir}')
    fig.tight_layout()
    img_path = savedir / f'area.png'
    plt.savefig(img_path)
    if experiment is not None:
        experiment.log_image(img_path, step=epoch)

    if show:
        plt.show()

    policy.train()
    logger.info('Visualization finished.')
