import math
import torch
import torch.nn.functional as F

DEVICE = torch.device("cuda" if torch.cuda.is_available()
                      else ("mps" if torch.backends.mps.is_available() else "cpu"))


@torch.no_grad()
def sample_modulus_voronoi_torch(batch_size, ny, nx, *,
                                 n_regions=48, low=0.5, high=5.0,
                                 blur_sigma=0.0, device=DEVICE, dtype=torch.float32):
    """
    Batched Voronoi sampler for E(x) entirely in Torch. Returns (B, ny, nx).
    """
    B = batch_size
    # grid
    yy = torch.linspace(0, ny - 1, ny, device=device, dtype=dtype).view(1, ny, 1)
    xx = torch.linspace(0, nx - 1, nx, device=device, dtype=dtype).view(1, 1, nx)

    # sites per sample (B, n_regions, 2)
    sites_y = torch.rand(B, n_regions, device=device, dtype=dtype) * (ny - 1)
    sites_x = torch.rand(B, n_regions, device=device, dtype=dtype) * (nx - 1)

    # region values per sample (log-uniform)
    log_low = math.log(low); log_high = math.log(high)
    vals = torch.exp(torch.rand(B, n_regions, device=device, dtype=dtype) * (log_high - log_low) + log_low)

    # compute nearest site label via vectorised distances
    # d2: (B, n_regions, ny, nx)
    d2 = (yy - sites_y[:, :, None, None])**2 + (xx - sites_x[:, :, None, None])**2
    labels = torch.argmin(d2, dim=1)  # (B, ny, nx) indices in [0, n_regions)

    # gather per-cell E
    # convert (B, n_regions) -> (B, ny, nx) by take_along_dim
    E = torch.take_along_dim(vals, labels.view(B, -1), dim=1).view(B, ny, nx)

    if blur_sigma and blur_sigma > 0:
        radius = max(1, int(math.ceil(3 * blur_sigma)))
        x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
        g1 = torch.exp(-0.5 * (x / blur_sigma)**2)
        g1 = (g1 / g1.sum()).view(1, 1, 1, -1)  # (out=1, in=1, 1, kx)

        # apply separable Gaussian blur via conv2d
        E = E.unsqueeze(1)  # (B,1,ny,nx)
        E = F.pad(E, (radius, radius, 0, 0), mode="replicate")
        E = F.conv2d(E, g1)
        E = F.pad(E, (0, 0, radius, radius), mode="replicate")
        E = F.conv2d(E, g1.transpose(-1, -2))
        E = E[:, 0]  # (B, ny, nx)

    return E



