import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"

import json
import argparse
from pathlib import Path

import jax
import jax.numpy as jnp
import torch
import numpy as np
from flax import serialization

from complete_graph_gnn_jax import MoleculePINN


def _load_pinn_args(exp_dir: Path) -> argparse.Namespace:
    args_path = exp_dir / "pinn_args.json"
    if not args_path.exists():
        raise FileNotFoundError(f"Missing {args_path}. Did you dump pinn_args.json into the experiment folder?")
    with open(args_path, "r") as f:
        cfg = json.load(f)
    return argparse.Namespace(**cfg)


def _restore_ckpt_params(pinn, ckpt_path: Path, key: jax.Array) -> dict:
    # init to get a params structure (template)
    n_max = 29
    x0 = jnp.zeros((n_max, 3), dtype=jnp.float32)
    y0 = jnp.zeros((n_max, 3), dtype=jnp.float32)
    t0 = jnp.array(0.0, dtype=jnp.float32)
    n0 = jnp.array(10, dtype=jnp.int32)
    variables = pinn.init(key, x0, y0, t0, n0)

    with open(ckpt_path, "rb") as f:
        bytes_ = f.read()

    # If checkpoint was saved as a dict via serialization.to_bytes(ckpt_dict),
    # msgpack_restore can often load it without needing a full template.
    if hasattr(serialization, "msgpack_restore"):
        ckpt = serialization.msgpack_restore(bytes_)
    else:
        # fallback: only restore params (ignores other keys)
        ckpt = serialization.from_bytes({"params": variables["params"]}, bytes_)

    if "params" not in ckpt:
        # if fallback returned just params pytree
        ckpt = {"params": ckpt}

    return ckpt["params"]


def get_score_function(ckpt_path: str, exp_dir: str | None = None):
    ckpt_path = Path(ckpt_path)
    exp_dir = Path(exp_dir) if exp_dir is not None else ckpt_path.parent

    # 1) load pinn args from disk
    pinn_args = _load_pinn_args(exp_dir)

    # 2) build pinn
    key = jax.random.PRNGKey(0)
    pinn = MoleculePINN(
        n_max=29,
        n_fourier=pinn_args.n_fourier,
        r_fourier_min=pinn_args.r_fourier_min,
        r_fourier_max=pinn_args.r_fourier_max,
        t_fourier_min=pinn_args.t_fourier_min,
        t_fourier_max=pinn_args.t_fourier_max,
        apply_log=pinn_args.apply_log,
    )

    # 3) restore params
    params = _restore_ckpt_params(pinn, ckpt_path, key)

    # 4) put params on GPU if available
    gpus = jax.devices("gpu")
    dev = gpus[0] if len(gpus) > 0 else jax.devices("cpu")[0]
    params = jax.device_put(params, dev)

    # ----- PINN derivatives -----
    def f_0(params, x, y, t, n):
        u_0 = pinn.apply({"params": params}, x, y, t, n)
        return u_0, u_0

    def f_x(params, x, y, t, n):
        u_x, u_0 = jax.jacrev(f_0, argnums=2, has_aux=True)(params, x, y, t, n)
        return u_x, (u_x, u_0)

    batched_f_x = jax.jit(jax.vmap(f_x, in_axes=(None, 0, 0, 0, 0)))

    # ----- conversions -----
    def torch_to_jax(x: torch.Tensor) -> jax.Array:
        # simple + reliable (host copy). If you need faster, switch to DLPack.
        return jnp.asarray(x.detach().cpu().numpy())

    def jax_to_torch(x: jax.Array) -> torch.Tensor:
        return torch.tensor(np.array(x))

    def sigma_to_t(sigma):
        return (sigma ** 2) / 2.0

    # ----- score function -----
    def score_func(z_noisy, z_0, sigma, n_list):
        z_noisy = torch_to_jax(z_noisy)
        z_0 = torch_to_jax(z_0)
        sigma = torch_to_jax(sigma)
        n_list = torch_to_jax(n_list).astype(jnp.int32)

        t = sigma_to_t(sigma)

        _, (u_x, _u_0) = batched_f_x(params, z_noisy, z_0, t, n_list)

        pinn_score_idx = (t > pinn_args.t_min)[:, None, None]
        score_original = (z_0 - z_noisy) / (sigma[:, None, None] ** 2)
        score_pinn = -u_x

        score = jnp.where(pinn_score_idx, score_pinn, score_original)
        return jax_to_torch(score)

    return score_func
