import argparse
import os
import jax
import copy
import json 
import torch
import optax
from jax import random
import itertools
import numpy as np
from jax import jit 
from tqdm import tqdm
import jax.numpy as jnp
import matplotlib.pyplot as plt
from flax.core.frozen_dict import freeze, unfreeze
from transformers import FlaxViTForImageClassification
from model import  FlaxViTMoEForImageClassification, print_model, print_model_with_prefix
from datasets import build_dataset
import multiprocessing as mp
from engine import accuracy
from flax.training.train_state import TrainState
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

def to_serializable(val):
    if isinstance(val, (jnp.ndarray, np.ndarray)):
        return val.tolist()
    elif isinstance(val, (jnp.float32, jnp.float64, np.float32, np.float64)):
        return float(val)
    elif isinstance(val, (jnp.int32, jnp.int64, np.int32, np.int64)):
        return int(val)
    return val

# Recursively convert the results
def recursive_to_serializable(obj):
    if isinstance(obj, dict):
        return {k: recursive_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_to_serializable(v) for v in obj]
    else:
        return to_serializable(obj)
def data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val
def make_stuff(model, num_classes: int):
    apply_fn = model.__call__
    @jit
    def batch_eval(params, images, labels):
        logits = apply_fn(params=params, pixel_values=images, train=False).logits
        y_onehot = jax.nn.one_hot(labels, num_classes)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
        return logits, labels, loss
    @jit
    def step(train_state: TrainState, images, labels):
        def loss_fn(params):
            logits = apply_fn(params=params, pixel_values=images, train=True).logits
            y_onehot = jax.nn.one_hot(labels, num_classes)
            loss = jnp.mean(optax.softmax_cross_entropy(logits, y_onehot))
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state, {"batch_loss": loss, "logits": logits, "labels": labels}
    def dataset_loss_and_accuracy(params, dataloader):
        total_loss, total_acc1, total_acc5, total_count = 0.0, 0.0, 0.0, 0
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for images, labels in pbar:
            images = jax.device_put(np.asarray(images.numpy()))
            labels = jax.device_put(np.asarray(labels.numpy()))
            logits, labels, loss = batch_eval(params, images, labels)
            logits_np = np.asarray(logits)
            labels_np = np.asarray(labels)
            acc1, acc5 = accuracy(logits_np, labels_np, topk=(1, 5))
            batch_size = images.shape[0]
            total_loss += float(loss) * batch_size
            total_acc1 += acc1 * batch_size / 100.0
            total_acc5 += acc5 * batch_size / 100.0
            total_count += batch_size
            pbar.set_postfix({"loss": f"{loss:.4f}","acc1": f"{acc1:.2f}%","acc5": f"{acc5:.2f}%"})
        avg_loss = total_loss / total_count
        avg_acc1 = total_acc1 / total_count * 100
        avg_acc5 = total_acc5 / total_count * 100
        return avg_loss, avg_acc1
    return {"batch_eval": batch_eval,"step": step,"dataset_loss_and_accuracy": dataset_loss_and_accuracy,}
