
import os
import jax 
import optax
import copy
import pickle 
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import json 
from jax import random
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.tree_util import tree_map
from jax.lax import stop_gradient
from dataclasses import dataclass, asdict
from jax import jit, random, value_and_grad
from flax.training.train_state import TrainState
from flax.core import freeze, unfreeze
from flax.serialization import from_bytes
from model import FlaxTransformerMoE, print_model
from datamodule import load_agnews_data, create_data_loader, AGNewsDataset, flax_make_numpy_loader
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params

vocab_size = 15000
max_seq_len = 100
PAD = 0
UNK = 1
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
@dataclass
class ModelConfig:
    encoder_vocab_size: int
    d_embed: int
    d_ff: int
    h: int
    N_encoder: int
    max_seq_len: int
    dropout: float
    num_experts: int = 0
    moe_idx: int = 0
    num_shared_experts: int = 0
    num_gated_experts: int = 2
    topk: int = 2
def load_combined_model(load_path):
    with open(load_path, "rb") as f:
        bundle = pickle.load(f)
    return bundle["flax_params"], bundle["config"]
def make_stuff(model, num_classes):
    apply_fn = model.apply  # ✅ correct and idiomatic in Flax
    @jit
    def batch_eval(params, batch):
        logits = apply_fn({'params': params},batch['input'],pad_mask=batch['mask'],deterministic=True)
        labels = batch['label']
        y_onehot = jax.nn.one_hot(labels, num_classes)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
        num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
        return {"loss": loss, "num_correct": num_correct}


    @jit
    def step(train_state, batch):
        def loss_fn(params):
            logits = apply_fn({'params': params},batch['input'],pad_mask=batch['mask'],deterministic=False)
            labels = batch['label']
            y_onehot = jax.nn.one_hot(labels, num_classes)
            loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
            num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
            return loss, {"num_correct": num_correct}
        (loss, info), grads = value_and_grad(loss_fn, has_aux=True)(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state, {"batch_loss": loss, **info}

    def dataset_loss_and_accuracy(params, dataloader):
        total_loss = 0.0
        total_correct = 0
        total_count = 0
        pbar = tqdm(dataloader, desc="Evaluating")
        for batch in pbar:
            batch = {k: jax.device_put(v) for k, v in batch.items()}
            metrics = batch_eval(params, batch)
            batch_size = batch['input'].shape[0]
            total_loss += float(metrics["loss"]) * batch_size
            total_correct += int(metrics["num_correct"])
            total_count += batch_size
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{(metrics['num_correct'] / batch_size):.4f}"
            })
        avg_loss = total_loss / total_count
        avg_acc = total_correct / total_count
        return avg_loss, avg_acc

    return {"batch_eval": batch_eval,"step": step, "dataset_loss_and_accuracy": dataset_loss_and_accuracy,}
def main():
    parser = argparse.ArgumentParser(description="Weight matching for Transformer-MoE models on AGNEWS")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first Transformer-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second Transformer-MoE model checkpoint")
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--batch-size", type=int, default = 100)
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    args = parser.parse_args()
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_agnews_data(args.data_path)
    train_loader = list(flax_make_numpy_loader(x_test, y_test, args.batch_size, pad_id=PAD))
    test_loader = list(flax_make_numpy_loader(x_test, y_test, args.batch_size, pad_id=PAD))
    params_a, config_a = load_combined_model(args.model_a)
    params_b, config_b = load_combined_model(args.model_b)
    config = ModelConfig(**config_a)  # ✅ wrap as actual config class
    print(config)
    model = FlaxTransformerMoE(config,num_classes = num_classes)
    stuff = make_stuff(model = model,num_classes = num_classes)
    train_loss_a, train_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, train_loader)
    train_loss_b, train_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, train_loader)
    test_loss_a, test_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, test_loader)
    test_loss_b, test_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, test_loader)
    # Print results
    print({
        "train_loss_a": float(train_loss_a),
        "train_accuracy_a": float(train_accuracy_a),
        "train_loss_b": float(train_loss_b),
        "train_accuracy_b": float(train_accuracy_b),
        "test_loss_a": float(test_loss_a),
        "test_accuracy_a": float(test_accuracy_a),
        "test_loss_b": float(test_loss_b),
        "test_accuracy_b": float(test_accuracy_b),
    })

    # Evaluate base models
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b)
    print(baseline_train_loss)
    # Generate all possible expert permutations
    expert_perms = [list(p) for p in itertools.permutations(range(config.num_experts))]
    perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_acc_interp_naive, test_acc_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_loader)
        test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_acc_interp_naive.append(train_acc)
        test_acc_interp_naive.append(test_acc)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_acc_interp_clever_list, test_acc_interp_clever_list = [], []
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_acc_interp_naive": train_acc_interp_naive,
        "test_acc_interp_naive": test_acc_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_acc_interp_clever_list": train_acc_interp_clever_list,
        "test_acc_interp_clever_list": test_acc_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    os.makedirs("./results/agnews/", exist_ok=True)
    name_a = os.path.basename(args.model_a.rstrip("/"))
    name_b = os.path.basename(args.model_b.rstrip("/"))
    print("Save List of Values...")
    with open(f'results/agnews/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
if __name__ == "__main__":
    # Parse command-line arguments
    main()