#!/usr/bin/env python
import os
import numpy as np

try:
    import jax

    if all(d.platform != 'gpu' for d in jax.devices()):
        print("⚠️ No GPU detected — enabling multithreading for CPU.")
        os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=10"
except Exception:
    # If JAX isn't installed or fails to load, fall back safely
    pass

import jax
import jax.numpy as jnp  # JAX NumPy
from clu import metrics
from flax import struct
import optax  # Common loss functions and optimizers
import sys
import json
import flax.serialization as serialization
import jax.tree_util

jax.config.update("jax_traceback_filtering", 'off')
print("Devices available:", jax.devices())
import optimizers
import training


from mlp_models_multilayer import CircularMLP, MLPOneEmbed, MLPOneHot

### NOTES ###
if len(sys.argv) < 13:
    print("Usage: script.py <learning_rate> <weight_decay> <p> <batch_size> <optimizer> <epochs> <k> <batch_experiment> <num_neurons> <MLP_class> <features> <num_layers> <random_seed_int_1> [<random_seed_int_2> ...]")
    sys.exit(1)

print("start args parsing")
# Parse command-line arguments
learning_rate = float(sys.argv[1])  # stepsize_
weight_decay = float(sys.argv[2])     # L2 norm
p = int(sys.argv[3])
batch_size = int(sys.argv[4])
optimizer = sys.argv[5]
epochs = int(sys.argv[6])
k = int(sys.argv[7])
batch_experiment = sys.argv[8]
num_neurons = int(sys.argv[9])
MLP_class = sys.argv[10]
training_set_size = k * batch_size
# read features from somewhere
features = int(sys.argv[11])
num_layers = int(sys.argv[12])
top_k = [1]

# --- Process remaining arguments ---
random_seed_ints = [int(arg) for arg in sys.argv[13:]]
num_models = len(random_seed_ints)
print(f"Random seeds: {random_seed_ints}")

print("making dataset")

def generate_polynomial_dataset(p, f, num_train_batches, batch_size, rng: jax.random.PRNGKey):
    total_possible_samples = p * p
    total_samples_needed = num_train_batches * batch_size

    if total_samples_needed > total_possible_samples:
        raise ValueError("Not enough data samples for the requested number of batches.")

    a, b = jnp.mgrid[0:p, 0:p]
    y = f(a, b)

    a_flat = a.ravel()
    b_flat = b.ravel()
    y_flat = y.ravel()
    data = jnp.stack([a_flat, b_flat, y_flat], axis=1)

    rng, subkey = jax.random.split(rng)
    indices = jax.random.choice(subkey, total_possible_samples, (total_samples_needed,), replace=False)
    train_data = data[indices]
    train_data = train_data.reshape(num_train_batches, batch_size, 3)
    return train_data

def generate_polynomial_dataset_for_seed(seed):
    rng_key = jax.random.PRNGKey(seed)
    num_train_batches = k
    train_data = generate_polynomial_dataset(
        p, lambda a, b: jnp.mod(a + b, p), num_train_batches, batch_size, rng_key)
    return train_data

if batch_experiment == "random_random":
    train_ds_list = []
    for seed in random_seed_ints:
        train_data = generate_polynomial_dataset_for_seed(seed)
        train_ds_list.append(train_data)
    train_ds = jnp.stack(train_ds_list)
    print(f"Number of training batches: {train_ds.shape[1]}")

print("made dataset")

def compute_pytree_size(pytree):
    total_size = 0
    for array in jax.tree_util.tree_leaves(pytree):
        total_size += array.size * array.dtype.itemsize
    return total_size

dataset_size_bytes = (train_ds.shape[1] * train_ds.shape[2] * train_ds.shape[3] *
                      train_ds.dtype.itemsize)
dataset_size_mb = dataset_size_bytes / (1024 ** 2)
print(f"Dataset size per model: {dataset_size_mb:.2f} MB")

def positive_he_normal(key, shape, dtype=jnp.float32):
    init = jax.nn.initializers.he_normal()(key, shape, dtype)
    return jnp.abs(init)

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss') # type: ignore
    l2_loss: metrics.Average.from_output('l2_loss') # type: ignore

model: CircularMLP
mlp_class_lower = f"{MLP_class.lower()}_{num_layers}"
model_class_map = {
    "no_embed": MLPOneHot,
    "one_embed": MLPOneEmbed,
}

base_class_name = MLP_class.lower()

if base_class_name not in model_class_map:
    raise ValueError(f"Unknown MLP_class: {MLP_class}")

model_class = model_class_map[base_class_name]

kwargs = dict(p=p, num_neurons=num_neurons, num_layers=num_layers)
if "embed" in base_class_name:
    kwargs["features"] = features

model = model_class(**kwargs)

dummy_x = jnp.zeros(shape=(batch_size, 2), dtype=jnp.int32)

def cross_entropy_loss(y_pred, y):
    return optax.softmax_cross_entropy_with_integer_labels(logits=y_pred, labels=y).mean()

def total_loss(y_pred_and_l2, y):
    y_pred, pre_activation, l2_loss = y_pred_and_l2
    return cross_entropy_loss(y_pred, y) + l2_loss * weight_decay

def apply(variables, x, training=False):
    params = variables['params']
    batch_stats = variables.get("batch_stats", None)
    if batch_stats is None:
        batch_stats = {}
    outputs, updates = model.apply({'params': params, 'batch_stats': batch_stats}, x, training=training,
                                   mutable=['batch_stats'] if training else [])
    x_out, pre_activation, _, _ = outputs
    l2_loss = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
    return x_out, updates, l2_loss

def batched_apply(variables_batch, x_batch, training=False):
    outputs, updates, l2_loss = jax.vmap(lambda vars, x: apply(vars, x, training), in_axes=(0, 0))(variables_batch, x_batch)
    return outputs, updates, l2_loss

