import os
os.environ["JAX_LOG_COMPILES"] = "0"
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_analytical_sol_latency_estimator=false"

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from nnx_models import MLP, SIREN, WIRE
from nnx_models import add_lora_to_model, merge_lora_params, reset_lora_params
import flax.nnx as nnx
import optax
from soap_jax import soap
import orbax.checkpoint as ocp
import absl.logging
import shutil
from tqdm import tqdm
import typing as tp
jax.config.update("jax_enable_x64", False)

absl.logging.set_verbosity(absl.logging.FATAL)
rows_up, cols_up = jnp.triu_indices(3)

def loss_bssn(model, x, y):
    pred = model(x)
    return jnp.mean(jnp.square(pred - y))

def loss_jac_bssn(model, x, y):
    dpreds = jnp.transpose(jax.vmap(jax.jacfwd(model))(x), (0, 2, 1))
    return jnp.mean(jnp.square(dpreds - y))

def loss_norm_grad(model, x, y, dy, param_filter):
    # loss0, grads0 = nnx.value_and_grad(loss_bssn, argnums=nnx.DiffState(0, filter=nnx.PathContains('hidden_layers')))(model, x, y)
    # loss1, grads1 = nnx.value_and_grad(loss_jac_bssn, argnums=nnx.DiffState(0, filter=nnx.PathContains('hidden_layers')))(model, x, dy)

    loss0, grads0 = nnx.value_and_grad(loss_bssn, argnums=nnx.DiffState(0, filter=param_filter))(model, x, y)
    loss1, grads1 = nnx.value_and_grad(loss_jac_bssn, argnums=nnx.DiffState(0, filter=param_filter))(model, x, dy)

    grads0, unflatten = jax.flatten_util.ravel_pytree(grads0)
    grads1, _ = jax.flatten_util.ravel_pytree(grads1)
    grads = jnp.stack([grads0, grads1], axis=0)
    norm_grads = grads / (jnp.linalg.norm(grads, axis=1, keepdims=True))
    # norm_grads = norm_grads.at[0].set(norm_grads[0] * 0.4)
    g = unflatten(jnp.sum(norm_grads, axis=0))
    # g = Config_update_double(grads0, grads1)
    # nnx.to_flat_state
    return g, loss0 + loss1, loss0, loss1

@nnx.jit(static_argnames=['param_filter'])
def train_step_jac(model, optimizer, x, y, dy, param_filter):
    # loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=nnx.Param))(model, x, y, dy)
    # loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=nnx.LoRAParam))(model, x, y)
    # exit()
    g, loss_step, loss0, loss1 = loss_norm_grad(model, x, y, dy, param_filter)

    optimizer.update(g)
    return loss_step, loss0, loss1

@nnx.jit(static_argnames=['param_type'])
def train_step(model, optimizer, x, y, param_type):

    loss_step, g = nnx.value_and_grad(loss_bssn, argnums=nnx.DiffState(0, filter=param_type))(model, x, y)

    optimizer.update(g)

    return loss_step, 0.0, 0.0


def get_spectral_centroid(bssn_variables_array: jax.Array) -> float:
    bssn_variables_reshaped = bssn_variables_array.reshape(213, 213, 213, -1)
    freqx = jnp.fft.fftfreq(213, d = 30 / 212)
    freqy = jnp.fft.fftfreq(213, d = 30 / 212)
    freqz = jnp.fft.fftfreq(213, d = 30 / 212)
    kx, ky, kz = jnp.meshgrid(freqx, freqy, freqz, indexing='ij')
    k_magnitude = jnp.sqrt(kx**2 + ky**2 + kz**2)

    spectral_centroid = []
    for c in range(bssn_variables_reshaped.shape[-1]):
        variable = bssn_variables_reshaped[..., c]
        fft_variable = jnp.fft.fftn(variable)
        amplitude_spectrum = jnp.abs(fft_variable)

        spectral_centroid.append(jnp.sum(k_magnitude * amplitude_spectrum) / jnp.sum(amplitude_spectrum))
        
    return jnp.array(spectral_centroid).mean()