@torch.no_grad()
def solve_elasticity_batched(E, *, nu=0.30, max_iters=20000, tol=1e-6,
                             A_top=0.03, A_bottom=None, clamp_bottom=False):
    """
    Batched plane-strain elasticity, Dirichlet-only, no body force.
    E: (B, ny, nx) float32 tensor on DEVICE.
    Returns u: (B, ny, nx, 2) float32 tensor on DEVICE.
    """
    if A_bottom is None:
        A_bottom = A_top

    E = E.to(DEVICE, dtype=torch.float32, non_blocking=True)
    B, ny, nx = E.shape
    dx = 1.0 / (nx - 1)
    dy = 1.0 / (ny - 1)

    mu  = E / (2.0 * (1.0 + nu))
    lam = (nu * E) / ((1.0 + nu) * (1.0 - 2.0 * nu))
    lam2mu = lam + 2.0 * mu

    # boundary profiles (broadcast over batch)
    x = torch.linspace(0.0, 1.0, nx, device=DEVICE, dtype=torch.float32)
    u_top_y = +A_top * torch.sin(torch.pi * x)  # inward at top (array coords: down)
    u_bot_y = -A_bottom * torch.sin(torch.pi * x)  # inward at bottom (array coords: up)
    u_top_y = u_top_y.view(1, 1, nx).expand(B, 1, nx)    # (B,1,nx)
    if not clamp_bottom:
        u_bot_y = u_bot_y.view(1, 1, nx).expand(B, 1, nx)

    # unknown displacement
    u = torch.zeros((B, ny, nx, 2), device=DEVICE, dtype=torch.float32)

    # stable explicit step (global, conservative)
    dt = 0.20 * (min(dx, dy)**2) / (float(torch.amax(lam2mu).item()) + 1e-12)

    def d_dx(a):
        # a: (B, ny, nx)
        out = torch.empty_like(a)
        out[:, :, 1:-1] = (a[:, :, 2:] - a[:, :, :-2]) / (2 * dx)
        out[:, :, 0]    = (a[:, :, 1]  - a[:, :, 0])   / dx
        out[:, :, -1]   = (a[:, :, -1] - a[:, :, -2])  / dx
        return out

    def d_dy(a):
        out = torch.empty_like(a)
        out[:, 1:-1, :] = (a[:, 2:, :] - a[:, :-2, :]) / (2 * dy)
        out[:, 0,   :]  = (a[:, 1,   :] - a[:, 0,   :]) / dy
        out[:, -1,  :]  = (a[:, -1,  :] - a[:, -2,  :]) / dy
        return out

    def enforce_dirichlet(u_):
        # left/right clamped
        u_[:, :, 0, :]  = 0.0
        u_[:, :, -1, :] = 0.0
        # top: normal displacement only
        u_[:, 0, :, 0] = 0.0
        u_[:, 0, :, 1] = u_top_y[:, 0, :]
        # bottom: either clamp or normal displacement
        if clamp_bottom:
            u_[:, -1, :, :] = 0.0
        else:
            u_[:, -1, :, 0] = 0.0
            u_[:, -1, :, 1] = u_bot_y[:, 0, :]
        return u_

    for _ in range(max_iters):
        u = enforce_dirichlet(u)

        ux = u[..., 0]  # (B,ny,nx)
        uy = u[..., 1]

        dux_dx = d_dx(ux)
        duy_dy = d_dy(uy)
        dux_dy = d_dy(ux)
        duy_dx = d_dx(uy)

        exx = dux_dx
        eyy = duy_dy
        exy = 0.5 * (dux_dy + duy_dx)
        tr = exx + eyy

        sxx = 2.0 * mu * exx + lam * tr
        syy = 2.0 * mu * eyy + lam * tr
        sxy = 2.0 * mu * exy

        divx = d_dx(sxx) + d_dy(sxy)
        divy = d_dx(sxy) + d_dy(syy)

        rx = -divx
        ry = -divy
        rmean = torch.sqrt(torch.mean(rx*rx + ry*ry))
        if float(rmean.item()) <= tol:
            break

        u[..., 0] += dt * (-rx)
        u[..., 1] += dt * (-ry)


    u = enforce_dirichlet(u)
    return u


# --------------------------
# Fast dataset generation
# --------------------------

def generate_dataset_fast(N_samples, ny, nx, *,
                          batch_size=64,
                          n_regions=32, low=1.0, high=10.0, blur_sigma=1.0,
                          nu=0.30, max_iters=20_000, tol=1e-6,
                          A_top=0.10, A_bottom=None, clamp_bottom=False,
                          device=DEVICE):
    """
    Vectorised, batched generator. Returns dict with tensors on CPU ready to save.
    """
    E_batches = []
    U_batches = []

    remaining = N_samples
    while remaining > 0:
        b = min(batch_size, remaining)

        # sample E on device
        E_b = sample_modulus_voronoi_torch(b, ny, nx,
                                           n_regions=n_regions, low=low, high=high,
                                           blur_sigma=blur_sigma, device=device)

        # solve on device
        U_b = solve_elasticity_batched(E_b, nu=nu, max_iters=max_iters, tol=tol,
                                       A_top=A_top, A_bottom=A_bottom, clamp_bottom=clamp_bottom)

        # move to CPU for storage (float32)
        E_batches.append(E_b.detach().cpu())
        U_batches.append(U_b.detach().cpu())

        remaining -= b

    E = torch.cat(E_batches, dim=0)              # (N, ny, nx)
    U = torch.cat(U_batches, dim=0)              # (N, ny, nx, 2)

    params = {
        'ny': ny, 'nx': nx, 'N_samples': N_samples,
        'sampler': {'type': 'voronoi', 'n_regions': n_regions, 'low': low, 'high': high, 'blur_sigma': blur_sigma},
        'solver': {'fn': 'solve_elasticity_batched', 'nu': nu, 'max_iters': max_iters, 'tol': tol,
                   'A_top': A_top, 'A_bottom': (A_top if A_bottom is None else A_bottom),
                   'clamp_bottom': clamp_bottom, 'device': str(device)},
    }
    return {'E': E, 'u': U, 'params': params}



if __name__ == "__main__":
    torch.manual_seed(42)
    ny, nx = 64, 64
    N = 10_000

    data = generate_dataset_fast(
        N, ny, nx,
        batch_size=128,
        n_regions=32, low=1.0, high=10.0, blur_sigma=1.0,
        nu=0.30, max_iters=20_000, tol=1e-6,
        A_top=0.10, A_bottom=None, clamp_bottom=False,
        device=DEVICE,
    )

    save_path = "elasticity_dataset.pt"
    torch.save(data, save_path)