def sample_hessian(prediction, sample):
    return (optimizers.sample_crossentropy_hessian(prediction, sample[0]), 0.0, 0.0)

def compute_metrics(metrics, *, loss, l2_loss, outputs, labels):
    logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
    metric_updates = metrics.single_from_model_output(
        logits=logits, labels=labels, loss=loss, l2_loss=l2_loss)
    return metrics.merge(metric_updates)

def prepare_batches(batches_array):
    x = batches_array[:, :, :, :2].astype(jnp.int32)
    y = batches_array[:, :, :, 2].astype(jnp.int32)
    return x, y

print("model made")

def init_model(seed):
    rng_key = jax.random.PRNGKey(seed)
    variables = model.init(rng_key, dummy_x, training=False)
    return variables

variables_list = []
for seed in random_seed_ints:
    variables = init_model(seed)
    variables_list.append(variables)

model_size_bytes = compute_pytree_size(variables_list[0]['params'])
model_size_mb = model_size_bytes / (1024 ** 2)
print(f"Single model size: {model_size_mb:.2f} MB")

variables_batch = {}
variables_batch['params'] = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *(v['params'] for v in variables_list))
variables_batch['batch_stats'] = None

params_batch = variables_batch['params']
batch_stats_batch = variables_batch.get('batch_stats', None)

if optimizer == "adam":
    tx = optax.adam(learning_rate)
elif optimizer[:3] == "SGD":
    tx = optax.sgd(learning_rate, 0.0)
else:
    raise ValueError("Unsupported optimizer type")

def init_opt_state(params):
    return tx.init(params)

opt_state_list = []
for i in range(num_models):
    params_i = jax.tree_util.tree_map(lambda x: x[i], params_batch)
    opt_state = init_opt_state(params_i)
    opt_state_list.append(opt_state)

opt_state_batch = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *opt_state_list)

def create_train_state(params, opt_state, rng_key, batch_stats):
    state = training.TrainState(
        apply_fn=apply, params=params, tx=tx,
        opt_state=opt_state,
        loss_fn=total_loss,
        loss_hessian_fn=sample_hessian,
        compute_metrics_fn=compute_metrics,
        rng_key=rng_key,
        initial_metrics=Metrics,
        batch_stats=batch_stats,
        injected_noise=0.0
    )
    return state

states_list = []
for i in range(num_models):
    seed = random_seed_ints[i]
    rng_key = jax.random.PRNGKey(seed)
    params_i = jax.tree_util.tree_map(lambda x: x[i], params_batch)
    opt_state_i = jax.tree_util.tree_map(lambda x: x[i], opt_state_batch)
    batch_stats = None
    state = create_train_state(params_i, opt_state_i, rng_key, batch_stats)
    states_list.append(state)

states = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *states_list)

train_x, train_y = prepare_batches(train_ds)
train_x = jax.device_put(train_x)
train_y = jax.device_put(train_y)

initial_metrics_list = [state.initial_metrics.empty() for state in states_list]
initial_metrics = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *initial_metrics_list)

### Added for test evaluation ###
a_eval, b_eval = jnp.mgrid[0:p, 0:p]
a_eval = a_eval.ravel()
b_eval = b_eval.ravel()
x_eval = jnp.stack([a_eval, b_eval], axis=-1).astype(jnp.int32)
y_eval = jnp.mod(a_eval + b_eval, p).astype(jnp.int32)

x_eval = jax.device_put(x_eval)
y_eval = jax.device_put(y_eval)

x_eval_expanded = jnp.tile(x_eval[None, :, :], (num_models, 1, 1))
y_eval_expanded = jnp.tile(y_eval[None, :], (num_models, 1))

eval_batch_size = 1024
total_eval_samples = x_eval.shape[0]
num_full_batches = total_eval_samples // eval_batch_size
remaining_samples = total_eval_samples % eval_batch_size

if remaining_samples > 0:
    pad_size = eval_batch_size - remaining_samples
    x_padding = jnp.zeros((num_models, pad_size, x_eval.shape[1]), dtype=x_eval.dtype)
    y_padding = jnp.zeros((num_models, pad_size), dtype=y_eval.dtype)
    x_eval_padded = jnp.concatenate([x_eval_expanded, x_padding], axis=1)
    y_eval_padded = jnp.concatenate([y_eval_expanded, y_padding], axis=1)
    num_eval_batches = num_full_batches + 1
else:
    x_eval_padded = x_eval_expanded
    y_eval_padded = y_eval_expanded
    num_eval_batches = num_full_batches

x_eval_batches = x_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size, -1)
y_eval_batches = y_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size)

# For horizontal embeddings, vertical embeddings, and p^2 combinations,
# we now extract the effective embeddings directly using our updated helper functions.
model_dir = (
    f"embed_128/models_{mlp_class_lower}/"
    f"p={p}_bs={batch_size}_nn={num_neurons}_wd={weight_decay}_epochs={epochs}_"
    f"training_set_size={training_set_size}"
)
os.makedirs(model_dir, exist_ok=True)

# === Logging dictionaries for metrics (per epoch) ===
log_by_seed = {seed: {} for seed in random_seed_ints}

# === NEW: DFT logging dictionary (kept as in original) ===
epoch_dft_logs_by_seed = { seed: {} for seed in random_seed_ints }

# Preexisting logs for effective embeddings, preactivations, and logits.
epoch_embedding_log = {}
epoch_preactivation_log = {}
epoch_logits_log = {}

# === Training and Evaluation Loops ===
@jax.jit
def train_epoch(states, x_batches, y_batches, initial_metrics):
    def train_step(state_metrics, batch):
        states, metrics = state_metrics
        x, y = batch
        new_states, new_metrics = jax.vmap(
            lambda state, metric, x, y: state.train_step(metric, (x, y)),
            in_axes=(0, 0, 0, 0)
        )(states, metrics, x, y)
        return (new_states, new_metrics), None
    initial_state_metrics = (states, initial_metrics)
    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    (new_states, new_metrics), _ = jax.lax.scan(
        train_step,
        initial_state_metrics,
        (transposed_x, transposed_y)
    )
    return new_states, new_metrics

@jax.jit
def eval_model(states, x_batches, y_batches, initial_metrics):
    def eval_step(metrics, batch):
        x, y = batch
        new_metrics = jax.vmap(
            lambda state, metric, x, y: state.eval_step(metric, (x, y)),
            in_axes=(0, 0, 0, 0)
        )(states, metrics, x, y)
        return new_metrics, None
    metrics = initial_metrics
    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    final_metrics, _ = jax.lax.scan(
        eval_step,
        metrics,
        (transposed_x, transposed_y)
    )
    return final_metrics


@jax.jit
def compute_margin_stats(params, xs, ys):
    # xs: (N,2) int32; ys: (N,) int32
    logits = model.apply({'params': params}, xs, training=False)[0]  # (N, C)
    # correct‐class logits
    correct = logits[jnp.arange(xs.shape[0]), ys]
    # mask out correct class
    one_hot = jax.nn.one_hot(ys, logits.shape[1], dtype=bool)
    masked = jnp.where(one_hot, -1e9, logits)
    runner = jnp.max(masked, axis=1)
    margins = correct - runner  # shape (N,)
    return jnp.min(margins), jnp.mean(margins)


# === Training Loop ===
first_100_acc_epoch_by_seed = {seed: None for seed in random_seed_ints}
first_epoch_loss_by_seed = {seed: None for seed in random_seed_ints}
first_epoch_ce_loss_by_seed = {seed: None for seed in random_seed_ints} 

