import smlmsim
import torch
from torch import Generator, Tensor

import smlm


def sample_fluorophores(
    N: int,
    time_on: float,
    time_off: float,
    time_bleach: float,
    n_periods_max: int,
    n_frames: int,
    photon_flux_mean: float,
    photon_flux_std: float,
    vol_extent: Tensor,
    gen: Generator,
):
    xyz = sample_coordinates(N, vol_extent=vol_extent, gen=gen)
    d = sample_dynamics(
        N,
        time_bleach=time_bleach,
        time_off=time_off,
        time_on=time_on,
        n_periods_max=n_periods_max,
        gen=gen,
    )
    t_start = sample_starting_times(N, length=n_frames, gen=gen)
    d += t_start[:, None, None]
    n = smlmsim.dynamics.batch_discretize_periods(d, n_frames=n_frames)
    photon_flux = sample_photon_flux(
        N, mean=photon_flux_mean, std=photon_flux_std, gen=gen
    )
    n = n * photon_flux[:, None]
    x = torch.cat([xyz, n], dim=-1)
    return x


def sample_coordinates(N: int, vol_extent: Tensor, gen: Generator) -> Tensor:
    xy = torch.rand((N, 2), generator=gen)
    z = 0.5 + 0.2 * torch.randn((N, 1), generator=gen)
    z = torch.clip(z, min=0.0, max=1.0)
    xyz = torch.cat([xy, z], dim=-1)
    # xyz = torch.rand((N, 3), generator=gen)
    xyz = (vol_extent[:, 1] - vol_extent[:, 0]) * xyz + vol_extent[:, 0]
    return xyz


def sample_dynamics(
    N: int,
    time_bleach: float,
    time_off: float,
    time_on: float,
    n_periods_max: int,
    gen: Generator,
) -> Tensor:
    d = smlmsim.dynamics.batch_sample_periods(
        N,
        time_on=time_on,
        time_off=time_off,
        time_bleach=time_bleach,
        n_periods_max=n_periods_max,
        gen=gen,
    )
    return d


def sample_starting_times(N: int, length: int, gen: Generator) -> Tensor:
    t_start = length * torch.rand((N,), generator=gen)
    return t_start


def sample_photon_flux(N: int, mean: float, std: float, gen: Generator) -> Tensor:
    eps = torch.randn((N,), generator=gen)
    photon_flux = mean + std * eps
    photon_flux.clamp_(min=1e-2 * mean)  # prevents negative
    return photon_flux


def sample_background(
    h: int, w: int, mean: float, std: float, gen: Generator
) -> Tensor:
    """Background is a perlin noise of the same resoltion as the image"""
    eps = generate_perlin_noise_2d(shape=(h, w), res=(1, 1), gen=gen)
    bg = mean + std * eps
    bg = bg.clip(min=0.0)
    return bg


def generate_perlin_noise_2d(shape, res, gen: Generator):
    def fade(t):
        return 6 * t**5 - 15 * t**4 + 10 * t**3

    # Gradients
    gradients = torch.randn(res[0] + 1, res[1] + 1, 2, generator=gen)
    gradients = gradients / gradients.norm(dim=2, keepdim=True)

    # Coordinate grid
    grid_y, grid_x = torch.meshgrid(
        torch.linspace(0, res[0], shape[0]),
        torch.linspace(0, res[1], shape[1]),
        indexing="ij",
    )

    grid = torch.stack([grid_y, grid_x], dim=-1)

    ij = grid.floor().long()
    fxy = grid - ij
    u, v = fade(fxy[..., 0]), fade(fxy[..., 1])

    def dot_grid(ix, iy, fx, fy):
        g = gradients[ix, iy]
        return (fx - ix) * g[..., 0] + (fy - iy) * g[..., 1]

    ix0 = ij[..., 0].clamp(0, res[0])
    iy0 = ij[..., 1].clamp(0, res[1])
    ix1 = (ix0 + 1).clamp(0, res[0])
    iy1 = (iy0 + 1).clamp(0, res[1])

    n00 = dot_grid(ix0, iy0, grid[..., 0], grid[..., 1])
    n10 = dot_grid(ix1, iy0, grid[..., 0], grid[..., 1])
    n01 = dot_grid(ix0, iy1, grid[..., 0], grid[..., 1])
    n11 = dot_grid(ix1, iy1, grid[..., 0], grid[..., 1])

    nx0 = n00 * (1 - u) + n10 * u
    nx1 = n01 * (1 - u) + n11 * u
    nxy = nx0 * (1 - v) + nx1 * v

    return nxy
