import cv2
import numpy as np
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds


def set_requires_grad(module: nn.Module, requires_grad: bool):
    for p in module.parameters():
        p.requires_grad_(requires_grad)


def compute_distance_transform(mask: torch.Tensor):
    image_size = mask.shape[-1]
    distance_transform = (
        torch.stack(
            [
                torch.from_numpy(
                    cv2.distanceTransform(
                        (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
                    )
                    / (image_size / 2)
                )
                for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8)
            ]
        )
        .unsqueeze(1)
        .clip(0, 1)
        .to(mask.device)
    )
    return distance_transform


def default(x, d):
    return d if x is None else x


def get_num_points(x):
    if isinstance(x, Pointclouds):
        return x.points_padded().shape[1]
    else:
        assert len(x.shape) == 3
        return x.shape[1]


def get_custom_betas(
    beta_start: float,
    beta_end: float,
    warmup_frac: float = 0.3,
    num_train_timesteps: int = 1000,
):
    """Custom beta schedule"""
    betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
    warmup_frac = 0.3
    warmup_time = int(num_train_timesteps * warmup_frac)
    warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
    warmup_time = min(warmup_time, num_train_timesteps)
    betas[:warmup_time] = warmup_steps[:warmup_time]
    return betas
