"""Few-shot MNIST dataset."""
import os

import wandb
import torch
import matplotlib.pyplot as plt
import torchvision as tv
import normflows as nf
import numpy as np

from calnf.datasets.dataset import Dataset
from calnf.datasets.utils import UnlabelledDataset, FlattenTransform


class FewShotMNIST(Dataset):
    """MNIST dataset adapted to few-shot learning.

    Includes "8" as the nominal class and "9" as the target distribution.

    Args:
        data_root: Root directory for the data.
        n_target: Number of target examples.
        include_all_nominal: Whether to use all non-target classes or just "8" in the nominal
            distribution.
    """

    N_SAMPLES_TO_VISUALIZE = 5

    def __init__(
        self,
        data_root: str = "/datadrive/cbd/datasets",
        n_target: int = 64,
        include_all_nominal: bool = False,
    ):
        super().__init__()
        self.data_root = data_root
        self.n_target = n_target
        self.include_all_nominal = include_all_nominal

        # Check for access to data_root
        if not (
            os.path.exists(data_root)
            and os.path.isdir(data_root)
            and os.access(data_root, os.R_OK)
            and os.access(data_root, os.W_OK)
        ):
            self.data_root = "./data"

    @property
    def obs_dims(self):
        return 32 * 32

    @property
    def latent_dims(self):
        return 32 * 32

    @property
    def score_scale(self):
        """Return the scaling from ELBO to score function."""
        return 1 / self.latent_dims

    def transforms(self):
        return tv.transforms.Compose(
            [
                tv.transforms.ToTensor(),
                tv.transforms.Resize(32, antialias=True),
                nf.utils.Scale(255.0 / 256.0),
                nf.utils.Jitter(1 / 256.0),
                FlattenTransform(),
            ]
        )

    def configure_nominal_data(self, device):
        """Configure the nominal dataloader.

        This method must create two DataLoader instances and save them as attributes:
            - self.nominal_train_loader: DataLoader for training nominal data.
            - self.nominal_test_loader: DataLoader for test nominal data.

        Args:
            device (torch.device): Device to use for the data.
        """
        # Load the datasets
        nominal_train_data = tv.datasets.MNIST(
            self.data_root,
            train=True,
            download=True,
            transform=self.transforms(),
        )
        nominal_test_data = tv.datasets.MNIST(
            self.data_root,
            train=False,
            download=True,
            transform=self.transforms(),
        )

        if not self.include_all_nominal:
            # Filter to only "8" images (these will be nominal examples)
            train_eight_indices = np.where(np.array(nominal_train_data.targets) == 8)[0]
            nominal_train_data.data = nominal_train_data.data[train_eight_indices]
            test_eight_indices = np.where(np.array(nominal_test_data.targets) == 8)[0]
            nominal_test_data.data = nominal_test_data.data[test_eight_indices]

        # Create the dataloaders
        self.nominal_train_loader = torch.utils.data.DataLoader(
            UnlabelledDataset(nominal_train_data), batch_size=32, shuffle=True
        )
        self.nominal_test_loader = torch.utils.data.DataLoader(
            UnlabelledDataset(nominal_test_data), batch_size=64, shuffle=False
        )

    def configure_target_data(self, device):
        """Configure the target dataloader.

        This method must create two DataLoader instances and save them as attributes:
            - self.target_train_loader: DataLoader for training target data.
            - self.target_test_loader: DataLoader for test target data.

        Args:
            device (torch.device): Device to use for the data.
        """
        # Load the datasets
        target_train_data = tv.datasets.MNIST(
            self.data_root,
            train=True,
            download=True,
            transform=self.transforms(),
        )
        target_test_data = tv.datasets.MNIST(
            self.data_root,
            train=False,
            download=True,
            transform=self.transforms(),
        )

        # Filter to only "9" images (these will be target examples)
        train_nine_indices = np.where(np.array(target_train_data.targets) == 9)[0]
        train_nine_indices = train_nine_indices[
            torch.randperm(len(train_nine_indices))[: self.n_target]
        ].reshape(-1)
        target_train_data.data = target_train_data.data[train_nine_indices]
        test_nine_indices = np.where(np.array(target_test_data.targets) == 9)[0]
        target_test_data.data = target_test_data.data[test_nine_indices]

        # Create the dataloaders
        self.target_train_loader = torch.utils.data.DataLoader(
            UnlabelledDataset(target_train_data), batch_size=64, shuffle=True
        )
        self.target_test_loader = torch.utils.data.DataLoader(
            UnlabelledDataset(target_test_data), batch_size=64, shuffle=False
        )

    def single_particle_elbo(
        self, dist: torch.distributions.Distribution, _: int, obs: torch.Tensor
    ) -> torch.Tensor:
        """Compute the ELBO for a single particle.

        This is actually just the data logprob since this isn't an inverse problem.

        Args:
            dist (torch.distributions.Distribution): Distribution over latent variables.
            n (int): Number of observations (unused).
            obs (torch.Tensor): Observations.

        Returns:
            torch.Tensor: ELBO for a single particle.
        """
        data_likelihood = dist.log_prob(obs).mean()
        return data_likelihood

    @torch.no_grad()
    def visualize(
        self,
        nominal_dist: torch.distributions.Distribution,
        target_dist: torch.distributions.Distribution,
    ):
        """Visualize the distribution over latent variables.

        Should save to wandb without committing and close the figure.

        Args:
            nominal_dist (torch.distributions.Distribution): Distribution over latent variables.
            target_dist (torch.distributions.Distribution): Distribution over latent variables.
        """
        fig, axs = plt.subplots(
            self.N_SAMPLES_TO_VISUALIZE,
            2,
            figsize=(5, 2.5 * self.N_SAMPLES_TO_VISUALIZE),
        )

        # Plot nominal and target samples
        for i, ax in enumerate(axs):
            nominal_sample = nominal_dist.sample()
            target_sample = target_dist.sample()

            ax[0].imshow(nominal_sample.cpu().numpy().reshape(32, 32), cmap="gray")
            ax[0].axis("off")

            ax[1].imshow(target_sample.cpu().numpy().reshape(32, 32), cmap="gray")
            ax[1].axis("off")

            if i == 0:
                ax[0].set_title("Nominal")
                ax[1].set_title("Target")

        plt.subplots_adjust(wspace=0.01, hspace=0.01)

        wandb.log({"Posteriors": wandb.Image(fig)}, commit=False)

        plt.close()


if __name__ == "__main__":
    dataset = FewShotMNIST()
    dataset.configure_nominal_data(torch.device("cpu"))
    dataset.configure_target_data(torch.device("cpu"))

    # Make images of the first batch of the nominal and target training sets
    nominal_batch = next(iter(dataset.nominal_train_loader))
    target_batch = next(iter(dataset.target_train_loader))

    N_SAMPLES_TO_VISUALIZE = 5
    fig, axs = plt.subplots(
        N_SAMPLES_TO_VISUALIZE, 2, figsize=(5, 2.5 * N_SAMPLES_TO_VISUALIZE)
    )

    for i, ax in enumerate(axs):
        ax[0].imshow(nominal_batch[i].numpy().reshape(32, 32), cmap="gray")
        ax[0].axis("off")

        ax[1].imshow(target_batch[i].numpy().reshape(32, 32), cmap="gray")
        ax[1].axis("off")

        if i == 0:
            ax[0].set_title("Nominal")
            ax[1].set_title("Target")

    plt.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.savefig("mnist_samples.png")
