from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
from torch.utils.data import Dataset

from group_discovery.utils import fig_to_img


class MNISTVectorDataset(Dataset):
    def __init__(self, path, return_transform=False):
        path = Path(path)
        data = np.load(path)
        self.data = data["data"].astype(np.float32)
        self.labels = data["labels"].astype(np.int64)
        self.transforms = data["transform"].astype(np.float32)

        self.return_transform = return_transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        x = self.data[idx]

        if self.return_transforms:
            return x, self.transforms[idx]
        else:
            return x


def plot_MNIST_vectors(x):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_aspect("equal")
    ax.invert_yaxis()
    ax.set(xlim=(-1, 1), ylim=(-1, 1), xticks=[], yticks=[])
    ax.scatter(x[:, 0], x[:, 1], color="k")

    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_MNIST_grid(x, titles=None):
    nr, nc = x.shape[0], x.shape[1]
    fig = plt.figure(figsize=(1.5 * nc, 1.5 * nr), dpi=100)
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(nr, nc),
        axes_pad=(0.05, 0.1),
        share_all=True,
    )

    for i in range(nr * nc):
        grid[i].set_aspect("equal")
        grid[i].invert_yaxis()
        grid[i].set(xlim=(-1, 1), ylim=(-1, 1), xticks=[], yticks=[])

        grid[i].scatter(
            x[i // nc, i % nc, :, 0], x[i // nc, i % nc, :, 1], color="k", s=1
        )

        if titles is not None and i // nc == 0:
            grid[i].set_title(titles[i])

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_histogram(angles):
    fig, ax = plt.subplots(figsize=(8, 8))
    bins = np.linspace(-180, 180, 73) - 0.5  # every 5 degrees
    ax.hist(angles, bins=bins, color="k", alpha=0.7)
    ax.set(
        xlabel="Angle (degrees)",
        ylabel="Count",
        xticks=[-180, -135, -90, -45, 0, 45, 90, 135, 180],
        xlim=[-180, 180],
    )

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb
