import multiprocessing as mp
import os
import pickle
from pathlib import Path

import elegy as eg
import optax
import pyvista as pv
from typer import Typer

from . import pde
from .config import Config, get_config
from .data import make_data_generator
from .model import make_network
from .validation import (MeshValidator, load_meshes, predict_and_save_mesh,
                         process_mesh)

mp.set_start_method("spawn")


app = Typer()


def make_model(cfg: Config) -> eg.Model:
    return eg.Model(
        module=pde.PDE(
            make_network(cfg.model),
            getattr(pde, cfg.equation),
            getattr(pde, cfg.boundary_condition),
        ),
        loss=[
            eg.losses.Huber(on="pde", name="loss/pde_loss"),
            eg.losses.Huber(on="bc", name="loss/bc_loss"),
        ],
        optimizer=eg.Optimizer(
            optax.inject_hyperparams(
                lambda lr: optax.apply_if_finite(
                    optax.MultiSteps(
                        optax.chain(
                            optax.clip_by_global_norm(cfg.grad_norm),
                            getattr(optax, cfg.optimizer)(lr, **cfg.optimizer_args),
                        ),
                        every_k_schedule=cfg.grad_accumulation_steps,
                    ),
                    max_consecutive_errors=10,
                ),
            )(
                lr=optax.warmup_cosine_decay_schedule(
                    0.0,
                    cfg.learning_rate.max_value,
                    cfg.learning_rate.warmup_steps,
                    cfg.learning_rate.decay_steps,
                    cfg.learning_rate.minimum_value,
                )
            ),
        ),
    )


@app.command()
def train(path: Path = "outputs", dry_run: bool = False, config: Path = None):
    """
    Train the model
    """
    print("training")
    cfg = get_config(config)
    print("Config:")
    print(cfg)
    os.makedirs(path, exist_ok=True)
    os.chdir(path)
    Path("config.yml").write_text(str(cfg))

    gen = make_data_generator(cfg)
    model = make_model(cfg)
    frst = next(iter(gen))[0]
    from jax.random import PRNGKey
    from jax.tree_util import tree_map

    print(tree_map(lambda x: x.shape, frst))
    model.module.init(PRNGKey(0), frst, inplace=True)
    model.summary(frst)
    tb = eg.callbacks.TensorBoard("runs")
    try:
        hist = model.fit(
            gen,
            epochs=cfg.epochs if not dry_run else 1,
            steps_per_epoch=cfg.steps_per_epoch,
            drop_remaining=False,  # needed because of mesh size difference
            callbacks=[
                eg.callbacks.TerminateOnNaN(),
                tb,
                MeshValidator(
                    load_meshes(f"../{cfg.data.validation_data}"),
                    lambda: tb.val_writer,
                    cfg,
                ),
                eg.callbacks.ModelCheckpoint(cfg.checkpoint),
            ],
        )
        with Path("hist.pkl").open("wb") as out:
            pickle.dump(hist.history, out)
        # print(hist.history)
    finally:
        del gen
    exit(0)


@app.command()
def test(model: str, meshes: str, output: str):
    """
    Test a given model on the given meshes
    """
    print("testing")
    os.makedirs(output, exist_ok=True)
    meshes = load_meshes(meshes)
    model = eg.load(model)
    u = model.module.u

    for name, mesh in meshes.items():
        mdata = process_mesh(mesh)
        err = predict_and_save_mesh(u, f"{output}/{name}", mdata)
        print(f"mesh {name} with mean error {err.mean()}")


@app.command()
def generate_fem_mesh(path: str):
    """
    Tesselate geometry for FreeFem++ to use as validation data
    """
    from glob import glob

    def gen_mesh(p: Path):
        from .data import export_fem_mesh, tesselate

        print(f"tesselating file :{p}")
        mesh = pv.read(p)
        mesh = tesselate(mesh, 21)
        export_fem_mesh(mesh, p.with_suffix(".mesh"))

    for p in glob(path):
        gen_mesh(Path(p))


# @app.command()
# def compute_spectrum(path: Path):
#     """
#     Approximate a Fourier expansion of the given FEM solution
#     """
#     assert os.path.exists(path)
#     mesh = pv.read(path)
#     mesh = mesh.ctp()
#     import jax.numpy as jnp
#     import numpy as np
#     from jax.scipy.optimize import minimize

#     x = jnp.asarray(mesh.points)
#     u = jnp.asarray(mesh["u"])

#     def loss(theta):
#         n = len(theta) // 5
#         a, k, phi = theta[:n], theta[n:-n], theta[-n:]
#         k = k.reshape(3, n)
#         y = (a * jnp.sin(x @ k + phi)).sum(-1)
#         return optax.l2_loss(y, u).mean()

#     res = minimize(loss, np.random.randn(5 * 1_000), method="BFGS")
#     with Path(path).with_suffix(".fourier").open("wb") as f:
#         pickle.dump(res, f)

#     print(res.fun, res.nfev)


if __name__ == "__main__":
    app()