coords_full = np.stack([a_eval, b_eval], axis=1)
train_coords_by_seed = {}
test_coords_by_seed  = {}
for i, seed in enumerate(random_seed_ints):
    # flatten seed’s train data to (k*batch_size, 3) numpy
    train_flat = np.array(train_ds_list[i].reshape(-1, 3))
    seen = set(map(tuple, train_flat[:, :2].tolist()))
    train_coords = np.array([xy for xy in coords_full if tuple(xy) in seen], dtype=int)
    test_coords  = np.array([xy for xy in coords_full if tuple(xy) not in seen], dtype=int)
    train_coords_by_seed[seed] = train_coords
    test_coords_by_seed[seed]  = test_coords

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    # === Training ===
    states, train_metrics = train_epoch(states, train_x, train_y, initial_metrics)
    train_losses = []
    train_accuracies = []

    # === Optional: Evaluation at intervals ===
    do_eval = (epoch + 1) % 2500 == 0 or (epoch + 1) == epochs
    if True:
        print(f"\n--- Test Evaluation at Epoch {epoch + 1} ---")
        test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
        test_losses = []
        test_accuracies = []

    for i in range(num_models):
        seed = random_seed_ints[i]

        # --- Train metrics ---
        train_metric = jax.tree_util.tree_map(lambda x: x[i], train_metrics)
        train_metric = train_metric.compute()
        train_loss = float(train_metric['loss'])
        train_acc = float(train_metric['accuracy'])
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        print(f"Model {i + 1}/{num_models}: Train Loss: {train_loss:.6f}, Train Accuracy: {train_acc:.2%}")

        # --- Test metrics (only if needed) ---
        if True:
            test_metric = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
            test_metric = test_metric.compute()
            test_loss = float(test_metric['loss'])
            test_accuracy = float(test_metric['accuracy'])
            test_l2_loss = float(test_metric['l2_loss'])
            test_ce_loss = test_loss - weight_decay * test_l2_loss

            test_losses.append(test_loss)
            test_accuracies.append(test_accuracy)

            print(f"Model {i + 1}/{num_models}: Test CE Loss: {test_ce_loss:.6f}, Test Total Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.2%}")

            if first_100_acc_epoch_by_seed[seed] is None and test_accuracy >= 0.999999:
                first_100_acc_epoch_by_seed[seed] = epoch + 1
                first_epoch_loss_by_seed[seed] = test_loss
                first_epoch_ce_loss_by_seed[seed] = test_ce_loss

                print(
                    f"*** Seed {seed} first reached 100% accuracy at epoch {epoch + 1} "
                    f"with total loss {test_loss:.6f} and CE-only loss {test_ce_loss:.6f} ***"
                )

            # --- Log to dictionary ---
            params_i = jax.tree_util.tree_map(lambda x: x[i], states.params)
            weight_norm = float(sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params_i)))

            log_by_seed[seed][epoch + 1] = {
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "test_loss": test_loss,
                "test_ce_loss": test_ce_loss, 
                "test_accuracy": test_accuracy,
                "l2_weight_norm": weight_norm
            }

            tc = train_coords_by_seed[seed]
            ty = jnp.mod(tc[:, 0] + tc[:, 1], p)
            train_min, train_avg = compute_margin_stats(params_i, jnp.array(tc), ty)

            # test margin
            if(k**2!=p**2):
                vc = test_coords_by_seed[seed]
                vy = jnp.mod(vc[:, 0] + vc[:, 1], p)
            else:
                vc = tc
                vy = ty
            test_min, test_avg = compute_margin_stats(params_i, jnp.array(vc), vy)

            # total margin (use full eval grid)
            total_min, total_avg = compute_margin_stats(params_i, x_eval, y_eval)

            # update log
            log_by_seed[seed][epoch + 1].update({
                "train_margin":      float(train_min),
                "train_avg_margin":  float(train_avg),
                "test_margin":       float(test_min),
                "test_avg_margin":   float(test_avg),
                "total_margin":      float(total_min),
                "total_avg_margin":  float(total_avg),
            })


    if do_eval:
        print("--- End of Test Evaluation ---\n")

    
    # === NEW: Log full DFT for every neuron using fixed frequency inputs ===
    if (epoch + 1) % 10000 == 0 or (epoch + 1) == epochs:
        for i in range(num_models):
            params_i = jax.tree_util.tree_map(lambda x: x[i], states.params)
            x_freq_b2 = jnp.array([[a, 2] for a in range(p)], dtype=jnp.int32)
            x_freq_b3 = jnp.array([[a, 3] for a in range(p)], dtype=jnp.int32)
            _, pre_acts_b2, _, _ = model.apply({'params': params_i}, x_freq_b2, training=False)
            _, pre_acts_b3, _, _ = model.apply({'params': params_i}, x_freq_b3, training=False)

            # Use only the first hidden layer (index 0)
            pre_act_b2_np = np.array(pre_acts_b2[0])
            pre_act_b3_np = np.array(pre_acts_b3[0])
            seed = random_seed_ints[i]
            if (epoch + 1) not in epoch_dft_logs_by_seed[seed]:
                epoch_dft_logs_by_seed[seed][epoch + 1] = {}
            num_neurons_in_layer = pre_act_b2_np.shape[1]
            for neuron_idx in range(num_neurons_in_layer):
                neuron_pre_b2 = pre_act_b2_np[:, neuron_idx]
                neuron_pre_b3 = pre_act_b3_np[:, neuron_idx]
                fft_b2 = np.fft.fft(neuron_pre_b2)
                fft_b3 = np.fft.fft(neuron_pre_b3)
                max_b2 = np.max(np.abs(fft_b2))
                max_b3 = np.max(np.abs(fft_b3))
                if max_b2 >= max_b3:
                    chosen_fft = fft_b2
                else:
                    chosen_fft = fft_b3
                unique_range = range(1, (p // 2) + 1)
                dft_dict = {str(freq): float(np.abs(chosen_fft[freq])) for freq in unique_range}
                epoch_dft_logs_by_seed[seed][epoch + 1][neuron_idx] = dft_dict

    current_epoch = epoch + 1 
    # === NEW: Every 200 epochs, log full effective input weights, preactivations, and logits ===
    if current_epoch % 10000 == 0 or (current_epoch + 1) == epochs:
        # For the new MLP, effective embeddings come from the appropriate source.
        params_single = jax.tree_util.tree_map(lambda x: x[0], states.params)
        print(params_single.keys())
        horizontal_embeddings = model.extract_effective_embeddings_horizontal(params_single)
        epoch_embedding_log[current_epoch] = horizontal_embeddings.tolist()
        horizontal_file = os.path.join(model_dir, f"interactive_embedding_map_horizontal.html")
        #commented this:
        print(f"Logged horizontal effective embeddings (p points) for epoch {current_epoch}.")

        a_indices, b_indices = np.meshgrid(np.arange(p), np.arange(p), indexing='ij')
        a_indices = a_indices.ravel()
        b_indices = b_indices.ravel()
        full_input = np.stack([a_indices, b_indices], axis=-1).astype(np.int32)
        logits, preactivations, _, _ = model.apply({'params': params_single}, full_input, training=False)
        epoch_preactivation_log[current_epoch] = [np.array(layer).tolist() for layer in preactivations]
        epoch_logits_log[current_epoch] = logits.tolist()
        print(f"Logged full preactivations and logits (p^2 points) for epoch {current_epoch}.")

# --- After training: Write out the per-seed metric logs ---
freq_json_dir = os.path.join(
        f"multilayer_heatmaps_logn-neurips-run-2/{mlp_class_lower}-freqs_{p}-",
        f"freq_distribution_mlp={mlp_class_lower}_p={p}_bs={batch_size}_k={k}_nn={num_neurons}_wd={weight_decay}_lr={learning_rate}"
    )
os.makedirs(freq_json_dir, exist_ok=True)

for seed in random_seed_ints:
    log_file_path = os.path.join(freq_json_dir, f"log_seed_{seed}.json")
    with open(log_file_path, "w") as f:
        json.dump(log_by_seed[seed], f, indent=2)
    print(f"Final log for seed {seed} saved to {log_file_path}")

# === Final Evaluation on Test Set ===
print("Starting final evaluation...")
test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
network_metrics = {}  # To store loss and l2_loss for each seed.
for i in range(num_models):
    test_metric = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
    test_metric = test_metric.compute()
    test_loss = float(test_metric["loss"])
    test_accuracy = float(test_metric["accuracy"])
    test_l2_loss = float(test_metric["l2_loss"])  # extract l2_loss from metrics
    print(f"Model {i + 1}/{num_models}: Final Test Loss: {test_loss:.6f}, Final Test Accuracy: {test_accuracy * 100:.2f}%")
    network_metrics[random_seed_ints[i]] = {"loss": test_loss, "l2_loss": test_l2_loss}
    if test_accuracy >= 0.999:
        experiment_name = batch_experiment
        optimizer_name = optimizer
        params_file_path = os.path.join(
            model_dir,
            f"params_p_{p}_{batch_experiment}_{optimizer_name}_ts_{training_set_size}_"
            f"bs={batch_size}_nn={num_neurons}_lr={learning_rate}_wd={weight_decay}_"
            f"rs_{random_seed_ints[i]}.params"
        )
        os.makedirs(os.path.dirname(params_file_path), exist_ok=True)
        with open(params_file_path, 'wb') as f:
            f.write(serialization.to_bytes(jax.tree_util.tree_map(lambda x: x[i], states.params)))
        print(f"Model {i + 1}: Parameters saved to {params_file_path}")
    else:
        print(f"Model {i + 1}: Test accuracy did not exceed 99.9%. Model parameters will not be saved.")
        print(f"\n--- Misclassified Test Examples for Model {i + 1} ---")
        logits, _, _, _ = model.apply({'params': jax.tree_util.tree_map(lambda x: x[i], states.params)}, x_eval, training=False)
        predictions = jnp.argmax(logits, axis=-1)
        y_true = y_eval
        incorrect_mask = predictions != y_true
        incorrect_indices = jnp.where(incorrect_mask)[0]
        if incorrect_indices.size > 0:
            misclassified_x = x_eval[incorrect_indices]
            misclassified_y_true = y_true[incorrect_indices]
            misclassified_y_pred = predictions[incorrect_indices]
            print(f"Total Misclassifications: {len(incorrect_indices)}")
            for idx, (x_vals, true_label, pred_label) in enumerate(zip(misclassified_x, misclassified_y_true, misclassified_y_pred), 1):
                a_val, b_val = x_vals
                print(f"{idx}. a: {int(a_val)}, b: {int(b_val)}, True: {int(true_label)}, Predicted: {int(pred_label)}")
        else:
            print("No misclassifications found. All predictions are correct.")

# === FINAL UPGRADE: Build new dictionaries based on final epoch grouping for DFT logs ===
final_epoch = epochs
final_grouping = {}
for seed in random_seed_ints:
    final_grouping[seed] = {}
    final_epoch_log = epoch_dft_logs_by_seed[seed][final_epoch]
    for neuron_idx, dft_dict in final_epoch_log.items():
        max_freq = int(max(dft_dict, key=lambda k: dft_dict[k]))
        if max_freq not in final_grouping[seed]:
            final_grouping[seed][max_freq] = []
        final_grouping[seed][max_freq].append(neuron_idx)
# for seed in random_seed_ints:
#     for freq, neuron_list in final_grouping[seed].items():
#         new_dict = {}
#         for epoch_num, epoch_dict in epoch_dft_logs_by_seed[seed].items():
#             filtered_neurons = {str(neuron_idx): epoch_dict[neuron_idx] for neuron_idx in neuron_list if neuron_idx in epoch_dict}
#             new_dict[epoch_num] = filtered_neurons
#         output_filepath = os.path.join(model_dir, f"frequency_{freq}_log_seed_{seed}.json")
#         with open(output_filepath, "w") as f:
#             json.dump(new_dict, f, indent=2)
#         print(f"Frequency log for frequency {freq} (seed {seed}) saved to {output_filepath}")


from plots_multilayer import (
    plot_cluster_preactivations,
    reconstruct_sine_fits_multilayer_logn_fits_layers_after_2,
    fit_sine_wave_multi_freq
)

def convert_to_builtin_type(obj):
    if isinstance(obj, (np.ndarray, jnp.ndarray)):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_to_builtin_type(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_to_builtin_type(v) for v in obj]
    else:
        return obj

import copy

def zero_dead_neurons_general(params, dead_neurons_by_layer):
    """
    Zero out the weights for dead neurons in a network of arbitrary depth.
    Works whether the first hidden layer is named "dense", "dense_1", or "input_dense".
    """
    new_params = copy.deepcopy(params)

    # 1. Identify the first‑layer key
    for cand in ("dense", "dense_1", "input_dense"):
        if cand in new_params:
            first_layer_key = cand
            break
    else:
        raise ValueError("Could not find first hidden layer in parameters.")

    # 2. Collect deeper hidden layers
    additional_keys = [k for k in new_params
                       if k.startswith("dense_") and k != first_layer_key and k != "output_dense"]
    additional_keys.sort(key=lambda k: int(k.split("_")[1]))  # e.g. dense_2 < dense_3 < …

    final_output_key = "output_dense"

    # 3. Sanity‑check the dead‑neuron list length
    total_layers = 1 + len(additional_keys)
    if len(dead_neurons_by_layer) != total_layers:
        raise ValueError(f"Expected {total_layers} dead‑neuron lists, got {len(dead_neurons_by_layer)}.")

    # 4. Zero first layer
    for idx in dead_neurons_by_layer[0]:
        new_params[first_layer_key]["kernel"] = (
            new_params[first_layer_key]["kernel"].at[:, idx].set(0.0)
        )
        new_params[first_layer_key]["bias"] = (
            new_params[first_layer_key]["bias"].at[idx].set(0.0)
        )

    # Also zero outgoing weights into the next layer, if any.
    if additional_keys:
        next_key = additional_keys[0]
        for idx in dead_neurons_by_layer[0]:
            new_params[next_key]["kernel"] = (
                new_params[next_key]["kernel"].at[idx, :].set(0.0)
            )

    # 5. Zero deeper layers.
    for i, key in enumerate(additional_keys):
        current_dead = dead_neurons_by_layer[i + 1]
        # Zero incoming weights and bias.
        for idx in current_dead:
            new_params[key]["kernel"] = new_params[key]["kernel"].at[:, idx].set(0.0)
            new_params[key]["bias"] = new_params[key]["bias"].at[idx].set(0.0)
        # Zero outgoing weights.
        if i < len(additional_keys) - 1:
            next_key = additional_keys[i + 1]
            for idx in current_dead:
                new_params[next_key]["kernel"] = (
                    new_params[next_key]["kernel"].at[idx, :].set(0.0)
                )
        elif final_output_key in new_params:
            for idx in current_dead:
                new_params[final_output_key]["kernel"] = (
                    new_params[final_output_key]["kernel"].at[idx, :].set(0.0)
                )

    return new_params

def run_reconstruction(model, model_params, seed, top_k_val,
                       p, batch_size, k, weight_decay, learning_rate,
                       mlp_class_lower, contrib_a_np, contrib_b_np, bias_layer1,
                       model_accuracy, test_total_loss, test_ce_loss):
    """
    Run the reconstruction process for a given top-k value.

    Args:
      model: The MLP model instance.
      model_params: The parameters for this model (for seed `seed`).
      seed: The current seed (for logging file names).
      top_k_val: The number of key frequencies to use for the reconstruction.
      p, batch_size, k, weight_decay, learning_rate: Hyperparameters.
      mlp_class_lower: Lower-cased version of the model class name (used in file paths).
      contrib_a_np, contrib_b_np: Precomputed contribution arrays from layer 1.
      bias_layer1: Bias values extracted from layer 1.
      model_accuracy: The final test accuracy of this model.
      test_total_loss: The final test total loss of this model.
      test_ce_loss: The final test cross-entropy loss of this model.
      
    Returns:
      A tuple containing:
         - reconstruction_metrics: dict with loss/accuracy metrics.
         - layer1_freq: frequency distribution for layer 1 (to learn new top-k).
         - neuron_data: the dictionary containing per-neuron data.
         - dominant_freq_clusters: the dominant frequency clusters from reconstruction.
         - freq_json_dir: directory where reconstruction metrics are saved.
    """
    # --- Gather additional layer parameters ---
    additional_layer_keys = [key for key in model_params
                             if key.startswith("dense_") and key not in ("dense_1", "output_dense")]
    additional_layer_keys.sort(key=lambda k: int(k.split("_")[1]))
    additional_layers_params = [model_params[k] for k in additional_layer_keys]

    # --- Call the multilayer reconstruction function ---
    (layer1_freq,
     additional_layers_freq,
     layer1_fits,
     additional_layers_fits_lookup,
     dead_neurons_layer1,
     additional_layers_dead_neurons,
     dominant_freq_clusters) = reconstruct_sine_fits_multilayer_logn_fits_layers_after_2(
            contrib_a_np,
            contrib_b_np,
            bias_layer1,
            additional_layers_params,
            p,
            top_k=top_k_val
    )
    num_neurons_layer1 = contrib_a_np.shape[1]

    # --- Build neuron data (for saving reconstructed preactivations, fitted preactivations, etc.) ---
    neuron_data = {}
    a_vals, b_vals = np.arange(p), np.arange(p)
    a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")
    ab_inputs = np.stack([a_grid.ravel(), b_grid.ravel()], axis=-1).astype(np.int32)
    _, pre_acts_all, _, _ = model.apply({'params': model_params}, ab_inputs, training=False)
    real_preacts = [np.array(act).reshape((p, p, -1)) for act in pre_acts_all]

    # ---- Layer 1 Reconstruction ----
    neuron_data[1] = {}
    for neuron_idx, (fit_a, fit_b, bias_val) in enumerate(layer1_fits):
        fitted = np.zeros((p, p))
        for a in range(p):
            for b in range(p):
                fitted[a, b] = fit_a(a) + fit_b(b) + bias_val
        real = real_preacts[0][:, :, neuron_idx]
        postact = np.maximum(real, 0.0)
        neuron_data[1][neuron_idx] = {
            'a_values': np.arange(p),
            'b_values': np.arange(p),
            'real_preactivations': real,
            'fitted_preactivations': fitted,
            'postactivations': postact,
        }

    # ---- Additional Layers Reconstruction ----
    for layer_num, fit_lookup in enumerate(additional_layers_fits_lookup, start=2):
        neuron_data[layer_num] = {}
        real_layer = real_preacts[layer_num - 1]  # layer_num=2 corresponds to index 1.
        for neuron_idx, row_fns in enumerate(fit_lookup):
            fitted = np.zeros((p, p))
            for a in range(p):
                for b in range(p):
                    fitted[a, b] = row_fns[a](b)
            real = real_layer[:, :, neuron_idx]
            postact = np.maximum(real, 0.0)
            neuron_data[layer_num][neuron_idx] = {
                'a_values': np.arange(p),
                'b_values': np.arange(p),
                'real_preactivations': real,
                'fitted_preactivations': fitted,
                'postactivations': postact,
            }

    # --- Set up the directory for logging frequency distributions ---
    freq_json_dir = os.path.join(
        f"multilayer_heatmaps_logn-neurips-run-2/{mlp_class_lower}-freqs_{p}-",
        f"freq_distribution_mlp={mlp_class_lower}_p={p}_bs={batch_size}_k={k}_nn={num_neurons_layer1}_wd={weight_decay}_lr={learning_rate}"
    )
    os.makedirs(freq_json_dir, exist_ok=True)
    
    # --- Save frequency distributions ---
    layer1_json_path = os.path.join(freq_json_dir, f"freq_distribution_layer_1_top-k_{top_k_val}_seed_{seed}.json")
    with open(layer1_json_path, "w") as f:
        json.dump(convert_to_builtin_type({str(k): v for k, v in layer1_freq.items()}), f, indent=2)
    for idx, layer_freq in enumerate(additional_layers_freq, start=2):
        layer_json_path = os.path.join(freq_json_dir, f"freq_distribution_layer_{idx}_top-k_{top_k_val}_seed_{seed}.json")
        with open(layer_json_path, "w") as f:
            json.dump(convert_to_builtin_type(layer_freq), f, indent=2)

    first100_path = os.path.join(
        freq_json_dir,
        f"first100_layer_{idx}_top-k_{top_k_val}_seed_{seed}.json"
    )
    # add first_epoch_ce_loss_by_seed
    with open(first100_path, "w") as f:
        json.dump({
            "epoch": first_100_acc_epoch_by_seed[seed],
            "ce_loss": first_epoch_ce_loss_by_seed[seed]
        }, f, indent=2)


    # --- Reconstruct the network output ---
    # Reconstruct layer 1 using stored fits
    h1_dead = np.zeros((p, p, num_neurons_layer1))
    for n in range(num_neurons_layer1):
        fit_a, fit_b, bias_val = layer1_fits[n]
        for a in range(p):
            for b in range(p):
                h1_dead[a, b, n] = np.maximum(fit_a(a) + fit_b(b) + bias_val, 0.0)
    # On-the-fly reconstruction for layer 1.
    h1_sim = np.zeros((p, p, num_neurons_layer1))
    for n in range(num_neurons_layer1):
        y_a = contrib_a_np[:, n]
        y_b = contrib_b_np[:, n]
        fit_a_sim, _ = fit_sine_wave_multi_freq(y_a, p, top_k=top_k_val)
        fit_b_sim, _ = fit_sine_wave_multi_freq(y_b, p, top_k=top_k_val)
        bias_val = bias_layer1[n]
        for a in range(p):
            for b in range(p):
                h1_sim[a, b, n] = np.maximum(fit_a_sim(a) + fit_b_sim(b) + bias_val, 0.0)
    h_reconstructed = h1_dead.copy()
    h_sim = h1_sim.copy()

    # Process additional layers sequentially.
    for layer_idx, key in enumerate(additional_layer_keys, start=2):
        current_weights = np.array(model_params[key]["kernel"])
        current_bias = np.array(model_params[key]["bias"])
        h_pre_sim = np.einsum('abn,nm->abm', h_sim, current_weights) + current_bias
        h_sim = np.maximum(h_pre_sim, 0)
        h_pre = np.einsum('abn,nm->abm', h_reconstructed, current_weights) + current_bias
        num_neurons_current = h_pre.shape[-1]
        h_reconstructed_new = np.zeros((p, p, num_neurons_current))
        lookup_table = additional_layers_fits_lookup[layer_idx - 2]
        for m in range(num_neurons_current):
            for a in range(p):
                for b in range(p):
                    h_reconstructed_new[a, b, m] = lookup_table[m][a](b)
        h_reconstructed = np.maximum(h_reconstructed_new, 0)

    # Apply the final output layer.
    final_layer_weights = np.array(model_params["output_dense"]["kernel"])
    output_bias = np.array(model_params["output_dense"].get("bias", np.zeros(p)))
    logits_reconstructed_with_dead = np.dot(
        h_reconstructed.reshape(-1, h_reconstructed.shape[-1]),
        final_layer_weights
    ) + output_bias
    logits_reconstructed_with_dead = logits_reconstructed_with_dead.reshape(p, p, -1)
    logits_reconstructed = np.dot(
        h_sim.reshape(-1, h_sim.shape[-1]),
        final_layer_weights
    ) + output_bias
    logits_reconstructed = logits_reconstructed.reshape(p, p, -1)

    # --- Compute test accuracy and losses ---
    a_vals = np.arange(p)
    b_vals = np.arange(p)
    a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing='ij')
    labels = (a_grid + b_grid) % p

    preds_dead = np.argmax(logits_reconstructed_with_dead, axis=-1)
    fitted_accuracy_with_dead = np.mean(preds_dead == labels) * 100
    preds_sim = np.argmax(logits_reconstructed, axis=-1)
    fitted_accuracy_sim = np.mean(preds_sim == labels) * 100

    def compute_loss_and_accuracy(logits, p):
        logits_flat = logits.reshape(-1, p)
        labels_flat = labels.reshape(-1)
        losses = []
        for i in range(logits_flat.shape[0]):
            logit_i = logits_flat[i]
            max_logit = np.max(logit_i)
            logsumexp = max_logit + np.log(np.sum(np.exp(logit_i - max_logit)))
            loss_i = -logit_i[labels_flat[i]] + logsumexp
            losses.append(loss_i)
        ce_loss = np.mean(losses)
        predictions = np.argmax(logits_flat, axis=1)
        accuracy = np.mean(predictions == labels_flat) * 100
        return ce_loss, accuracy

    ce_loss_stored, acc_stored = compute_loss_and_accuracy(logits_reconstructed_with_dead, p)
    ce_loss_onfly, acc_onfly = compute_loss_and_accuracy(logits_reconstructed, p)

    # --- Update parameters by zeroing out dead neurons ---
    dead_by_layer = [dead_neurons_layer1] + additional_layers_dead_neurons
    updated_params = zero_dead_neurons_general(model_params, dead_by_layer)
    l2_loss = 0.0
    for key in additional_layer_keys:
        l2_loss += np.sum(np.square(np.array(updated_params[key]["kernel"])))
    if "output_dense" in updated_params:
        l2_loss += np.sum(np.square(np.array(updated_params["output_dense"]["kernel"])))
    total_loss_stored = ce_loss_stored + weight_decay * l2_loss
    total_loss_onfly = ce_loss_onfly + weight_decay * l2_loss

    # --- Package reconstruction metrics ---
    reconstruction_metrics = {
        "model": {
            "cross_entropy_loss": float(test_ce_loss),
            "total_loss": float(test_total_loss),
            "accuracy": float(model_accuracy)
        },
        "stored_fits": {
            "cross_entropy_loss": float(ce_loss_stored),
            "total_loss": float(total_loss_stored),
            "accuracy": float(acc_stored)
        },
        "on_the_fly": {
            "cross_entropy_loss": float(ce_loss_onfly),
            "total_loss": float(total_loss_onfly),
            "accuracy": float(acc_onfly)
        }
    }
    # --- Compute average R² per layer using the frequency distributions ---
    r2_per_layer = {}

    # Layer 1 comes from layer1_freq: {freq: [count, R2, …?]}
    total_neurons_1 = sum(count for count, r2, *_ in layer1_freq.values())
    weighted_sum_1  = sum(count * r2 for count, r2, *_ in layer1_freq.values())
    r2_per_layer['1'] = float(weighted_sum_1 / total_neurons_1) if total_neurons_1 > 0 else 0.0

    # Additional layers come from additional_layers_freq, a list of dicts
    for layer_idx, dist in enumerate(additional_layers_freq, start=2):
        total_neurons = sum(count for count, r2, *_ in dist.values())
        weighted_sum = sum(count * r2 for count, r2, *_ in dist.values())
        r2_per_layer[str(layer_idx)] = float(weighted_sum / total_neurons) if total_neurons > 0 else 0.0

    # merge into the metrics dict
    reconstruction_metrics.update(r2_per_layer)

    output_json_path = os.path.join(freq_json_dir, f"reconstruction_metrics_top-k={top_k_val}_seed_{seed}.json")
    with open(output_json_path, "w") as f:
        json.dump(reconstruction_metrics, f, indent=2)
    print(f"Reconstruction metrics saved to {output_json_path}")

    # --- Return outputs (now including non-None values for the model metrics) ---
    return reconstruction_metrics, layer1_freq, neuron_data, dominant_freq_clusters, freq_json_dir



# --- Main Loop Over Models (Seeds) ---
for i, seed in enumerate(random_seed_ints):
    # Extract parameters for the i-th model.
    model_params_seed = jax.tree_util.tree_map(lambda x: x[i], states.params)
    bias_layer1 = np.array(model.bias(model_params_seed))
    
    # Build input arrays for layer 1 contributions.
    x_freq_a = jnp.array([[a, 0] for a in range(p)], dtype=jnp.int32)
    x_freq_b = jnp.array([[0, b] for b in range(p)], dtype=jnp.int32)
    
    # Compute contributions for layer 1.
    _, _, contrib_a, _ = model.apply({'params': model_params_seed}, x_freq_a, training=False)
    _, _, _, contrib_b = model.apply({'params': model_params_seed}, x_freq_b, training=False)
    contrib_a_np = np.array(contrib_a)
    contrib_b_np = np.array(contrib_b)
    
    # (Assuming cluster_grouping, mlp_class_plots, and base_dir are used for plotting elsewhere.)
    cluster_grouping = final_grouping[seed]
    mlp_class_plots = f"{MLP_class}_seed_{seed}_bs={batch_size}_k={k}"
    base_dir = f"plots-{mlp_class_lower}-multilayer"
    
    m = log_by_seed[seed][epochs]
    model_accuracy = m["test_accuracy"]
    test_total_loss = m["test_loss"]
    test_ce_loss = m["test_ce_loss"]
    
    # --- First Reconstruction: top_k=1 ---
    rec_metrics1, layer1_freq, neuron_data, dominant_freq_clusters, freq_json_dir1 = run_reconstruction(
        model, model_params_seed, seed, top_k_val=1, p=p, batch_size=batch_size, k=k,
        weight_decay=weight_decay, learning_rate=learning_rate,
        mlp_class_lower=mlp_class_lower,
        contrib_a_np=contrib_a_np, contrib_b_np=contrib_b_np, bias_layer1=bias_layer1,
        model_accuracy=model_accuracy, test_total_loss=test_total_loss, test_ce_loss=test_ce_loss
    )
    
    # --- Determine new top_k based on the layer 1 frequency distribution.
    new_top_k = len(layer1_freq)
    print(f"For seed {seed}, top_k=1 yielded {new_top_k} key frequencies.")
    
    # # Only perform the second reconstruction for the specific MLP classes.
    # if base_class_name in ["no_embed"]:
    #     rec_metrics2, _, _, _, freq_json_dir2 = run_reconstruction(
    #         model, model_params_seed, seed, top_k_val=new_top_k, p=p, batch_size=batch_size, k=k,
    #         weight_decay=weight_decay, learning_rate=learning_rate,
    #         mlp_class_lower=mlp_class_lower,
    #         contrib_a_np=contrib_a_np, contrib_b_np=contrib_b_np, bias_layer1=bias_layer1,
    #         model_accuracy=model_accuracy, test_total_loss=test_total_loss, test_ce_loss=test_ce_loss
    #     )
    
    #     # --- Log training distributions for second reconstruction ---
    #     os.makedirs(freq_json_dir2, exist_ok=True)
    #     training_distributions_name2 = f"training_distributions_top-k={new_top_k}_{p}_{neuron_data[1].__len__()}_{weight_decay}_{learning_rate}.txt"
    #     training_distributions_path2 = os.path.join(freq_json_dir2, training_distributions_name2)
    #     with open(training_distributions_path2, "a") as f:
    #         line = f"{seed},<fitted_accuracy_sim_run2>,<fitted_accuracy_with_dead_run2>\n"
    #         f.write(line)
    # else:
    #     print(f"Skipping second reconstruction for model class: {base_class_name}")
    
    # --- Log training distributions for first reconstruction ---
    os.makedirs(freq_json_dir1, exist_ok=True)
    training_distributions_name1 = f"training_distributions_top-k=1_{p}_{neuron_data[1].__len__()}_{weight_decay}_{learning_rate}.txt"
    training_distributions_path1 = os.path.join(freq_json_dir1, training_distributions_name1)
    with open(training_distributions_path1, "a") as f:
        fitted_accuracy_sim_run1 = rec_metrics1["stored_fits"]["cross_entropy_loss"]
        fitted_accuracy_with_dead_run1 = rec_metrics1["model"]["cross_entropy_loss"]
        line = f"{seed},{fitted_accuracy_sim_run1},{fitted_accuracy_with_dead_run1}\n"
        f.write(line)
    
#     # Call the plotting functions.
    # plot_cluster_preactivations(
    #     cluster_groupings=dominant_freq_clusters,
    #     neuron_data=neuron_data,
    #     mlp_class=mlp_class_plots,
    #     seed=seed,
    #     features=features,
    #     num_neurons=num_neurons,
    #     base_dir=base_dir
    # )