# jax.lax.switch()

# @jax.jit(static_argnames=['d'])
def create_spectrum_switch(max_d):
    """Create a switch function for d values from 0 to max_d"""
    
    # @jax.jit
    def get_s_spectrum_base(fft_spectrum, d_val):
        rows, cols = fft_spectrum.shape
        i_indices = jnp.arange(rows)
        j_indices = d_val - i_indices
        valid_mask = (j_indices >= 0) & (j_indices < cols)
        valid_j = jnp.where(valid_mask, j_indices, 0)
        # print(fft_spectrum[i_indices, valid_j].shape)
        elements = jnp.where(valid_mask, fft_spectrum[i_indices, valid_j], 0)
        # print(valid_mask)
        return jnp.sum(elements)
    

    branches = [lambda x, d=d: get_s_spectrum_base(x, d) for d in range(max_d + 1)]

    # @jax.jit
    def get_s_spectrum_switch(fft_spectrum, d_index):
        # Create list of functions for each d value
        # branches = [lambda x, d=d: get_s_spectrum_base(x, d) for d in range(max_d)]
        return jax.lax.switch(d_index, branches, fft_spectrum)
    
    return get_s_spectrum_switch, jax.jit(jax.vmap(get_s_spectrum_switch, in_axes=(None, 0)))

