import argparse
import os
import jax
import copy
import json 
import torch
import optax
from jax import random
import itertools
import numpy as np
from jax import jit 
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import freeze, unfreeze
from model import  FlaxViTMoEForImageClassification, print_model, print_model_with_prefix
from datasets import build_dataset
import multiprocessing as mp
from engine import accuracy
from flax.training.train_state import TrainState
from weight_matching import apply_permutation, weight_matching, permute_moe_block
from weight_matching import flax_vit_permutation_spec_moe, plot_interp_loss, plot_interp_acc
from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
def data_loader(args):
    # dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    # sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    # data_loader_train = torch.utils.data.DataLoader(
    #     dataset_train, sampler=sampler_train, batch_size=args.batch_size,
    #     num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    # )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_val, data_loader_val
def make_stuff(model, num_classes: int):
    apply_fn = model.__call__
    @jit
    def batch_eval(params, images, labels):
        logits = apply_fn(params=params, pixel_values=images, train=False).logits
        y_onehot = jax.nn.one_hot(labels, num_classes)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
        return logits, labels, loss
    @jit
    def step(train_state: TrainState, images, labels):
        def loss_fn(params):
            logits = apply_fn(params=params, pixel_values=images, train=True).logits
            y_onehot = jax.nn.one_hot(labels, num_classes)
            loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state, {"batch_loss": loss, "logits": logits, "labels": labels}
    def dataset_loss_and_accuracy(params, dataloader):
        total_loss, total_acc1, total_acc5, total_count = 0.0, 0.0, 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for images, labels in pbar:
            images = jax.device_put(np.asarray(images.numpy()))
            labels = jax.device_put(np.asarray(labels.numpy()))
            logits, labels, loss = batch_eval(params, images, labels)
            logits_np = np.asarray(logits)
            labels_np = np.asarray(labels)
            acc1, acc5 = accuracy(logits_np, labels_np, topk=(1, 5))
            batch_size = images.shape[0]
            total_loss += float(loss) * batch_size
            total_acc1 += acc1 * batch_size / 100.0
            total_acc5 += acc5 * batch_size / 100.0
            total_count += batch_size
            pbar.set_postfix({"loss": f"{loss:.4f}","acc1": f"{acc1:.2f}%","acc5": f"{acc5:.2f}%"})
        avg_loss = total_loss / total_count
        avg_acc1 = total_acc1 / total_count * 100
        avg_acc5 = total_acc5 / total_count * 100
        return avg_loss, avg_acc1
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_accuracy": dataset_loss_and_accuracy,}