def get_trainable_mask(params, config, train_classifier=False):
    """
    Returns a parameter mask where only the MoE block at config.moe_idx
    and optionally the classifier head are trainable; all other parameters are frozen.
    """
    def is_moe_param(keys):
        # Checks if path points to an MoE expert or gate in the layer config.moe_idx
        if len(keys) < 6:
            return False
        return (keys[0] == "vit" and keys[1] == "encoder" and keys[2] == "layer" and keys[3] == str(config.moe_idx) and keys[4] == "moe_block" 
                and (keys[5] == "gate" or keys[5].startswith("expert_intermediates_") or keys[5].startswith("expert_outputs_"))
        )
    def is_classifier_param(keys):
        return train_classifier and keys[0] == "classifier"
    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        if is_moe_param(keys) or is_classifier_param(keys):
            return "trainable"
        return "frozen"
    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_parmas(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    
    # 1. Copy top-level params (embeddings, layernorm, classifier)
    finetune_params["vit"]["embeddings"] = copy.deepcopy(pretrained_params["vit"]["embeddings"])
    finetune_params["vit"]["layernorm"] = copy.deepcopy(pretrained_params["vit"]["layernorm"])
    finetune_params["classifier"] = copy.deepcopy(pretrained_params["classifier"])
    
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i == config.moe_idx:
            # Special handling for the MoE layer
            ref_layer = pretrained_params["vit"]["encoder"]["layer"][str_i]  # use layer 0 from pretrained
            target_layer = finetune_params["vit"]["encoder"]["layer"][str_i]
            # Copy shared parts
            target_layer["layernorm_before"] = copy.deepcopy(ref_layer["layernorm_before"])
            target_layer["layernorm_after"] = copy.deepcopy(ref_layer["layernorm_after"])
            target_layer["attention"] = copy.deepcopy(ref_layer["attention"])
        else:
            # Normal ViT layer, copy directly
            finetune_params["vit"]["encoder"]["layer"][str_i] = copy.deepcopy(
                pretrained_params["vit"]["encoder"]["layer"][str_i]
            )

    return freeze(finetune_params)


def main():
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on Imagenet")
    # --- Model & Training Config ---
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch-size", type=int, default=64)
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    train_loader, val_loader = data_loader(args)
    pretrained_model = FlaxViTForImageClassification.from_pretrained(args.model_path, dtype=jnp.float32)
    list_val_loss, list_val_acc = [],[]
    for moe_idx in range(12):
        config = copy.deepcopy(pretrained_model.config)
        config.num_routed_experts = 1
        config.num_shared_experts = 0
        config.topk = 1
        config.moe_idx = moe_idx 
        finetune_model = FlaxViTMoEForImageClassification(config)
        dummy_inputs = jnp.ones((1,config.num_channels,config.image_size, config.image_size), dtype=jnp.float32)  # adjust shape if needed
        variables = finetune_model.module.init(rng, pixel_values=dummy_inputs)
        finetune_model.params = variables['params']
        finetune_params = pretrained2finetune_parmas(pretrained_model.params, finetune_model.params,config)
        stuff = make_stuff(model = finetune_model,num_classes = 1000)
        val_loss, val_acc = stuff["dataset_loss_and_accuracy"](finetune_params, val_loader)
        print({"val_loss": float(val_loss),"val_acc": float(val_acc),})
        list_val_loss.append(to_serializable(val_loss))
        list_val_acc.append(to_serializable(val_acc))
    stuff = make_stuff(model = pretrained_model,num_classes = 1000)
    val_loss, val_acc = stuff["dataset_loss_and_accuracy"](pretrained_model.params, val_loader)
    print({"val_loss": float(val_loss),"val_acc": float(val_acc),})
    list_val_loss.append(to_serializable(val_loss))
    list_val_acc.append(to_serializable(val_acc))
    np.savez('results.npz', val_loss=list_val_loss, val_acc=list_val_acc)
    val_loss = list_val_loss
    val_acc = list_val_acc
    assert len(val_loss) == 13 and len(val_acc) == 13, "Arrays must have 13 values."
    # X-axis from 0 to 11
    x = np.arange(12)
    # Plot configuration
    plt.rcParams.update({"font.family": "serif","legend.frameon": False,"lines.linewidth": 2,})
    plt.style.use("tableau-colorblind10")
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    FONT_SMALL = 11
    FONT_MEDIUM = 13
    FONT_LARGE = 16
    # Plot Validation Loss
    axes[0].bar(x, val_loss[:12], color='steelblue')
    axes[0].axhline(y=val_loss[12], color='lightsalmon', linestyle='--', label=f'Pretrained: {val_loss[12]:.4f}')
    axes[0].set_title('Validation Loss', fontsize=FONT_LARGE)
    axes[0].set_xlabel('Layer', fontsize=FONT_MEDIUM)
    axes[0].set_ylabel('Loss', fontsize=FONT_MEDIUM)
    axes[0].set_xticks(x)
    axes[0].tick_params(axis='both', labelsize=FONT_SMALL)
    axes[0].legend(fontsize=FONT_SMALL)
    # Plot Validation Accuracy
    axes[1].bar(x, val_acc[:12] , color='steelblue')  # Scale to %
    axes[1].axhline(y=val_acc[12] , color='lightsalmon', linestyle='--', label=f'Pretrained: {val_acc[12] :.2f}%')
    axes[1].set_title('Validation Accuracy (%)', fontsize=FONT_LARGE)
    axes[1].set_xlabel('Layer', fontsize=FONT_MEDIUM)
    axes[1].set_ylabel('Accuracy (%)', fontsize=FONT_MEDIUM)
    axes[1].set_xticks(x)
    axes[1].set_ylim(0, 100)
    axes[1].tick_params(axis='both', labelsize=FONT_SMALL)
    axes[1].legend(fontsize=FONT_SMALL)
    plt.tight_layout()
    output_path = "imagenet_layer_metric.pdf"
    plt.savefig(output_path)
    print(f"Saved plot to {output_path}")
    plt.show()
if __name__ == "__main__":
    main()
