import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch.optim as optim
import matplotlib.pyplot as plt
import scipy.ndimage
from torch.utils.data import Dataset, DataLoader


import utils.aliasfree_dct

# ---------------------------------------------------------------------

def get_interpolation_mode(interpolation="bilinear"):
    if interpolation == "bilinear":
        return InterpolationMode.BILINEAR
    if interpolation == "bicubic":
        return InterpolationMode.BICUBIC
    raise NotImplementedError(interpolation)


def get_mgrid(sidelen: int) -> torch.Tensor:
    # Matches the repo: coords in [-0.5, 0.5) with step 1/sidelen
    pixel_coords = np.stack(np.mgrid[:sidelen, :sidelen], axis=-1)[None, ...].astype(np.float32)
    pixel_coords /= sidelen
    pixel_coords -= 0.5
    return torch.tensor(pixel_coords).view(-1, 2)  # (H*W, 2)

# ---------------------------------------------------------------------

class SignedDistanceTransform:
    def __call__(self, img_tensor: torch.Tensor):
        """
        img_tensor: Tensor in [0,1], shape (1,H,W) from ToTensor()

        Returns:
            signed_distances: Tensor shape (1,H,W) [float32]
            binary_image:     Tensor shape (1,H,W) {0,1} [float32]
        """
        # Threshold
        img_tensor = img_tensor.clone()
        img_tensor[img_tensor < 0.5] = 0.0
        img_tensor[img_tensor >= 0.5] = 1.0

        img_np = img_tensor.numpy()  # (1,H,W), float

        # scipy wants ndarray; edt works fine with (1,H,W) too
        neg_distances = scipy.ndimage.distance_transform_edt(img_np)

        sd_img = (img_np - 1.0).astype(np.uint8)  # inside digit -> 0, outside -> 255
        signed_distances = scipy.ndimage.distance_transform_edt(sd_img) - neg_distances

        # Normalize by width 
        signed_distances /= float(img_np.shape[-1])

        signed_distances = torch.tensor(signed_distances, dtype=torch.float32)
        binary_image = torch.tensor(img_np, dtype=torch.float32)
        return signed_distances, binary_image

# ---------------------------------------------------------------------

class MNISTSDFDataset(Dataset):
    def __init__(
        self,
        root: str = "./data",
        train: bool = True,
        download: bool = True,
        size=(32, 32),
        interpolation: str = "bilinear",
        antialias=None,
        return_image: bool = False,
    ):
        self.transform = transforms.Compose([
            transforms.Resize(size, interpolation=get_interpolation_mode(interpolation), antialias=antialias),
            transforms.ToTensor(),
            SignedDistanceTransform(),
        ])
        self.img_dataset = torchvision.datasets.MNIST(root=root, train=train, download=download)
        self.meshgrid = get_mgrid(size[0])
        self.return_image = return_image
        self.size = size

    def __len__(self):
        return len(self.img_dataset)

    def __getitem__(self, idx: int):
        img, digit_class = self.img_dataset[idx]
        signed_distance_img, binary_image = self.transform(img)  # both (1,H,W)

        if self.return_image:
            # Grid form: (1,H,W), label
            return signed_distance_img, digit_class

        # Flattened form: (H*W,1), (H*W,2)
        coord_values = self.meshgrid  # (H*W,2)
        signed_distance_values = signed_distance_img.reshape(-1, 1)  # (H*W,1)
        return signed_distance_values, coord_values

# ---------------------------------------------------------------------








def get_mnistsdf_batch(
    B: int,
    root: str = "./data",
    train: bool = True,
    size=(32, 32),
    return_image: bool = False,
    num_workers: int = 0,
    shuffle: bool = True,
):
    ds = MNISTSDFDataset(root=root, train=train, download=True, size=size, return_image=return_image)

    loader = DataLoader(ds, batch_size=B, shuffle=shuffle, num_workers=num_workers, drop_last=True)
    batch = next(iter(loader))

    if return_image:
        # batch = (signed_distance_img, labels)
        # signed_distance_img: (B,1,H,W), labels: (B,)
        return batch

    # batch = (signed_distance_values, coord_values)
    # signed_distance_values: (B,H*W,1)
    # coord_values:          (B,H*W,2)  (DataLoader stacks meshgrids)
    return batch
