import os 
import torch
import argparse
import copy
import pickle
from pathlib import Path
import itertools
import jax.numpy as jnp
import matplotlib.pyplot as plt
from flax.serialization import from_bytes
from jax import random, vmap
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import numpy as np
import json 
import matplotlib.pyplot as plt
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import freeze, unfreeze
from model import  FlaxViTMoEForImageClassification, print_model, print_model_with_prefix
from weight_matching_moe import make_stuff
from datasets import build_dataset
import multiprocessing as mp
from engine import accuracy
from flax.training.train_state import TrainState
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_vit_permutation_spec_moe, plot_interp_loss, plot_interp_acc
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
mp.set_start_method("spawn", force=True)

def data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)

def compute_weights_cost_matrices(model, params_a, params_b, config):
    """
    Compute cost matrices for ViT-MoE models using routing gate weights and expert FFN parameters.

    Args:
        model: FlaxViTMoEForImageClassification instance.
        params_a, params_b: FrozenDict model parameter sets.
        config: Model config with `moe_idx`, `num_routed_experts`.

    Returns:
        S: [num_experts, num_experts] cosine similarity of gating weights.
        D: [num_experts, num_experts] Frobenius norm distance between expert weights.
    """
    moe_idx = str(config.moe_idx)
    num_experts = config.num_routed_experts

    moe_block_a = params_a['vit']['encoder']['layer'][moe_idx]['moe_block']
    moe_block_b = params_b['vit']['encoder']['layer'][moe_idx]['moe_block']

    # ----- Compute Gating Similarity (S) -----
    gate_kernel_a = np.array(moe_block_a['gate']['kernel'])  # [hidden_dim, num_experts]
    gate_bias_a = np.array(moe_block_a['gate']['bias'])      # [num_experts]
    gate_kernel_b = np.array(moe_block_b['gate']['kernel'])
    gate_bias_b = np.array(moe_block_b['gate']['bias'])

    # Center the gating weights
    centered_kernel_a = gate_kernel_a - np.mean(gate_kernel_a, axis=0, keepdims=True)
    centered_bias_a = gate_bias_a - np.mean(gate_bias_a)
    centered_kernel_b = gate_kernel_b - np.mean(gate_kernel_b, axis=0, keepdims=True)
    centered_bias_b = gate_bias_b - np.mean(gate_bias_b)

    # Gating vectors: [num_experts, hidden_dim + 1]
    gating_vectors_a = np.hstack([centered_kernel_a.T, centered_bias_a[:, None]])
    gating_vectors_b = np.hstack([centered_kernel_b.T, centered_bias_b[:, None]])

    diff = gating_vectors_a[:, None, :] - gating_vectors_b[None, :, :]
    S = np.sqrt(np.sum(diff ** 2, axis=-1))  # Euclidean distance

    # ----- Compute Expert Weight Distances (D) -----
    D = np.zeros((num_experts, num_experts))
    for i in range(num_experts):
        inter_a = moe_block_a[f"routed_intermediates_{i}"]["dense"]
        out_a = moe_block_a[f"routed_outputs_{i}"]["dense"]

        W1_a = np.array(inter_a["kernel"])
        b1_a = np.array(inter_a["bias"])
        W2_a = np.array(out_a["kernel"])
        b2_a = np.array(out_a["bias"])

        W1_a_tilde = np.vstack([W1_a, b1_a[None, :]])
        W2_a_tilde = np.vstack([W2_a, b2_a[None, :]])

        gram1_a = W1_a_tilde.T @ W1_a_tilde
        gram2_a = W2_a_tilde @ W2_a_tilde.T

        for j in range(num_experts):
            inter_b = moe_block_b[f"routed_intermediates_{j}"]["dense"]
            out_b = moe_block_b[f"routed_outputs_{j}"]["dense"]

            W1_b = np.array(inter_b["kernel"])
            b1_b = np.array(inter_b["bias"])
            W2_b = np.array(out_b["kernel"])
            b2_b = np.array(out_b["bias"])

            W1_b_tilde = np.vstack([W1_b, b1_b[None, :]])
            W2_b_tilde = np.vstack([W2_b, b2_b[None, :]])

            gram1_b = W1_b_tilde.T @ W1_b_tilde
            gram2_b = W2_b_tilde @ W2_b_tilde.T

            diff1 = gram1_a - gram1_b
            diff2 = gram2_a - gram2_b

            D[i, j] = np.sqrt(np.linalg.norm(diff1, 'fro')**2 + np.linalg.norm(diff2, 'fro')**2)

    return D, S


