import os
import time

import absl
import jax
import optax
import orbax
from flax import nnx
from flax.training import orbax_utils
from jax import numpy as jnp
from KNO import build_model, quadratures, utils
from ml_collections.config_flags import config_flags
from numpy.polynomial.legendre import (
    leggauss,  ### don't worry about other quadrature choices tbh
)
from scipy import io


# if you're debugging, it's easiest to disable jit and use print statements as usual
# os.environ['JAX_DISABLE_JIT'] = 'True'

### convenience for reading the config, add more if you'd like
optimizer_mapping = {
    "adam": optax.adam,
    "sgd": optax.sgd,
}


### absl stuff, to dynamically set config on CL
FLAGS = absl.flags.FLAGS
absl.flags.DEFINE_string("workdir", None, "Working directory for a run of experiments.")
absl.flags.DEFINE_string("runsig", None, "run signature")
config_flags.DEFINE_config_file(
    "config", None, "Training configuration.", lock_config=False
)
absl.flags.mark_flags_as_required(["workdir", "config", "runsig"])
jax.config.config_with_absl()


def main(argv):
    workdir = FLAGS.workdir
    runsig = FLAGS.runsig
    config = FLAGS.config

    utils.create_path(workdir, verbose=False)
    exp_workdir = os.path.join(workdir, runsig)
    utils.create_path(exp_workdir, verbose=False)

    logger = utils.get_logger(os.path.join(exp_workdir, "log.txt"), displaying=False)
    logger.info("=============== Experiment Setup ===============")
    logger.info(config)
    logger.info("================================================")

    logger.info(f"DEVICE = {jax.default_backend()}")

    ################# loading data ##############################################################################

    key = jax.random.PRNGKey(seed=config.seed)

    data = jnp.load(config.data.filepath)
    x, x_grid, y, y_grid = data["x"], data["x_grid"], data["y"], data["y_grid"]
    grids = (x_grid, y_grid)
    x_train, x_test = x[: config.data.ntrain], x[-config.data.ntest :]
    y_train, y_test = y[: config.data.ntrain], y[-config.data.ntest :]

    shuffle = jax.random.permutation(key, len(x_train))
    x_train = x_train[shuffle]
    y_train = y_train[shuffle]

    logger.info(
        f"{x_train.shape=}, {y_train.shape=}\n {x_test.shape=}, {y_test.shape=}"
    )
    logger.info(f"{x_grid.shape=}, {y_grid.shape=}")

    if config.data.normalize:
        x_normalizer = utils.UnitGaussianNormalizer(x_train)
        y_normalizer = utils.UnitGaussianNormalizer(y_train)
        x_train = x_normalizer.encode(x_train)
        x_test = x_normalizer.encode(x_test)

    ### manual mini-batching of the jnp arrays
    train_batch_size = config.training.batch_size

    @jax.jit
    def get_train_batch(
        i,
    ):
        x = jax.lax.dynamic_slice_in_dim(
            x_train,
            i * train_batch_size,
            train_batch_size,
        )
        y = jax.lax.dynamic_slice_in_dim(
            y_train,
            i * train_batch_size,
            train_batch_size,
        )
        return x, y

    num_train_batches = len(x_train) // config.training.batch_size

    ### things we need to configure for the model, that are DATASET SPECIFIC ##########################################

    ### like determining the number of features which go into the lifting MLP, i.e. dim(x_grid) + dim(f_x), which dim the output function is for the projection mlp

    x_codomain_dim = x_train.shape[-1]
    x_domain_dim = x_grid.shape[-1]

    y_codomain_dim = y_train.shape[-1]
    y_domain_dim = y_grid.shape[
        -1
    ]  ### note that y_domain_dim and x_domain_dim currently can't be different for this method

    in_features = x_codomain_dim + x_domain_dim
    out_features = y_codomain_dim

    ################# like setting up the quadrature rule #############################################################

    quadrature_fn_1d = leggauss

    num_base_points = config.model.num_base_quad_pts

    if x_domain_dim == 1:
        quadrature_rule = quadratures.quadrature_unit_hypercube(
            1,
            num_base_points,
            quadrature_fn_1d,
        )

    elif x_domain_dim == 2:
        ### the irregular domains
        if config.dataset == "darcy_triangular":
            quadrature_rule = quadratures.triangle_quad_rule(
                num_base_points, quadrature_fn_1d
            )
        elif config.dataset == "darcy_triangular_notch":
            quadrature_rule = utils.get_triangular_notch_quadrature_rule(
                num_base_points, quadrature_fn_1d
            )
        else:
            num_subdivisions = config.model.num_domain_subdivisions_for_quad
            quadrature_rule = quadratures.triangle_mesh_quad_rule(
                num_base_points,
                quadrature_fn_1d,
                mesh=utils.unit_square_triangular_mesh(num_subdivisions),
            )

    if config.dataset == "diffrec_3d":
        rule = io.loadmat("datasets/sphere_quad/n_1000.mat")
        t, w = rule["t"].astype(jnp.float32), rule["w"].astype(jnp.float32)
        quadrature_rule = (t.reshape(-1, 3), w.flatten()[:, None])

    logger.info(f"num quadrature nodes: {len(quadrature_rule[0])}")

    ### like determining the centers for kernel least squares interpolation
    if config.model.interpolation_scheme == "kernel_least_squares":
        ### heuristic for the number of centers/degree is sqrt(n)? could change this to a fraction of x_grid
        num_centers = int(0.3 * int(len(x_grid)))
        print(f"using {num_centers=} in the least squares interpolation scheme")
        ### randomly select subset of x_grid
        key, _ = jax.random.split(key)
        centers = jax.random.choice(key, x_grid, shape=(num_centers,), replace=False)
    else:
        centers = None

    dynamic_config = {
        "in_features": in_features,
        "out_features": out_features,
        "ndims": x_domain_dim,
        "centers": centers,
    }

    ################# load model ##############################################################################

    model = build_model(config, dynamic_config)

    params = nnx.state(model, nnx.Param)
    logger.info(
        f"param_count: {sum(x.size for x in jax.tree_util.tree_leaves(params))}"
    )

    ################ set up optimizer #########################################################################

    opt = config.optim
    lr_schedule_fn = lambda epochs: utils.cosine_annealing_plus_constant(
        total_steps=(epochs * num_train_batches),
        init_value=opt.init_value,
        warmup_frac=opt.warmup_frac,
        peak_value=opt.peak_value,
        end_value=opt.end_value,
        num_cosine_cycles=opt.num_cosine_cycles,
        num_constant_cycles=opt.num_constant_cycles,
        gamma=opt.gamma,
    )

    optimizer_choice = optimizer_mapping[config.optim.optimizer]
    optimizer = nnx.Optimizer(
        model, optimizer_choice(lr_schedule_fn(config.training.epochs))
    )
    metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
    )

    ################# stuff for independent layer training  #############################################################

    if config.training.individual_layer_training:
        kernel_layer_paths = utils.get_kernel_layer_paths(config)

        total_epochs_training_individual_layers = config.training.epochs_per_layer * (
            len(kernel_layer_paths)
        )

        main_optimizer = optimizer_choice(
            lr_schedule_fn(
                config.training.epochs - total_epochs_training_individual_layers
            )
        )
        layer_optimizer = optimizer_choice(
            lr_schedule_fn(config.training.epochs_per_layer)
        )

        partition_optimizers = {
            "trainable": layer_optimizer,
            "frozen": optax.set_to_zero(),
        }
    ################ define functions to do forward pass + param updates #######################################

    @jax.jit
    def train_step(graphdef, state, batch):
        """Train for a single batch."""
        x, y = batch
        model, optimizer, metrics = nnx.merge(graphdef, state)

        def rel_l2(model):
            y_pred = nnx.vmap(
                lambda x: model(
                    x,
                    *quadrature_rule,
                    *grids,
                )
            )(
                x,
            )
            if config.data.normalize:
                y_pred = y_normalizer.decode(y_pred)
            return jnp.mean(
                jnp.linalg.norm(y_pred - y, axis=1) / jnp.linalg.norm(y, axis=1)
            )

        loss, grads = nnx.value_and_grad(rel_l2)(model)
        metrics.update(loss=loss)
        optimizer.update(grads)

        _, state = nnx.split((model, optimizer, metrics))
        return state

    @jax.jit
    def eval_step(graphdef, state, batch):
        """Eval for a single batch."""
        x, y = batch
        model, _, metrics = nnx.merge(graphdef, state)

        def rel_l2(model):
            y_pred = nnx.vmap(
                lambda x: model(
                    x,
                    *quadrature_rule,
                    *grids,
                )
            )(
                x,
            )
            if config.data.normalize:
                y_pred = y_normalizer.decode(y_pred)

            return jnp.mean(
                jnp.linalg.norm(y_pred - y, axis=1) / jnp.linalg.norm(y, axis=1)
            )

        loss = rel_l2(model)
        metrics.update(loss=loss)
        _, state = nnx.split((model, optimizer, metrics))
        return state

    metrics_history = {"train_loss": [], "test_loss": []}
    graphdef, state = nnx.split((model, optimizer, metrics))

    ################ begin training phase  ##########################################################

    for epoch in range(config.training.epochs):
        ############# changing the optimizer based on which layer i training ##########################

        ### no real prettier way to do this
        if config.training.individual_layer_training:
            ### every time we finish training a layer, change the optimizer to focus on the next layer, pop previous from the list
            if epoch % config.training.epochs_per_layer == 0 and kernel_layer_paths:
                logger.info(
                    f"now training kernel-based layer {len(kernel_layer_paths)}"
                )
                params_partition = jax.tree_util.tree_map_with_path(
                    lambda path, v: "trainable"
                    if all(arg in path for arg in kernel_layer_paths[-1])
                    else "frozen",
                    params,
                )
                optimizer = nnx.Optimizer(
                    model, optax.multi_transform(partition_optimizers, params_partition)
                )
                graphdef, state = nnx.split((model, optimizer, metrics))
                kernel_layer_paths.pop()

            ### back to entire model
            if epoch == total_epochs_training_individual_layers:
                logger.info("now training full model")
                optimizer = nnx.Optimizer(model, main_optimizer)
                graphdef, state = nnx.split((model, optimizer, metrics))

        ### train
        train_t1 = time.perf_counter()
        for i in range(num_train_batches):
            batch = get_train_batch(i)
            state = train_step(graphdef, state, batch)
        train_time = time.perf_counter() - train_t1

        _, _, metrics = nnx.merge(graphdef, state)
        ### get training metrics
        if epoch % config.training.log_at == 0:
            for metric, value in metrics.compute().items():
                metrics_history["train_loss"].append(value)
                logger.info(f"{epoch=}, train_loss: {value}, {train_time=}")

        ### reset after each training epoch for testing metrics
        metrics.reset()
        state[2] = nnx.state(metrics)

        ### inference
        if epoch % config.testing.eval_at == 0:
            eval_t1 = time.perf_counter()
            state = eval_step(graphdef, state, (x_test, y_test))
            eval_time = time.perf_counter() - eval_t1

            _, _, metrics = nnx.merge(graphdef, state)
            for metric, value in metrics.compute().items():
                metrics_history["test_loss"].append(value)
                logger.info(f"{epoch=}, test_loss: {value}, {eval_time=}")
            metrics.reset()
            state[2] = nnx.state(metrics)

    logger.info(f"best train loss: {jnp.array(metrics_history['train_loss']).min()}")
    logger.info(f"best test loss: {jnp.array(metrics_history['test_loss']).min()}")

    nnx.update((model, optimizer, metrics), state)
    ####### saving model ############################################################################

    if config.save_model:
        logger.info("saving checkpoint ...")
        ckpt = {
            "model": nnx.state(model),
            "config": config.to_dict(),
            "metrics_history": metrics_history,
            "step": epoch,
        }

        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(ckpt)
        checkpointer.save(
            os.path.abspath(os.path.join(exp_workdir, "ckpt")),
            ckpt,
            save_args=save_args,
            force=True,
        )
        logger.info("DONE!")


if __name__ == "__main__":
    absl.app.run(main)
