import argparse
import os, re
import jax
import flax
import copy
import torch
import wandb
import optax
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
from flax import linen as nn
from typing import Any, Dict, List
from flax.jax_utils import replicate, unreplicate
from flax.core.frozen_dict import freeze, unfreeze
from flax.training import checkpoints, train_state
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.training.common_utils import get_metrics, onehot, shard
from lmc_model import  LMCFlaxViTForImageClassification, print_model, print_model_with_prefix
from transformers.models.vit.modeling_flax_vit import ViTConfig
from datasets import build_dataset
import multiprocessing as mp
from pprint import pprint
import json
import shutil
from flax.serialization import from_bytes
mp.set_start_method("spawn", force=True)
# ---------- Dataset Loader ----------
def load_flax_params(checkpoint_dir, target):
    msgpack_path = os.path.join(checkpoint_dir, "flax_model.msgpack")
    with open(msgpack_path, "rb") as f:
        packed_bytes = f.read()
    params = from_bytes(target, packed_bytes)
    return params
def parse_step_from_path(path: str) -> int:
    m = re.search(r'last_(\d+)', path)
    return int(m.group(1)) if m else 0
def remove_old_dirs_with_prefix(save_path, prefix, keep_step):
    for fname in os.listdir(save_path):
        if fname.startswith(prefix) and not fname.endswith(str(keep_step)):
            full_path = os.path.join(save_path, fname)
            if os.path.isdir(full_path):
                shutil.rmtree(full_path)
def imagenet_data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val
def prepare_image_batch(images:torch.Tensor,labels:torch.Tensor) -> Dict[str, Any]:
    images, labels =  jnp.array(images),jnp.array(labels)
    return {'images': shard(images),'labels': shard(labels)}
def accuracy(logits, labels, topk=(1,)):
    maxk = max(topk)
    batch_size = labels.shape[0]
    topk_preds = jnp.argsort(logits, axis=-1)[:, -maxk:][:, ::-1]  # Top-k predictions
    res = []
    for k in topk:
        correct = (topk_preds[:, :k] == labels[:, None])
        correct = jnp.any(correct, axis=1)
        correct = jnp.sum(correct)
        res.append(100.0 * correct / batch_size)
    return res  # list of [acc@1, acc@5]
# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    def is_att_param(keys):
        if len(keys) < 5: return False
        if keys[0] == "vit" and keys[1] == "encoder" and keys[2] == "layer" and int(keys[3]) in config.lmc_layer_indices and keys[4] == "attention": return True
        if config.finetune_mlp == True and keys[0] == "vit" and keys[1] == "encoder" and keys[2] == "layer" and int(keys[3]) in config.lmc_layer_indices and keys[4] == "moe": return True
        return False
    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        if is_att_param(keys): return "trainable"
        return "frozen"
    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_params(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level params (embeddings, layernorm, classifier)
    finetune_params["vit"]["embeddings"] = copy.deepcopy(pretrained_params["vit"]["embeddings"])
    finetune_params["vit"]["layernorm"] = copy.deepcopy(pretrained_params["vit"]["layernorm"])
    finetune_params["classifier"] = copy.deepcopy(pretrained_params["classifier"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i in config.lmc_layer_indices:
            ref_layer = pretrained_params["vit"]["encoder"]["layer"][str_i]  # use layer 0 from pretrained
            target_layer = finetune_params["vit"]["encoder"]["layer"][str_i]
            # Copy shared parts
            target_layer["layernorm_before"] = copy.deepcopy(ref_layer["layernorm_before"])
            target_layer["layernorm_after"] = copy.deepcopy(ref_layer["layernorm_after"])
            if config.finetune_mlp == False:
                target_layer["moe"] = copy.deepcopy(ref_layer["moe"])
            # #Copy Attention weights 
            # target_layer["attention"] = copy.deepcopy(ref_layer["attention"])
        else:
            finetune_params["vit"]["encoder"]["layer"][str_i] = copy.deepcopy(pretrained_params["vit"]["encoder"]["layer"][str_i])
    return freeze(finetune_params)


# ---------- Main ----------
def main(args: argparse.Namespace):
    if os.path.exists(args.model_path):
        config = ViTConfig.from_pretrained(os.path.dirname(args.model_path))        
    else:
        raise FileNotFoundError(f"Config directory does not exist: {os.path.dirname(args.model_path)}")
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    # --- Seeds & RNG ---
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        id=args.wandb_id,
        name=f"finetune-{config.position_embeddings}-indice{','.join(str(i) for i in args.lmc_layer_indices)}-heads{args.num_attention_heads}"
            f"-shared{config.num_shared_experts}-routed{config.num_routed_experts}-topk{config.topk}-mlp{str(args.finetune_mlp)}-seed{args.seed}",
        save_code=True
    )
    save_path = os.path.join(
        args.save_dir,f"finetune-{config.position_embeddings}-indice{','.join(str(i) for i in args.lmc_layer_indices)}-heads{args.num_attention_heads}"
        f"-shared{config.num_shared_experts}-routed{config.num_routed_experts}-topk{config.topk}-mlp{str(args.finetune_mlp)}-seed{args.seed}"
    )
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    train_loader, val_loader = imagenet_data_loader(args)
    # --- Load pretrained model ---
    pretrained_model = LMCFlaxViTForImageClassification(config, dtype=jnp.dtype(args.dtype))
    lmc_config = copy.deepcopy(config)
    new_config = copy.deepcopy(config)
    lmc_config.num_attention_heads = args.num_attention_heads
    lmc_config.routed_scaling_factor = args.routed_scaling_factor
    new_config.lmc_config = lmc_config
    new_config.lmc_layer_indices = args.lmc_layer_indices
    new_config.finetune_mlp = args.finetune_mlp
    if os.path.exists(args.model_path):
        pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    else:
        raise FileNotFoundError(f"Checkpoint path not found: {args.model_path}")
    pretrained_model.params = pretrained_params
    # --- Initialize fine-tuning model ---
    model = LMCFlaxViTForImageClassification(new_config,input_shape=(1,new_config.image_size, new_config.image_size, new_config.num_channels),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(save_path)
    print_model(model.params)
    model.params = pretrained2finetune_params(pretrained_model.params,model.params,new_config)
    # model = pretrained_model
    label_mask = get_trainable_mask(model.params,new_config)
    print(json.dumps(label_mask, indent=2))
    num_train_steps = len(train_loader)*args.epochs 
    num_warmup_steps = len(train_loader)*args.warmup_epochs
    lr_schedule =optax.warmup_cosine_decay_schedule(
        init_value=args.warmup_lr,
        peak_value=args.lr,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps,
        end_value=args.min_lr,
    )    
    tx = optax.multi_transform(
        transforms={
            'trainable': optax.adamw(
                learning_rate=lr_schedule,
                b1=args.adamw_beta1,
                b2=args.adamw_beta2,
                eps=args.adamw_eps,
                weight_decay=args.weight_decay
            ),
            'frozen': optax.set_to_zero()
        },
        param_labels=label_mask  
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)
    if args.restore_checkpoint_path:
        restored_params = load_flax_params(args.restore_checkpoint_path,state.params)
        restored_step = parse_step_from_path(args.restore_checkpoint_path)
        state = state.replace(params=restored_params, step=restored_step)
        print(f"train state restored from {args.restore_checkpoint_path}")
        print(f"skip trian step to {state.step}")
    latest_global_step = state.step
    curr_epoch = latest_global_step//len(train_loader)
    state = replicate(state)
    def train_step(state, batch, rng):
        dropout_rng, new_dropout_rng = jax.random.split(rng)
        def loss_fn(params):
            outputs = state.apply_fn(params=params,pixel_values=batch["images"],train=True,dropout_rng=dropout_rng,)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        state = state.apply_gradients(grads=grads)
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,"learning_rate": lr_schedule(state.step),}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return state, metrics, new_dropout_rng
    def eval_step(state, batch):
        outputs = state.apply_fn(params=state.params,pixel_values=batch["images"],train=False,)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    rng = jax.random.PRNGKey(args.seed)
    train_metrics_stack = []
    train_loss = 0.0
    global_step = latest_global_step
    best_val_acc1 = 0.0
    ###JUST FOR TESTING####
    eval_results = []
    pbar = tqdm(enumerate(val_loader), desc="Evaluating", leave=False)
    for batch_idx, (images, labels) in pbar:
        batch = prepare_image_batch(images, labels)
        eval_metric = parallel_eval_step(state, batch)
        eval_results.append(eval_metric)            
    # Compute mean metrics across all eval batches
    eval_metrics = get_metrics(eval_results)
    eval_metrics = unreplicate(eval_metrics)
    eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
    val_loss, val_acc1, val_acc5 = float(eval_metrics["loss"]), float(eval_metrics["acc1"]), float(eval_metrics["acc5"])
    print(f"Validation Loss: {val_loss:.4f}, Top-1 Accuracy: {val_acc1:.2f}%, Top-5 Accuracy: {val_acc5:.2f}%")
    #### START FINETUNING ####
    print("Starting training...")
    print(f"JAX devices: {jax.devices()}")
    print(f"Using {jax.local_device_count()} devices")
    for epoch in range(curr_epoch,args.epochs):
        print(f"Epoch {epoch}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_metrics_stack = []
        pbar = tqdm(enumerate(train_loader), desc="Training", leave=False)
        for batch_idx, (images, labels) in pbar:
            # Prepare and shard batch
            batch = prepare_image_batch(images,labels)
            # Run train step
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics = unreplicate(train_metrics)
            train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)
            loss, acc1, acc5 = (float(train_metrics["loss"]),float(train_metrics["acc1"]),float(train_metrics["acc5"]),)            
            curr_lr = float(lr_schedule(global_step))
            if global_step % args.wandb_logging_frequency == 0:
                wandb.log({"train/loss": loss,"train/acc1": acc1,"train/acc5": acc5,"lr": curr_lr,}, step=global_step)
            pbar.set_postfix({"train/loss": f"{loss:.3f}","train/acc1": f"{acc1:.3f}", "train/acc5": f"{acc5:.3f}","lr": f"{curr_lr:.2e}",})
            global_step += 1
        eval_results = []
        pbar = tqdm(enumerate(val_loader), desc="Evaluating", leave=False)
        for batch_idx, (images, labels) in pbar:
            batch = prepare_image_batch(images, labels)
            eval_metric = parallel_eval_step(state, batch)
            eval_results.append(eval_metric)            
        # Compute mean metrics across all eval batches
        eval_metrics = get_metrics(eval_results)
        eval_metrics = unreplicate(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
        val_loss, val_acc1, val_acc5 = float(eval_metrics["loss"]), float(eval_metrics["acc1"]), float(eval_metrics["acc5"])
        wandb.log({"val/loss": val_loss,"val/acc1": val_acc1, "val/acc5" : val_acc5}, step=global_step)
        # Save best model
        if best_val_acc1 <= val_acc1:
            best_val_acc1 = val_acc1
            model.params = unreplicate(state).params
            best_dir = os.path.join(save_path, f"best_{global_step}")
            model.save_pretrained(best_dir)
            print(f"Best model saved at step {global_step}")
            remove_old_dirs_with_prefix(save_path, "best_", global_step)
            # checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=global_step,prefix="last_",keep=1,overwrite=True,orbax_checkpointer=None)                    
        # Save last model
        model.params = unreplicate(state).params
        last_dir = os.path.join(save_path, f"last_{global_step}")
        model.save_pretrained(last_dir)
        print(f"Last model saved at step {global_step}")
        remove_old_dirs_with_prefix(save_path, "last_", global_step)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on Imagenet")
    # --- Model & Training Config ---
    parser.add_argument("--model-path", type=str, default="", help="Path of Pretrained Model")
    parser.add_argument("--num-attention-heads", type=int, default = 1)
    parser.add_argument('--routed-scaling-factor', type=float, default=1.0,help='')
    parser.add_argument("--lmc-layer-indices",type=int,nargs="*",default=[],help="List of lmc layer indices (optional, default: empty list)")
    parser.add_argument("--finetune-mlp",action="store_true",help="Enable fine-tuning for the MLP. Default is False.")
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)    
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--patience", type=int, default=10, help="Early stopping patience")
    parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    parser.add_argument("--wandb-entity", default="", help="wandb entity for logging")
    parser.add_argument("--wandb-group", default="", help="wandb group for logging")
    parser.add_argument("--wandb-project", default="", help="wandb project name for logging")
    parser.add_argument("--wandb-run-dir", default=".wandb", help="wandb run dir")
    parser.add_argument("--wandb-id", default=None, help="wandb project name for logging")
    parser.add_argument("--wandb-logging-frequency", type=int, default=100, help="do logging every logging_frequency step")
    args = parser.parse_args()
    main(args)
