import argparse
import os
import jax
import flax
import copy
import torch
import wandb
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
from flax import linen as nn

from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import freeze, unfreeze
from model import  FlaxViTMoEForImageClassification, print_model, print_model_with_prefix
from transformers import FlaxViTForImageClassification
from datasets import build_dataset
from engine import train_epoch, evaluate, create_train_state, accuracy
import multiprocessing as mp
from pprint import pprint
import json

mp.set_start_method("spawn", force=True)
os.environ["WANDB_API_KEY"] = ""
# ---------- Dataset Loader ----------
def data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val

# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    """
    Returns a parameter mask where only the MoE block at config.moe_idx
    and optionally the classifier head are trainable; all other parameters are frozen.
    """
    def is_moe_param(keys):
        # Checks if path points to an MoE expert or gate in the layer config.moe_idx
        if len(keys) < 6:
            return False
        return (keys[0] == "vit" and keys[1] == "encoder" and keys[2] == "layer" and keys[3] == str(config.moe_idx) and keys[4] == "moe_block")
    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        if is_moe_param(keys): return "trainable"
        return "frozen"
    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_parmas(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level params (embeddings, layernorm, classifier)
    finetune_params["vit"]["embeddings"] = copy.deepcopy(pretrained_params["vit"]["embeddings"])
    finetune_params["vit"]["layernorm"] = copy.deepcopy(pretrained_params["vit"]["layernorm"])
    finetune_params["classifier"] = copy.deepcopy(pretrained_params["classifier"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i == config.moe_idx:
            # Special handling for the MoE layer
            ref_layer = pretrained_params["vit"]["encoder"]["layer"][str_i]  # use layer 0 from pretrained
            target_layer = finetune_params["vit"]["encoder"]["layer"][str_i]
            # Copy shared parts
            target_layer["layernorm_before"] = copy.deepcopy(ref_layer["layernorm_before"])
            target_layer["layernorm_after"] = copy.deepcopy(ref_layer["layernorm_after"])
            target_layer["attention"] = copy.deepcopy(ref_layer["attention"])
            # #Copy feedforward weights into each MoE expert
            # dense_inter = ref_layer["intermediate"]["dense"]
            # dense_out = ref_layer["output"]["dense"]
            # moe_block = target_layer["moe_block"]

            # for expert_idx in range(config.num_routed_experts):
            #     # Intermediate dense layer
            #     moe_block[f"routed_intermediates_{expert_idx}"]["dense"]["kernel"] = copy.deepcopy(dense_inter["kernel"])
            #     moe_block[f"routed_intermediates_{expert_idx}"]["dense"]["bias"] = copy.deepcopy(dense_inter["bias"])
            #     # Output dense layer
            #     moe_block[f"routed_outputs_{expert_idx}"]["dense"]["kernel"] = copy.deepcopy(dense_out["kernel"])
            #     moe_block[f"routed_outputs_{expert_idx}"]["dense"]["bias"] = copy.deepcopy(dense_out["bias"])
        else:
            # Normal ViT layer, copy directly
            finetune_params["vit"]["encoder"]["layer"][str_i] = copy.deepcopy(
                pretrained_params["vit"]["encoder"]["layer"][str_i]
            )

    return freeze(finetune_params)


# ---------- Main ----------
def main():
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on Imagenet")
    # --- Model & Training Config ---
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--moe-idx", type=int, required=True)
    parser.add_argument("--num-routed-experts", type=int, default=2)
    parser.add_argument("--num-shared-experts", type=int, default=1)
    parser.add_argument("--topk", type=int, default=2)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER')
    parser.add_argument("--patience", type=int, default=6, help="Early stopping patience")
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    args = parser.parse_args()
    # --- Seeds & RNG ---
    wandb.init(project="LMC_for_VitMOE",
               name=f"idx{args.moe_idx}-lr{args.lr}-seed-{args.seed}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-topk{args.topk}")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    # --- Load pretrained model ---
    pretrained_model = FlaxViTForImageClassification.from_pretrained(args.model_path, dtype=jnp.float32)
    config = copy.deepcopy(pretrained_model.config)
    config.num_shared_experts = args.num_shared_experts
    config.num_routed_experts = args.num_routed_experts
    config.topk = args.topk 
    config.moe_idx = args.moe_idx 
    config.dtype = "float32"
    # --- Initialize fine-tuning model ---
    finetune_model = FlaxViTMoEForImageClassification(config)
    dummy_inputs = jnp.ones((1,config.num_channels,config.image_size, config.image_size), dtype=jnp.float32)  # adjust shape if needed
    variables = finetune_model.module.init(rng, pixel_values=dummy_inputs)
    finetune_model.params  = variables['params']
    finetune_params = pretrained2finetune_parmas(pretrained_model.params, finetune_model.params,config)
    finetune_model.params = finetune_params
    print_model(finetune_model.params)    
    train_loader, val_loader = data_loader(args)
    mask = get_trainable_mask(finetune_model.params,config)
    state = create_train_state(finetune_model,args,len(train_loader),mask)
    best_val_loss = float("inf")
    best_params = None
    wait = 0
    prefix = f"./weights/imagenet/finetune/"
    suffix = f"idx{args.moe_idx}-lr{args.lr}-seed-{args.seed}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-topk{args.topk}"
    save_path = os.path.join(prefix, suffix)
    os.makedirs(save_path, exist_ok=True)
    for epoch in range(args.epochs):
        rng, step_rng = jax.random.split(rng)
        state, train_metrics = train_epoch(state, train_loader,step_rng)
        acc1, acc5, val_loss = evaluate(state,val_loader)
        print(f"Epoch {epoch+1}: "
              f"Train Loss={train_metrics['loss']:.4f}, Acc@1={train_metrics['acc1']:.2f}%, Acc@5={train_metrics['acc5']:.2f}%")
        print(f"Val Loss={val_loss:.4f}, Acc@1={acc1:.2f}%, Acc@5={acc5:.2f}%")
        wandb.log({
            "epoch": epoch + 1, "val/loss": val_loss,"val/acc1": acc1,"val/acc5": acc5,
            "train/loss": train_metrics["loss"],"train/acc1": train_metrics["acc1"],"train/acc5": train_metrics["acc5"],
        })
        if (epoch + 1) % 10 == 0:
            state_path = os.path.join(save_path, f"train_state_epoch{epoch+1}_seed{args.seed}.msgpack")
            with open(state_path, "wb") as f:
                f.write(flax.serialization.to_bytes(state))
            print(f"Saved TrainState to {state_path}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_params = state.params
            wait = 0
            finetune_model.params = best_params
            finetune_model.save_pretrained(save_path)
            print(f"Checkpoint saved to {save_path} at epoch {epoch+1} (Val Loss = {val_loss:.4f})")

        else:
            wait += 1
            if wait >= args.patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

if __name__ == "__main__":
    main()