if __name__ == "__main__":
    jax.config.update("jax_default_matmul_precision", "highest")

    devices = jax.devices()
    num_devices = len(devices)

    mesh_devices = mesh_utils.create_device_mesh((num_devices,))
    mesh = Mesh(mesh_devices, axis_names=('batch',))

    data_sharding = NamedSharding(mesh, P('batch'))
    replicated_sharding = NamedSharding(mesh, P())

    hyperparams = {
        "model": {
            "hidden_dim" : 256,
            "num_hidden_layers": 6,
            "fourier_emb_scale" : 1.0,
            "fourier_emb_dim" : 256,
        },
        "training": {
            "learning_rate": 1e-3,
            "alpha": 1e-2,
            "batch_size": 100000,
            "num_epochs": 11,
            "optim": "soap", # Soap or Adam
            "jac_training": True,
            "print_freq": 10,
            "params_type": nnx.Param, # nnx.LoRAParam for LoRA
            # nnx.PathContains('hidden_layers') for full finetuning with only hidden layers
            "lora_rank": 32,
            "init_step": 3625,
            "final_step": 3625,
        },
        "data_file": "/restricteddata/t_comp/bssn/merger",
        "save_file": "...", # Where to save the model state
        "restore_file": None, # Starting from a checkpoint model state, doesn't work with LoRA
    }

    # with mesh:
    model = MLP(
        input_dim=3,
        output_dim=18,
        hidden_dim=hyperparams["model"]["hidden_dim"],
        fourier_emb_scale=hyperparams["model"]["fourier_emb_scale"],
        fourier_emb_dim=hyperparams["model"]["fourier_emb_dim"],
        num_hidden_layers=hyperparams["model"]["num_hidden_layers"]
    )

    checkpointer = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())


    if hyperparams['restore_file'] is not None:
        graphdef, abstract_state = nnx.split(model)
        state_restored = checkpointer.restore(
            hyperparams["restore_file"],
            args=ocp.args.StandardRestore(abstract_state))
        
        model = nnx.merge(graphdef, state_restored)

    if hyperparams["training"]["params_type"] == nnx.LoRAParam:
        add_lora_to_model(model, lora_rank=hyperparams["training"]["lora_rank"])

    def define_lr(init_lr=1e-2, decay_steps=4001, final_lr=1e-3):
        scheduler = optax.cosine_decay_schedule(init_value=init_lr, decay_steps=decay_steps, alpha=final_lr)
        # scheduler = optax.warmup_cosine_decay_schedule(init_value=init_lr,
        #                         peak_value=1e-3,
        #                         warmup_steps=int(0.3 * decay_steps),
        #                         decay_steps=decay_steps,
        #                         end_value=final_lr,
        #                         exponent=1.0)
        # optim = nnx.Optimizer(model, optax.adam(scheduler), wrt=nnx.Param)
        # optim = nnx.Optimizer(model, opt(learning_rate=scheduler), wrt=nnx.LoRAParam)
        return scheduler

    @nnx.jit
    def eval_step_diff(model, coords, bssn_array):
        pred = []
        for i in range(0, coords.shape[0], 100000):
            chunk = coords[i:i+100000]
            pred_chunk = model(chunk)
            pred.append(pred_chunk)
        pred = jnp.concatenate(pred, axis=0)
        mse = jnp.mean((pred - bssn_array)**2)
        l2_rel_norm = jnp.linalg.norm(pred - bssn_array) / jnp.linalg.norm(bssn_array)
        mean_abs = jnp.mean(jnp.abs(pred - bssn_array))
        max_abs = jnp.max(jnp.abs(pred - bssn_array))
        return mse, l2_rel_norm, mean_abs, max_abs

    @nnx.jit 
    def compiled_merge_and_reset(model):
        merge_lora_params(model)
        reset_lora_params(model, lora_rank=hyperparams['training']['lora_rank'])

    key = jax.random.PRNGKey(0)

    coords = jnp.stack(
        jnp.meshgrid(
            jnp.linspace(-15.0, 15.0, 213),
            jnp.linspace(-15.0, 15.0, 213),
            jnp.linspace(-15.0, 15.0, 213),
            indexing='ij'
        )
    , axis=-1).reshape(-1, 3)


    if hyperparams['restore_file'] is not None:
        init_timestep = f"{hyperparams['training']['init_step']:04d}"

        bssn_variables_array = jnp.load(f"{hyperparams['data_file']}/bssn_variables_{init_timestep}.npy")
        bssn_variables_array = jnp.concatenate([
            bssn_variables_array[:, :3],
            bssn_variables_array[:, 6:],
        ], axis=-1)

        jac_variables_array = jnp.load(f"{hyperparams['data_file']}/jac_variables_{init_timestep}.npy")
        jac_variables_array = jnp.concatenate([
            jac_variables_array[:, :, :3],
            jac_variables_array[:, :, 6:],
        ], axis=-1)

        mse, l2_rel_norm, mean_abs, max_abs = eval_step_diff(model, coords, bssn_variables_array)
        print(f"Initial MSE: {mse:.2e}")
        print(f"L2 rel norm: {l2_rel_norm:.2e}")
        print(f"Mean ABS error is {mean_abs:.2e}")
        print(f"Max ABS: {max_abs:.2e}")
        init_timestep += 1
    else:
        init_timestep = f"{hyperparams['training']['init_step']:04d}"

    print(f"Starting compression with params type: {hyperparams['training']['params_type']}")
    print(f"Fourier embeddings scale: {hyperparams['model']['fourier_emb_scale']}")

    @jax.jit
    def permute_indices(key, coords):
        permuted_indices = jax.random.permutation(key, coords.shape[0])
        return permuted_indices

    for i in tqdm(range(int(init_timestep), hyperparams['training']['final_step'] + 1), desc='Lora compression'):

        bssn_variables_array = jnp.load(f"{hyperparams['data_file']}/bssn_variables_{i:04d}.npy")
        bssn_variables_array = jnp.concatenate([
            bssn_variables_array[:, :3],
            bssn_variables_array[:, 6:],
        ], axis=-1)
        if hyperparams["training"]["jac_training"]:
            jac_variables_array = jnp.load(f"{hyperparams['data_file']}/jac_variables_{i:04d}.npy")
            jac_variables_array = jnp.concatenate([
                jac_variables_array[:, :, :3],
                jac_variables_array[:, :, 6:],
            ], axis=-1)

        timestep = f"{i:04d}"

        print(f"Starting time step {timestep} training.")

        batch_size = hyperparams["training"]["batch_size"]
        # batch_size_per_device = batch_size * num_devices
        schedule = define_lr(init_lr=hyperparams["training"]["learning_rate"], final_lr=hyperparams["training"]["alpha"], decay_steps=hyperparams['training']['num_epochs'] * (coords.shape[0] // batch_size))

        if hyperparams["training"]["optim"] == "soap":
            optim = nnx.Optimizer(model, soap(learning_rate=schedule, precondition_frequency=1), wrt=hyperparams['training']['params_type'])
        elif hyperparams["training"]["optim"] == "adam":
            optim = nnx.Optimizer(model, optax.adam(learning_rate=schedule), wrt=hyperparams['training']['params_type'])
        else:
            raise ValueError(f"Unknown optimizer: {hyperparams['training']['optim']}")
    
        optim_state = nnx.state(optim)
        optim_state = jax.device_put(optim_state, replicated_sharding)
        nnx.update(optim, optim_state)
        
        for epoch in range(hyperparams['training']['num_epochs']):
            key, _ = jax.random.split(key)
            indices = permute_indices(key, coords)
            # for j in range(0, coords.shape[0], 10000):
            loss_avg_total = 0
            loss_avg0 = 0 
            loss_avg1 = 0
            for j in range(0, coords.shape[0], batch_size):
                if j + batch_size > coords.shape[0]:
                    break
                batch_indices = indices[j:j+batch_size]
                coords_batch = jax.device_put(coords[batch_indices], data_sharding)
                bssn_variables_batch = jax.device_put(bssn_variables_array[batch_indices], data_sharding)
                if hyperparams["training"]["jac_training"]:
                    jac_variables_batch = jax.device_put(jac_variables_array[batch_indices], data_sharding)
                    loss_step, loss0, loss1 = train_step_jac(model, optim, coords_batch, bssn_variables_batch, jac_variables_batch, hyperparams['training']['params_type'])
                else:
                    loss_step, loss0, loss1 = train_step(model, optim, coords_batch, bssn_variables_batch, hyperparams['training']['params_type'])
                loss_avg_total += loss_step
                loss_avg0 += loss0
                loss_avg1 += loss1
            loss_avg = loss_avg_total / (coords.shape[0] // batch_size)
            if epoch % hyperparams["training"]["print_freq"] == 0:
                if hyperparams["training"]["jac_training"]:
                    print(f"Epoch {epoch}: Loss total: {loss_avg:.2e} Loss bssn: {loss_avg0 / (coords.shape[0] // batch_size):.2e} Loss jac_bssn: {loss_avg1 / (coords.shape[0] // batch_size):.2e}")
                else:
                    print(f"Epoch {epoch}: Loss BSSN: {loss_avg:.2e}")

        if hyperparams['training']['params_type'] == nnx.LoRAParam:
            merge_lora_params(model)
            reset_lora_params(model, lora_rank=hyperparams['training']['lora_rank'])
        
        mse, l2_rel_norm, mean_abs, max_abs = eval_step_diff(model, coords, bssn_variables_array)
        print(f"L2 rel norm is: {l2_rel_norm:.2e}")
        print(f"Mean ABS error is {mean_abs:.2e}")
        print(f"Max ABS error is {max_abs:.2e}")

        state = nnx.state(model)
        if os.path.exists(hyperparams['save_file'] + "/bssn_variables_" + timestep):
            shutil.rmtree(hyperparams['save_file'] + "/bssn_variables_" + timestep)
        checkpointer.save(hyperparams['save_file'] + "/bssn_variables_" + timestep, state)
    
    checkpointer.close()
