import jax
import jax.numpy as jnp
from jax import random
# plotting
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.distance import cdist




def get_problem_data_uniform(problem_params: dict, rng: any, precision: jnp.dtype, sharding: any, logger: any) -> dict:
    """
    Generates a symbolic mathematical expression of the function as a sum of Fourier and polynomial terms
    and evaluates it over a grid of points in arbitrary dimensions using JAX.

    Parameters:
    - n_dim (int): Number of input dimensions.
    - rng: JAX PRNG key.
    - complexity (float): Controls function complexity (0 = simple, 1 = complex).
    - max_terms (int): Maximum number of basis terms.
    - n_samples_per_dim (int): Number of samples per dimension.

    Returns:
    - problem_params (dict): Dictionary containing input points (x) and exact function values (u_exact).
    """

    n_dim = problem_params["dim"]
    n_samples_per_dim = problem_params['n_samples_per_dim']

    logger.info(f"get_problem_data_uniform")

    gaussian_peak = problem_params.get("gaussian_peak", None)
    sine_gaussian_peak = problem_params.get("sine_gaussian_peak", None)
    if gaussian_peak or sine_gaussian_peak:

        if gaussian_peak != None:
            mean = gaussian_peak.get("mean", 0.5)
            std  = gaussian_peak.get("std", 0.2)
            logger.info(f"\tgaussian_peak")
        else:
            logger.info(f"\tsin_gaussian_peak")
            mean = sine_gaussian_peak.get("mean", 0.5)
            std  = sine_gaussian_peak.get("std", 0.2)

        mean = jnp.full((n_dim,), mean) if jnp.ndim(mean) == 0 else jnp.array(mean)
        std  = jnp.full((n_dim,), std)  if jnp.ndim(std) == 0  else jnp.array(std)

        # Generate an n_dim-dimensional grid of sample points
        grid_axes = [jnp.linspace(0., 1., n_samples_per_dim) for _ in range(n_dim)]
        grid_points = jnp.array(jnp.meshgrid(*grid_axes, indexing='ij')).reshape(n_dim, -1).T

        """
        diff = (grid_points - mean) / std
        squared_dist = jnp.sum(diff ** 2, axis=1)
        u_exact = jnp.exp(-0.5 * squared_dist).reshape(-1, 1)

        if sine_gaussian_peak:
            u_exact = u_exact + jnp.sin(2 * jnp.pi * grid_points)

        """


        rid_shape = tuple([n_samples_per_dim] * n_dim)
        x = grid_points
        diff = (grid_points - mean) / std
        squared_dist = jnp.sum(diff ** 2, axis=1)
        u_exact = jnp.exp(-0.5 * squared_dist)

        if sine_gaussian_peak:
            sine_term = jnp.sum(jnp.sin(2 * jnp.pi * grid_points), axis=1)
            u_exact = u_exact + sine_term

        #u_exact = u_exact.reshape(grid_shape)
        u_exact = u_exact.reshape(-1, 1)

    elif problem_params.get("piecewise_smooth", False):

        logger.info(f"\tpiecewise_smooth")

        grid_axes = [jnp.linspace(0., 1., n_samples_per_dim) for _ in range(n_dim)]
        grid_points = jnp.array(jnp.meshgrid(*grid_axes, indexing='ij')).reshape(n_dim, -1).T


        p = problem_params.get("piecewise_smooth").get("p", 10.0)
        if n_dim == 1:
            x = grid_points[:, 0:1]
            u_exact = (2 * jnp.abs(p * x - jnp.floor(p * x + 0.5)) - 1)**2
        elif n_dim == 2:
            x, y = grid_points[:, 0:1], grid_points[:, 1:2]
            s = y - (x - 0.5)**2
            f2 = (2 * jnp.abs(p * s - jnp.floor(p * s + 0.5)) - 1)**2
            u_exact = f2
        else:
            raise ValueError("piecewise_smooth only supports dim=1 or dim=2")

    else:
        # Generate an n_dim-dimensional grid of sample points
        grid_axes = [jnp.linspace(-0.5, 0.5, n_samples_per_dim) for _ in range(n_dim)]
        grid_points = jnp.array(jnp.meshgrid(*grid_axes, indexing='ij')).reshape(n_dim, -1).T

        complexity = problem_params['complexity']
        max_terms = problem_params['max_terms']

        num_terms = int(1 + complexity * (max_terms - 1))  # Scale from 1 to max_terms
        frequency_scale = 1 + complexity * 5  # Scale Fourier frequencies

        # Split RNG for each random variable
        rng, rng_a, rng_b, rng_c, rng_d = random.split(rng, 5)

        # Random coefficients
        a_fourier = random.normal(rng_a, shape=(num_terms, n_dim)) * complexity
        b_fourier = random.uniform(rng_b, shape=(num_terms, n_dim), minval=1, maxval=frequency_scale)
        c_poly = random.normal(rng_c, shape=(num_terms, n_dim)) * complexity
        d_bias = random.normal(rng_d, shape=(num_terms,))

        def f(x):
            fourier_terms = jnp.sum(a_fourier * jnp.sin(b_fourier * x), axis=(0, 1))
            poly_terms = jnp.sum(c_poly * (x ** jnp.arange(1, num_terms+1)[:, None]), axis=(0, 1))
            return fourier_terms + poly_terms + jnp.sum(d_bias)

        # Compute function values at each grid point
        u_exact = jnp.array([f(xi) for xi in grid_points]).reshape(-1, 1)

    # Store results in dictionary
    problem_params = {"x": grid_points, "u_exact": u_exact, "grid_axes" : grid_axes}

    return problem_params



