from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, vmap
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import numpy as np
from flax.serialization import from_bytes

from .cifar10_vit_finetune_moe import ViTModelMoE  
from .cifar10_vit_train import make_stuff
from src.datasets import load_cifar10
from src.utils import flatten_params, lerp, unflatten_params, compute_weights_cost_matrices
from src.weight_matching import apply_permutation, weight_matching, PermutationSpec, vit_permutation_spec_moe, permute_moe_block

def main():
    parser = argparse.ArgumentParser(description="Expert matching for ViT-MoE models on MNIST")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first ViT-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second ViT-MoE model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, default=1, help="Number of transformer layers")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE block")
    args = parser.parse_args()

    class Config:
        pass
    config = Config()
    config.model_a = args.model_a
    config.model_b = args.model_b
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.num_experts = args.num_experts
    batch_size = 5000

    model = ViTModelMoE(num_layers=config.num_layers, num_experts=config.num_experts)
    stuff = make_stuff(model)

    def load_model(filepath):
        with open(filepath, "rb") as fh:
            return from_bytes(
                model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"],
                fh.read()
            )
    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)

    train_ds, test_ds = load_cifar10()
    permutation_spec = vit_permutation_spec_moe(num_experts=config.num_experts)
    lambdas = jnp.linspace(0, 1, num=25)

    # Compute naive interpolation
    train_loss_naive, test_loss_naive = [], []
    train_acc_naive, test_acc_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b)))
        train_loss, train_acc = stuff["dataset_loss_and_accuracy"](p_interp, train_ds, batch_size)
        test_loss, test_acc = stuff["dataset_loss_and_accuracy"](p_interp, test_ds, batch_size)
        train_loss_naive.append(train_loss)
        test_loss_naive.append(test_loss)
        train_acc_naive.append(train_acc)
        test_acc_naive.append(test_acc)

    print("Naive Interpolation:")
    print("  Train Loss:", [float(f"{x:.4f}") for x in train_loss_naive])
    print("  Test Loss:", [float(f"{x:.4f}") for x in test_loss_naive])
    print("  Train Acc:", [float(f"{x:.4f}") for x in train_acc_naive])
    print("  Test Acc:", [float(f"{x:.4f}") for x in test_acc_naive])

    # Compute weight-based cost matrices
    D_weight, S_weight = compute_weights_cost_matrices(model, model_a, model_b)

    # Select permutations
    _, perm_expert_weight = linear_sum_assignment(D_weight)
    _, perm_gating_weight = linear_sum_assignment(S_weight)  

    # Define methods and permutations
    methods = ["Expert Weight Matching", "Gating Weight Matching"]
    perms = [perm_expert_weight, perm_gating_weight]

    for method, pi in zip(methods, perms):
        model_b_pi = permute_moe_block(model_b, pi, config.num_layers, config.num_experts)
        final_permutation = weight_matching(
            random.PRNGKey(config.seed),
            permutation_spec,
            flatten_params(model_a),
            flatten_params(model_b_pi)
        )
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []
        for lam in tqdm(lambdas, desc=f"{method} Interpolation"):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_pi_aligned)))
            train_loss, train_acc = stuff["dataset_loss_and_accuracy"](p_interp, train_ds, batch_size)
            test_loss, test_acc = stuff["dataset_loss_and_accuracy"](p_interp, test_ds, batch_size)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)

        print(f"{method}:")
        print("  Train Loss:", [float(f"{x:.4f}") for x in train_loss_interp])
        print("  Test Loss:", [float(f"{x:.4f}") for x in test_loss_interp])
        print("  Train Acc:", [float(f"{x:.4f}") for x in train_acc_interp])
        print("  Test Acc:", [float(f"{x:.4f}") for x in test_acc_interp])

if __name__ == "__main__":
    main()