def main():
    parser = argparse.ArgumentParser(description="Weight matching for ViT-MoE models on ImageNet-1k")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first  ViT-MoE model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second ViT-MoE model checkpoint")
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--batch-size", type=int, default = 64)
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    args = parser.parse_args()
    train_loader, test_loader = data_loader(args)
    train_loader = test_loader
    model_a = FlaxViTMoEForImageClassification.from_pretrained(args.model_a)
    model_b = FlaxViTMoEForImageClassification.from_pretrained(args.model_b)
    params_a = model_a.params
    params_b = model_b.params
    config = copy.deepcopy(model_a.config)
    model = FlaxViTMoEForImageClassification(config)
    stuff = make_stuff(model = model,num_classes = 1000)
    train_loss_a, train_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, train_loader)
    train_loss_b, train_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, train_loader)
    test_loss_a, test_accuracy_a = stuff["dataset_loss_and_accuracy"](params_a, test_loader)
    test_loss_b, test_accuracy_b = stuff["dataset_loss_and_accuracy"](params_b, test_loader)
    print({
        "train_loss_a": float(train_loss_a),"train_accuracy_a": float(train_accuracy_a),
        "train_loss_b": float(train_loss_b),"train_accuracy_b": float(train_accuracy_b),
        "test_loss_a": float(test_loss_a),"test_accuracy_a": float(test_accuracy_a),
        "test_loss_b": float(test_loss_b),"test_accuracy_b": float(test_accuracy_b),
    })
    baseline_train_loss = 0.5 * (train_loss_a + train_loss_b) #2.7338707785748593
    print(baseline_train_loss)
    permutation_spec = flax_vit_permutation_spec_moe(config =config)
    # Generate all possible expert permutations
    expert_perms = [list(p) for p in itertools.permutations(range(config.num_routed_experts))]
    perm_labels = [str(p) for p in expert_perms]  # e.g., "[0, 1]", "[1, 0]"
    # Define interpolation points
    lambdas = jnp.linspace(0, 1, num=25)
    # Naive interpolation (no permutation)
    train_loss_interp_naive, test_loss_interp_naive = [], []
    train_acc_interp_naive, test_acc_interp_naive = [], []
    for lam in tqdm(lambdas, desc="Naive Interpolation"):
        naive_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(params_b)))  # ✅ only param trees
        train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_loader)
        test_loss, test_acc = train_loss, train_acc #stuff["dataset_loss_and_accuracy"](naive_p, test_loader)
        train_loss_interp_naive.append(train_loss)
        test_loss_interp_naive.append(test_loss)
        train_acc_interp_naive.append(train_acc)
        test_acc_interp_naive.append(test_acc)
    train_loss_interp_clever_list, test_loss_interp_clever_list = [], []
    train_acc_interp_clever_list, test_acc_interp_clever_list = [], []
    for pi, label in zip(expert_perms, perm_labels):
        model_b_pi = permute_moe_block(params_b, pi, config)
        # Perform weight matching
        final_permutation = weight_matching(
            random.PRNGKey(args.seed),permutation_spec,flatten_params(params_a),flatten_params(model_b_pi)
        )
        # Apply permutation to Model B_pi
        model_b_pi_aligned = unflatten_params(
            apply_permutation(permutation_spec, final_permutation, flatten_params(model_b_pi))
        )
        # Interpolate between Model A and aligned Model B_pi
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []
        for lam in tqdm(lambdas, desc=f"Permuted Interpolation {label}"):
            clever_p = freeze(lerp(lam, unfreeze(params_a), unfreeze(model_b_pi_aligned)))
            train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_loader)
            test_loss, test_acc =   train_loss, train_acc #= stuff["dataset_loss_and_accuracy"](clever_p, test_loader)
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)
        train_loss_interp_clever_list.append(train_loss_interp)
        test_loss_interp_clever_list.append(test_loss_interp)
        train_acc_interp_clever_list.append(train_acc_interp)
        test_acc_interp_clever_list.append(test_acc_interp)
    results = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_acc_interp_naive": train_acc_interp_naive,
        "test_acc_interp_naive": test_acc_interp_naive,
        "train_loss_interp_clever_list": train_loss_interp_clever_list,
        "test_loss_interp_clever_list": test_loss_interp_clever_list,
        "train_acc_interp_clever_list": train_acc_interp_clever_list,
        "test_acc_interp_clever_list": test_acc_interp_clever_list,
        "baseline_train_loss": baseline_train_loss
    }
    results = recursive_to_serializable(results)
    # Validate lengths
    assert len(lambdas) == len(train_loss_interp_naive) == len(test_loss_interp_naive)
    assert all(len(lambdas) == len(tl) for tl in train_loss_interp_clever_list)
    assert all(len(lambdas) == len(tl) for tl in test_loss_interp_clever_list)
    os.makedirs("./plots/agnews", exist_ok=True)
    os.makedirs("./results/imagenet", exist_ok=True)
    # Plot results
    print("Save List of Values...")
    name_a = os.path.basename(args.model_a.rstrip("/"))
    name_b = os.path.basename(args.model_b.rstrip("/"))
    with open(f'results/imagenet/[{name_a}+{name_b}].json', 'w') as f:
        json.dump(results, f, indent=4)
    print("Generating plots...")
    loss_fig = plot_interp_loss(
        lambdas,
        train_loss_interp_naive, test_loss_interp_naive,
        train_loss_interp_clever_list, test_loss_interp_clever_list,
        perm_labels
    )
    loss_fig_path = f"./plots/imagenet/[{name_a}+{name_b}]_weight_matching_interp_loss.png"
    plt.savefig(loss_fig_path, dpi=300)
    plt.close(loss_fig)
    acc_fig = plot_interp_acc(
        lambdas,
        train_acc_interp_naive, test_acc_interp_naive,
        train_acc_interp_clever_list, test_acc_interp_clever_list,
        perm_labels
    )
    acc_fig_path = f"./plots/imagenet/[{name_a}+{name_b}]_weight_matching_interp_accuracy.png"
    plt.savefig(acc_fig_path, dpi=300)
    plt.close(acc_fig)    
    # Log metrics to wandb
    metrics = {
        "train_loss_interp_naive": train_loss_interp_naive,
        "test_loss_interp_naive": test_loss_interp_naive,
        "train_acc_interp_naive": train_acc_interp_naive,
        "test_acc_interp_naive": test_acc_interp_naive,
    }
    for label, train_loss, test_loss, train_acc, test_acc in zip(
        perm_labels, train_loss_interp_clever_list, test_loss_interp_clever_list,
        train_acc_interp_clever_list, test_acc_interp_clever_list
    ):
        metrics[f"train_loss_interp_clever_{label}"] = train_loss
        metrics[f"test_loss_interp_clever_{label}"] = test_loss
        metrics[f"train_acc_interp_clever_{label}"] = train_acc
        metrics[f"test_acc_interp_clever_{label}"] = test_acc
if __name__ == "__main__":
    # Parse command-line arguments
    main()