def get_problem_data(
    problem_params: dict,
    rng_key: any,
    precision: jnp.dtype,
    sharding: any,
    logger: any,
    ):
    """
    Generates data points and their exact solution values using Monte Carlo sampling,
    respecting a GPU memory limit. Output arrays are of the specified precision
    and potentially sharded.

    Args:
        problem_params (dict): Dictionary containing problem parameters:
            - "gb_allowed" (float): Maximum GB of total memory for key arrays.
            - "dim" (int): Number of dimensions.
            - "gaussian_peak" (dict, optional): Parameters for Gaussian peak.
                - "mean" (float or list/array): Mean of the Gaussian.
                - "std" (float or list/array): Standard deviation of the Gaussian.
            - "sine_gaussian_peak" (dict, optional): Parameters for Sine-Gaussian peak.
                - "mean" (float or list/array): Mean of the Gaussian part.
                - "std" (float or list/array): Standard deviation of the Gaussian part.
            - "piecewise_smooth" (dict, optional): Parameters for piecewise smooth function.
                - "p" (float, for dim=1): Parameter for the piecewise function.
        rng_key (jax.random.PRNGKeyArray): JAX PRNG key for random number generation.
        precision (jnp.dtype): The JAX dtype (e.g., jnp.float32, jnp.float64) for generated data.
        sharding (Optional[jax.sharding.NamedSharding]): JAX NamedSharding object for the output
                                                         'points' and 'solution' arrays (sharding the sample dimension).
                                                         If None, default device placement is used.
        logger (logging.Logger): Logger for informational and error messages.

    Returns:
        dict: A dictionary containing:
            - "points" (jnp.ndarray): Monte Carlo sampled points, shape (N_mc, n_dim).
            - "solution" (jnp.ndarray): Exact solution at points, shape (N_mc, 1).
            - "n_samples" (int): The number of Monte Carlo samples generated (N_mc).
    """

    if problem_params.get("uniform_points", False):
        return get_problem_data_uniform(problem_params, rng_key, precision, sharding, logger)


    gb_allowed = problem_params['gb_allowed']
    n_dim = problem_params["dim"]

    # Ensure precision is a jnp.dtype object
    try:
        converted_precision = jnp.dtype(precision)
    except TypeError:
        logger.error(f"Invalid precision value: {precision}. Must be convertible to jnp.dtype. Falling back to jnp.float64.")
        converted_precision = jnp.float64
    precision = converted_precision

    bytes_per_float = precision.itemsize

    # Estimate memory needed per sample for the largest arrays
    # Default (covers Gaussian): points (n_dim) + diff intermediate (n_dim) + solution (1)
    floats_per_sample_approx = 2 * n_dim + 1

    if problem_params.get("piecewise_smooth"):
        if n_dim == 1: # points (1), solution (1), 1 intermediate (1)
            floats_per_sample_approx = 1 + 1 + 1
        elif n_dim == 2: # points (2), solution (1), 1 intermediate for 's' (1)
             floats_per_sample_approx = 2 + 1 + 1

    memory_per_sample_bytes = floats_per_sample_approx * bytes_per_float
    gb_allowed_bytes = gb_allowed * (1024**3)

    safety_factor = 0.9
    effective_gb_allowed_bytes = gb_allowed_bytes * safety_factor

    if memory_per_sample_bytes <= 0:
        max_n_mc_samples = 0
    else:
        max_n_mc_samples = int(effective_gb_allowed_bytes / memory_per_sample_bytes)

    if max_n_mc_samples <= 0:
        err_msg = (
            f"Calculated max_n_mc_samples is {max_n_mc_samples}. "
            f"Low gb_allowed ({gb_allowed}GB), high n_dim ({n_dim}), precision ({precision.name}), or memory estimation. "
            f"Mem per sample: {memory_per_sample_bytes} bytes."
        )
        logger.error(err_msg)
        raise ValueError(err_msg)

    logger.info(f"Target precision: {precision.name}")
    logger.info(f"Max total memory (for key arrays): {gb_allowed:.3f} GB. Effective (with {safety_factor*100}% factor): {effective_gb_allowed_bytes / (1024**3):.3f} GB.")
    logger.info(f"Estimated memory per sample: {memory_per_sample_bytes} bytes ({floats_per_sample_approx} floats of size {bytes_per_float}).")
    logger.info(f"Targeting {max_n_mc_samples} Monte Carlo samples for {n_dim} dimensions.")

    rng_key, subkey = jax.random.split(rng_key)


    """
    mc_points = jax.random.uniform(subkey,
                                  (max_n_mc_samples, n_dim),
                                  minval=jnp.array(0., dtype=precision),
                                  maxval=jnp.array(1., dtype=precision),
                                  dtype=precision,
                                  )
    points = mc_points
    """


    mc_points = jax.random.normal(
                    subkey,
                    shape=(max_n_mc_samples, n_dim),
                    dtype=precision
                )

    gaussian_peak_config = problem_params.get("gaussian_peak")
    sine_gaussian_peak_config = problem_params.get("sine_gaussian_peak")

    u_exact = None # Initialize u_exact

    if gaussian_peak_config or sine_gaussian_peak_config:
        config_source = gaussian_peak_config if gaussian_peak_config is not None else sine_gaussian_peak_config
        mean_val_orig = config_source.get("mean", 0.5)
        std_val_orig  = config_source.get("std", 0.2)

        # Process mean
        mean_val_jax = jnp.array(mean_val_orig)
        if mean_val_jax.ndim == 0:
            mean_arr = jnp.full((n_dim,), mean_val_jax.item(), dtype=precision)
        elif mean_val_jax.shape == (n_dim,):
            mean_arr = mean_val_jax.astype(precision)
        else:
            default_mean = 0.5
            logger.warning(f"Mean shape {mean_val_jax.shape} (from original: {mean_val_orig}) is not scalar or ({n_dim},). Using default {default_mean}.")
            mean_arr = jnp.full((n_dim,), default_mean, dtype=precision)

        # Process std
        std_val_jax = jnp.array(std_val_orig)
        if std_val_jax.ndim == 0:
            std_arr = jnp.full((n_dim,), std_val_jax.item(), dtype=precision)
        elif std_val_jax.shape == (n_dim,):
            std_arr = std_val_jax.astype(precision)
        else:
            default_std = 0.2
            logger.warning(f"Std shape {std_val_jax.shape} (from original: {std_val_orig}) is not scalar or ({n_dim},). Using default {default_std}.")
            std_arr = jnp.full((n_dim,), default_std, dtype=precision)

        points = mean_arr + std_arr * mc_points
        # Small arrays like mean_arr, std_arr are typically handled by JAX's SPMD propagator
        # which broadcasts/replicates them as needed when used with sharded arrays.
        diff = (points - mean_arr) / std_arr
        squared_dist = jnp.sum(diff ** 2, axis=1)
        u_exact = jnp.exp(jnp.array(-0.5, dtype=precision) * squared_dist)

        if sine_gaussian_peak_config:
            sine_term = jnp.sum(jnp.sin(jnp.array(2 * jnp.pi, dtype=precision) * points), axis=1)
            u_exact = u_exact + sine_term

        u_exact = u_exact.reshape(-1, 1)

    elif problem_params.get("piecewise_smooth", False):
        piecewise_config = problem_params.get("piecewise_smooth")
        const_0_5 = jnp.array(0.5, dtype=precision)
        if n_dim == 1:
            x_coords = mc_points
            p_val = piecewise_config.get("p", 10.0)
            p_val_arr = jnp.array(p_val, dtype=precision)
            term = p_val_arr * x_coords
            u_exact = (jnp.array(2.0, dtype=precision) * jnp.abs(term - jnp.floor(term + const_0_5)) - jnp.array(1.0, dtype=precision))**2
        elif n_dim == 2:
            x_coords, y_coords = mc_points[:, 0:1], mc_points[:, 1:2]
            s = y_coords - (x_coords - const_0_5)**2
            term = jnp.array(4.0, dtype=precision) * s
            u_exact = (jnp.array(2.0, dtype=precision) * jnp.abs(term - jnp.floor(term + const_0_5)) - jnp.array(1.0, dtype=precision))**2
        else:
            err_msg = "piecewise_smooth (Monte Carlo) currently only supports dim=1 or dim=2"
            logger.error(err_msg)
            raise ValueError(err_msg)

        if u_exact.ndim == 1:
             u_exact = u_exact.reshape(-1,1)
    else:
        err_msg = "Problem type not specified or supported (e.g., gaussian_peak, sine_gaussian_peak, or piecewise_smooth)."
        logger.error(err_msg)
        raise ValueError(err_msg)

    u_exact = u_exact.astype(precision) # gurantee dtype

    #scale = jnp.linalg.norm(u_exact) / (jnp.sqrt(u_exact.size))
    #u_exact = u_exact / (scale * points.shape[1])

    u_exact = u_exact / u_exact.ravel().max()

    mem_points_gb = mc_points.nbytes / (1024**3)
    mem_uexact_gb = u_exact.nbytes / (1024**3)
    total_mem_gb = mem_points_gb + mem_uexact_gb

    logger.info(f"Generated {mc_points.shape[0]} samples in {n_dim}D.")
    logger.info(f"Shape of points: {mc_points.shape} (dtype: {mc_points.dtype.name}).")
    logger.info(f"Shape of solution: {u_exact.shape} (dtype: {u_exact.dtype.name}).")
    logger.info(f"Global memory for points: {mem_points_gb:.4f} GB, solution: {mem_uexact_gb:.4f} GB. Total: {total_mem_gb:.4f} GB.")

    # Check if actual memory significantly overshot the estimate (more than the safety factor would imply)
    if total_mem_gb > gb_allowed_bytes / (1024**3) * (1/safety_factor * 1.05) : # e.g. if safety factor was 0.9, this is gb_allowed * 1.16
        logger.warning(
            f"Final arrays ({total_mem_gb:.3f} GB) notably exceed allowed memory scaled by safety factor ({gb_allowed_bytes / (1024**3) / safety_factor:.3f} GB)."
            " This might be due to JAX overheads or estimation nuances."
        )

    return {
        "x": points,
        "u_exact": u_exact,
    }