def select_permutations(D_score, S_score, config, alphas):
    """
    Select optimal expert permutations using output distance (D), gating similarity (S), and hybrid combinations.

    Args:
        D_score: [num_experts, num_experts] L2 distance matrix between expert outputs.
        S_score: [num_experts, num_experts] cosine similarity matrix between gating weights.
        config: Model config containing num_routed_experts.
        alphas: List of alpha values ∈ [0, 1] for hybrid matching.

    Returns:
        perm_D: np.ndarray, permutation minimizing D.
        perm_S: np.ndarray, permutation maximizing S.
        perm_hybrids: list of np.ndarrays, permutations for each alpha in alphas.
    """
    num_experts = config.num_routed_experts

    # Permutation minimizing L2 distance
    _, perm_D = linear_sum_assignment(D_score)

    # Permutation maximizing cosine similarity (negate to minimize)
    _, perm_S = linear_sum_assignment(-S_score)

    # Standardize matrices for hybrid scoring
    D_std = (D_score - D_score.mean()) / (D_score.std() + 1e-8)
    S_std = (S_score - S_score.mean()) / (S_score.std() + 1e-8)

    perm_hybrids = []
    for alpha in alphas:
        hybrid_cost = alpha * D_std - (1 - alpha) * S_std
        _, perm_hybrid = linear_sum_assignment(hybrid_cost)
        perm_hybrids.append(np.array(perm_hybrid))

    return perm_D, perm_S, perm_hybrids

def main():
    parser = argparse.ArgumentParser(description="Expert matching for ViT-MoE models on ImagNet")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first  ViT-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second ViT-MoE model checkpoint")
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--batch-size", type=int, default = 64)
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    parser.add_argument("--plot-path", type=str, default="/", help="Path to plot directory")
    args = parser.parse_args()
    model_a = FlaxViTMoEForImageClassification.from_pretrained(args.model_a)
    model_b = FlaxViTMoEForImageClassification.from_pretrained(args.model_b)
    params_a = model_a.params
    params_b = model_b.params
    config = copy.deepcopy(model_a.config)
    model = FlaxViTMoEForImageClassification(config)
    stuff = make_stuff(model = model,num_classes = 1000)
    train_loader, test_loader = data_loader(args)
    train_loader = test_loader 
    permutation_spec = flax_vit_permutation_spec_moe(config = config)
    lambdas = jnp.linspace(0, 1, num=25)
    plot_path = args.plot_path
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_acc_interp_naive, test_acc_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
            naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
            train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_loader)
            test_loss, test_acc = train_loss, train_acc #stuff["dataset_loss_and_accuracy"](naive_p, test_loader)
            train_loss_interp_naive.append(train_loss)
            test_loss_interp_naive.append(test_loss)
            train_acc_interp_naive.append(train_acc)
            test_acc_interp_naive.append(test_acc)

    D_weight, S_weight = compute_weights_cost_matrices(model, params_a, params_b, config)
    _, perm_expert_weight = linear_sum_assignment(D_weight)
    _, perm_gating_weight = linear_sum_assignment(S_weight)  
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_acc_interp_clever_list, test_acc_interp_clever_list = [], []
    seen_perms = set()
    selected_perms = {}
    def add_perm(name, perm):
        perm_tuple = tuple(perm)
        if perm_tuple not in seen_perms:
            selected_perms[name] = perm_tuple
            seen_perms.add(perm_tuple)
    add_perm("expert_weight_matching", perm_expert_weight)
    add_perm("gating_weight_matching", perm_gating_weight)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_acc_interp_clever_list, test_acc_interp_clever_list = [], []
    for label, pi in selected_perms.items():
        pi = list(pi)  # Convert np.array or tuple to list if needed
        parans_b_pi = permute_moe_block(params_b, pi, config)
        # Perform weight matching
        final_permutation = weight_matching(
            random.PRNGKey(args.seed),
            permutation_spec,
            flatten_params(params_a),
            flatten_params(parans_b_pi)
        )
        # Apply permutation to Model B_pi
        parans_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(parans_b_pi))
        )
        # Interpolate between Model A and aligned Model B_pi
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []
        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(parans_b_pi_aligned)))
            train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_loader)
            test_loss, test_acc =  train_loss, train_acc# stuff["dataset_loss_and_accuracy"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)

        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_acc_interp_clever_list.append(train_acc_interp)
        test_acc_interp_clever_list.append(test_acc_interp)
    # Convert to serializable format
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_acc_interp_naive": train_acc_interp_naive,
        "test_acc_interp_naive": test_acc_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_acc_interp_clever_list": train_acc_interp_clever_list,
        "test_acc_interp_clever_list": test_acc_interp_clever_list,
        "selected_perms": {k: list(v) for k, v in selected_perms.items()}
    }
    results = recursive_to_serializable(results)

    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)

    # Save directories
    os.makedirs("./plots/imagenet", exist_ok=True)
    os.makedirs("./results/imagenet", exist_ok=True)

    # Save results JSON
    print("Save List of Values...")
    name_a = os.path.basename(args.model_a.rstrip("/"))
    name_b = os.path.basename(args.model_b.rstrip("/"))
    result_path = f'results/imagenet/[{name_a}+{name_b}].json'
    with open(result_path, 'w') as f:
        json.dump(results, f, indent=4)

    # Plot
    print("Generating plots...")
    perm_labels_selected = list(selected_perms.keys())

    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels_selected
    )
    loss_fig_path = f"./plots/imagenet/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)

    ppl_fig = plot_interp_acc(
        lambdas,
        train_acc_interp_naive, test_acc_interp_naive,
        train_acc_interp_clever_list, test_acc_interp_clever_list,
        perm_labels_selected
    )
    ppl_fig_path = f"./plots/imagenet/[{name_a}+{name_b}]_weight_matching_interp_acc.png"
    plt.savefig(ppl_fig_path, dpi=300)
    plt.close(ppl_fig)


if __name__ == "__main__":
    main()