"""
Functions for visualizing vectors.
"""

import logging
from typing import Union, Optional

import numpy as np
import torch

from utils.visualization.preprocess import align_shapes

logger = logging.getLogger('custom')
usetex = True


def _scatter(ax, x, **kwargs):
    """ Creates single scatter plot. """
    x1, x2 = x[:, 0], x[:, 1]
    ax.scatter(x1, x2, **kwargs)


def scatter(ax,
            x: Union[np.array, torch.Tensor],
            y: Optional[Union[np.array, torch.Tensor]] = None,
            color_dict: dict = None,
            **kwargs):
    """ Creates scatter plots
    :param ax: matplotlib axes
    :param x: samples
    :param y: labels
    :param color_dict: optional, dict mapping labels to colors
    :param kwargs: matplotlib properties
    """
    x, y = align_shapes(x, y)
    if y is not None:
        color = kwargs.pop('color', None)
        for cur_y in np.unique(y):
            idx = np.where(cur_y == y)
            cur_x = np.unique(x[idx], axis=0)  # only plot unique instances
            cur_color = color_dict[cur_y] if color_dict else color
            _scatter(ax, cur_x, label=cur_y, color=cur_color, **kwargs)
    else:
        if color_dict and y is not None:
            kwargs['color'] = color_dict[None]
        _scatter(ax, x, **kwargs)
    return ax