# --- Define Data Generation/Loading Function ---
def get_2d_problem_data(problem_params: dict, precision: jnp.dtype, sharding: any, logger: any) -> dict:
    """Generates or loads the problem data."""


    n_samples = problem_params['n_samples']
    logger.info(f"Generating problem data with n_samples={n_samples}")
    x1 = jnp.linspace(-0.5, 0.5, n_samples, dtype=precision)
    x2 = jnp.linspace(-0.5, 0.5, n_samples, dtype=precision)
    X1_grid, X2_grid = jnp.meshgrid(x1, x2)
    X1 = X1_grid.flatten()
    X2 = X2_grid.flatten()
    u = (jnp.sin(2 * jnp.pi * X1)
        + jnp.abs(X2)**(jnp.array(0.5, dtype=precision))
        + X1 * jnp.cos(jnp.pi * X2)
        + jnp.log(jnp.array(1.0, dtype=precision) + jnp.abs(X1 - X2))
    )
    u_exact = u[:, None]
    x = jnp.vstack((X1, X2)).T
    x = jax.device_put(x, sharding)
    u_exact = jax.device_put(u_exact, sharding)
    return {"x": x, "u_exact": u_exact}



# ---------------------- Plotting Function for 1D, 2D, and 3D ---------------------- #
def plot_functions(complexity, rng):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(f"Function Complexity: {complexity:.2f}", fontsize=16)

    # Convert JAX PRNG to new seed for reproducibility
    rng = random.PRNGKey(int(complexity * 100))

    # 1D Plot
    problem_1d = get_problem_data(n_dim=1, rng=rng, complexity=complexity, n_samples_per_dim=100)
    x_1d = np.array(problem_1d["x"]).flatten()
    y_1d = np.array(problem_1d["u_exact"]).flatten()
    axes[0].plot(x_1d, y_1d, label="1D Function")
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("f(x)")
    axes[0].set_title("1D Function")
    axes[0].grid()

    # 2D Plot
    problem_2d = get_problem_data(n_dim=2, rng=rng, complexity=complexity, n_samples_per_dim=50)
    x_2d = np.array(problem_2d["x"][:, 0]).reshape(50, 50)
    y_2d = np.array(problem_2d["x"][:, 1]).reshape(50, 50)
    z_2d = np.array(problem_2d["u_exact"]).reshape(50, 50)
    contour = axes[1].contourf(x_2d, y_2d, z_2d, levels=20, cmap="coolwarm")
    fig.colorbar(contour, ax=axes[1], label="f(x, y)")
    axes[1].set_xlabel("x")
    axes[1].set_ylabel("y")
    axes[1].set_title("2D Function Contour")

    # 3D Plot
    ax3d = fig.add_subplot(133, projection="3d")
    f_3d = get_problem_data(n_dim=3, rng=rng, complexity=complexity, n_samples_per_dim=20)
    x_3d = np.linspace(-1, 1, 20)
    y_3d = np.linspace(-1, 1, 20)
    X, Y = np.meshgrid(x_3d, y_3d)
    Z = np.array([[f_3d["u_exact"][i * 20 + j] for j in range(X.shape[1])] for i in range(X.shape[0])]).reshape(20, 20)

    ax3d.plot_surface(X, Y, Z, cmap="coolwarm")
    ax3d.set_xlabel("x")
    ax3d.set_ylabel("y")
    ax3d.set_zlabel("f(x, y, z=0.5)")
    ax3d.set_title("3D Function Surface")


    plt.show()

if __name__ == "__main__":


    rng = np.random.default_rng(7)
    complexities = np.linspace(0.25, 2.5, 4)
    for complexity in complexities:
        plot_functions(complexity, rng)


