import torch
import torch.fft
import scipy.sparse
import scipy.sparse.linalg
from scipy.fft import idctn
import numpy as np
import random


def sample_grf(N, alpha=2.0, tau=3):
    """
    Generate a 2D Gaussian Random Field using a DCT-based spectral synthesis.
    """
    xi = np.random.randn(N, N)
    k1 = np.arange(N)
    k2 = np.arange(N)
    K1, K2 = np.meshgrid(k1, k2, indexing='ij')
    coef = (tau ** (alpha - 1)) * (np.pi ** 2 * (K1 ** 2 + K2 ** 2) + tau ** 2) ** (-alpha / 2)
    coef[0, 0] = 0.0  # remove constant mode
    L = N * coef * xi
    a_raw = idctn(L, type=2, norm='ortho')
    return torch.tensor(a_raw, dtype=torch.float32)


def psi(a_raw):
    # Map raw field to piecewise-constant permeability
    return torch.where(a_raw >= 0, torch.tensor(12.0, dtype=torch.float32),
                                 torch.tensor(3.0,  dtype=torch.float32))


def solve_pde(a: torch.Tensor, f: torch.Tensor):
    """
    Solve -div(a ∇u) = f with Dirichlet BCs on [0,1]^2 using finite differences.
    a: [N, N] permeability field
    f: [N, N] RHS field
    """
    N = a.shape[0]
    dx = 1.0 / (N - 1)
    dx2 = dx * dx

    a = a.numpy()
    f = f.numpy()

    A = scipy.sparse.lil_matrix((N * N, N * N))
    b = -f.reshape(-1)  # invert sign because we neglected minus on lhs

    for i in range(N):
        for j in range(N):
            idx = i * N + j
            if i == 0 or i == N - 1 or j == 0 or j == N - 1:
                # Dirichlet BC: u = 0
                A[idx, idx] = 1.0
                b[idx] = 0.0
            else:
                ai_j = a[i, j]
                a_ip1 = 0.5 * (ai_j + a[i + 1, j])
                a_im1 = 0.5 * (ai_j + a[i - 1, j])
                a_jp1 = 0.5 * (ai_j + a[i, j + 1])
                a_jm1 = 0.5 * (ai_j + a[i, j - 1])

                A[idx, idx]       = -(a_ip1 + a_im1 + a_jp1 + a_jm1) / dx2
                A[idx, idx + 1]   = a_jp1 / dx2
                A[idx, idx - 1]   = a_jm1 / dx2
                A[idx, idx + N]   = a_ip1 / dx2
                A[idx, idx - N]   = a_im1 / dx2

    A = A.tocsr()
    u = scipy.sparse.linalg.spsolve(A, b)
    u = torch.tensor(u.reshape(N, N), dtype=torch.float32)
    return u


def generate_sample(N):
    a_raw = sample_grf(N, alpha=2.0, tau=3.0)
    a = psi(a_raw)
    f = torch.ones(N, N, dtype=torch.float32)  # constant forcing
    u = solve_pde(a, f)
    return a, u


if __name__ == '__main__':
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)

    N = 64  # grid size
    num_samples = 20_000

    a_list = []
    u_list = []

    for _ in range(num_samples):
        a, u = generate_sample(N)
        a_list.append(a)
        u_list.append(u)

    data = {
        'a': torch.stack(a_list),  # shape: (num_samples, N, N)
        'u': torch.stack(u_list),
    }

    torch.save(data, 'darcy_dataset.pt')
