import os
# os.environ["JAX_PLATFORMS"] = "cpu"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import xarray
import seaborn as sns
from nnx_models import MLP, SIREN, WIRE
from nnx_models import LoRA, add_lora_to_model, merge_lora_params, reset_lora_params
import flax.nnx as nnx
import optax
from soap_jax import soap
import orbax.checkpoint as ocp
import absl.logging
import shutil
from tqdm import tqdm
import typing as tp
import pickle
import matplotlib.image as mpimg
from matplotlib.animation import FuncAnimation, PillowWriter


absl.logging.set_verbosity(absl.logging.FATAL)

A = tp.TypeVar('A')

class CustomVariable(nnx.variablelib.Param[A]):
    pass

def loss(model, x, y):
    preds = model(x)
    return jnp.mean((preds - y)**2)

@nnx.jit
def train_step(model, optimizer, x, y):
    loss_step, grads = nnx.value_and_grad(loss)(model, x, y)
    # loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=nnx.LoRAParam))(model, x, y)

    optimizer.update(grads)
    return loss_step

if __name__ == "__main__":
    jax.config.update("jax_default_matmul_precision", "highest")
    coords = jnp.load(".../2DNS_10000/coord.npy")
    vorticity = jnp.load(".../2DNS_10000/vorticity_trajectory.npy")

    # Siren best configurations
    # 1.86e-3 (10, 1)
    # 2.67e-3 (4, 1)
    # 3.88e-3 (20, 1)
    # 1.93e-3 (10, 2)
    # 1.50e-2 (10, 4)
    # 2.27e-3 (10, 0.5)
    
    # H-Siren best configurations
    # 7.17e-3 (1, 1)
    # 7.23e-2 (10, 5)
    # 1.04e-2 (1, 5)
    # 1.23e-2 (2, 1)
    # 1.53e-2 (4, 1)
    # 1.25e-2 (8, 1)

    # model = SIREN(
    #     input_dim=2,
    #     output_dim=1,
    #     hidden_dim=64,
    #     num_hidden_layers=5,
    #     first_omega=4.,
    #     hidden_omega=1.
    # )

    # Best WIRE configuration
    # 3.28e-3 (10, 1, 1)
    # 5.96e-3 (1, 1, 1)
    # 4.51e-3 (6, 2, 4)
    # 2.42e-3 (10, 2, 4)
    # 7.78e-3 (10, 1, 4)
    # 2.50e-3 (10, 3, 4)

    # jax.clear_caches()
    # exit()

    # model = WIRE(
    #     2,
    #     1,
    #     hidden_dim=32,
    #     num_hidden_layers=5,
    #     first_omega=10.,
    #     hidden_omega=1.,
    #     scale=2.
    # )

    model = MLP(
        input_dim=2,
        output_dim=1,
        hidden_dim=64,
        fourier_emb_scale=7.0,
        num_hidden_layers=6
    )

    checkpointer = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
    # checkpointer.save("../vorticity0/state", state)
    graphdef, abstract_state = nnx.split(model)
    # state_restored = checkpointer.restore(
    #     "/vorticity500/state",
    #     args=ocp.args.StandardRestore(abstract_state))
    # model = nnx.merge(graphdef, state_restored)

    coords = jnp.stack(
        jnp.meshgrid(
            jnp.arange(256) * 2 * jnp.pi / 256,
            jnp.arange(256) * 2 * jnp.pi / 256,
            indexing='ij'
        )
    , axis=-1).reshape(-1, 2)

    vorticity = vorticity.reshape(10000, -1, 1)
    min_v = vorticity.min()
    fact = vorticity.max() - min_v
    vorticity = (vorticity - min_v) / fact

    # add_lora_to_model(model, lora_rank=16)

    dict_coords = {
        "x": coords.reshape(256, 256, 2)[:, 0, 0],
        "y": coords.reshape(256, 256, 2)[0, :, 1],
    }

    @nnx.jit
    def eval_step_diff(model, coords, vorticity_prev, vorticity_curr):
        l2_rel_norm = jnp.linalg.norm(model(coords) + vorticity_prev - vorticity_curr) / jnp.linalg.norm(vorticity_curr)
        l2_rel_original = jnp.linalg.norm((model(coords) + vorticity_prev) * fact + min_v - (vorticity_curr * fact + min_v)) / jnp.linalg.norm(vorticity_curr * fact + min_v)
        l2_error = jnp.linalg.norm((model(coords) + vorticity_prev) * fact + min_v - (vorticity_curr * fact + min_v))
        max_abs = jnp.max(jnp.abs((model(coords) + vorticity_prev) * fact + min_v - (vorticity_curr * fact + min_v)))
        return l2_rel_norm, l2_rel_original, l2_error, max_abs
    
    @nnx.jit
    def eval_step(model, coords, vorticity_step):
        l2_rel_norm = jnp.linalg.norm(model(coords) - vorticity_step) / jnp.linalg.norm(vorticity_step)
        l2_rel_original = jnp.linalg.norm(model(coords) * fact + min_v - (vorticity_step * fact + min_v)) / jnp.linalg.norm(vorticity_step * fact + min_v)
        l2_error = jnp.linalg.norm(model(coords) * fact + min_v - (vorticity_step * fact + min_v))
        max_abs = jnp.max(jnp.abs(model(coords) * fact + min_v - (vorticity_step * fact + min_v)))
        psnr = 10 * jnp.log10(jnp.max(vorticity_step * fact + min_v)**2 / jnp.mean((model(coords) * fact + min_v - (vorticity_step * fact + min_v))**2))
        return l2_rel_norm, l2_rel_original, l2_error, max_abs, psnr

    @nnx.jit
    def eval_step_general(pred, vorticity_step):
        l2_rel_norm = jnp.linalg.norm(pred - vorticity_step) / jnp.linalg.norm(vorticity_step)
        l2_rel_original = jnp.linalg.norm(pred * fact + min_v - (vorticity_step * fact + min_v)) / jnp.linalg.norm(vorticity_step * fact + min_v)
        l2_error = jnp.linalg.norm(pred * fact + min_v - (vorticity_step * fact + min_v))
        max_abs = jnp.max(jnp.abs(pred * fact + min_v - (vorticity_step * fact + min_v)))
        psnr = 10 * jnp.log10(jnp.max(vorticity_step * fact + min_v)**2 / jnp.mean((pred * fact + min_v - (vorticity_step * fact + min_v))**2))
        return l2_rel_norm, l2_rel_original, l2_error, max_abs, psnr

    @nnx.jit
    def define_lr_optim(model):
        scheduler = optax.cosine_decay_schedule(init_value=1e-2, decay_steps=5000, alpha=1e-3)
        # optim = nnx.Optimizer(model, optax.adamw(scheduler, weight_decay=1e-5), wrt=nnx.Param)
        optim = nnx.Optimizer(model, soap(learning_rate=scheduler, precondition_frequency=1), wrt=nnx.Param)
        return optim

    @nnx.jit 
    def compiled_merge_and_reset(model):
        merge_lora_params(model)
        reset_lora_params(model)

    key = jax.random.PRNGKey(0)
    for i in tqdm(range(0, 1000), desc='Lora compression'):
        vorticity_step = vorticity[i]
        optim = define_lr_optim(model)
        print(f"Time step: {i} starting now:")

        def train_epoch_fn(carry, x):
            model, optim = carry
            loss_step = train_step(model, optim, coords, vorticity_step)
            return carry, loss_step

        for epoch in range(5000):
            loss_step = train_step(model, optim, coords, vorticity_step)
            if epoch % 500 == 0:
                print(f"Epoch {epoch}: Loss: {loss_step:.2e}")


        l2_rel_norm, l2_rel_original, l2_error, max_abs, pnsr = eval_step(model, coords, vorticity_step)
        print(f"L2 rel norm is: {l2_rel_norm:.2e}")
        print(f"L2 unormalized rel norm is: {l2_rel_original:.2e}")
        print(f"L2 error is: {l2_error:.2e}")
        print(f"Max ABS error is {max_abs:.2e}")
        print(f"PSNR is {pnsr:.2f}")
        
        # compiled_merge_and_reset(model)
        # state = nnx.state(model)
        # if os.path.exists(f".../lora_1000/vorticity{i}"):
        #     shutil.rmtree(f".../lora_1000/vorticity{i}")
        # checkpointer.save(f".../lora_1000/vorticity{i}/state", state)
        # checkpointer.wait_until_finished()

        exit()

