import argparse
import os
import jax
import flax
import copy
import json
import optax
import torch
import wandb
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
from flax import linen as nn
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.vit.modeling_flax_vit import ViTConfig, FlaxViTForImageClassification
from flax.training.common_utils import get_metrics, onehot, shard
from lmc_model import  LMCFlaxViTForImageClassification, print_model, print_model_with_prefix
from datasets import build_dataset
import multiprocessing as mp
from jax import debug
from pprint import pprint
from typing import Any, Dict, List
import shutil
mp.set_start_method("spawn", force=True)
def aggregate_metrics(metrics_list):
    return {
        key: jnp.mean(jnp.array([m[key] for m in metrics_list])).item()
        for key in metrics_list[0]
    }
# ---------- Dataset Loader ----------
def remove_old_dirs_with_prefix(save_path, prefix, keep_step):
    for fname in os.listdir(save_path):
        if fname.startswith(prefix) and not fname.endswith(str(keep_step)):
            full_path = os.path.join(save_path, fname)
            if os.path.isdir(full_path):
                shutil.rmtree(full_path)
def imagenet_data_loader(args):
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False
    )
    return data_loader_train, data_loader_val
# fmt: on
def prepare_image_batch(images:torch.Tensor,labels:torch.Tensor) -> Dict[str, Any]:
    images, labels =  jnp.array(images),jnp.array(labels)
    return {'images': shard(images),'labels': shard(labels)}
def accuracy(logits, labels, topk=(1,)):
    maxk = max(topk)
    batch_size = labels.shape[0]
    topk_preds = jnp.argsort(logits, axis=-1)[:, -maxk:][:, ::-1]  # Top-k predictions
    res = []
    for k in topk:
        correct = (topk_preds[:, :k] == labels[:, None])
        correct = jnp.any(correct, axis=1)
        correct = jnp.sum(correct)
        res.append(100.0 * correct / batch_size)
    return res  # list of [acc@1, acc@5]
def main(args: argparse.Namespace):
    # --- Seeds & RNG ---
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        id=args.wandb_id,
        name= f"lr{args.lr}-{args.position_embeddings}-epochs{args.epochs}-batch{args.batch_size}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-topk{args.topk}",
        save_code=True
    )
    wandb.config = dict(vars(args))
    save_path = os.path.join(args.save_dir,f"lr{args.lr}-{args.position_embeddings}-epochs{args.epochs}-batch{args.batch_size}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-topk{args.topk}")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    # --- Prepare Data Loader ---
    train_loader, val_loader = imagenet_data_loader(args)
    # --- Load pretrained model ---
    config = ViTConfig.from_pretrained('google/vit-base-patch16-224')
    config.hidden_size = args.hidden_size
    config.num_hidden_layers = args.num_hidden_layers
    config.num_attention_heads = args.num_attention_heads
    config.intermediate_size = args.intermediate_size
    config.position_embeddings = args.position_embeddings
    config.rotary_value = args.rotary_value
    config.num_shared_experts = args.num_shared_experts
    config.num_routed_experts = args.num_routed_experts
    config.topk = args.topk
    config.routed_scaling_factor = args.routed_scaling_factor
    config.lmc_layer_indices = args.lmc_layer_indices
    model = LMCFlaxViTForImageClassification(
        config,input_shape=(1,config.image_size, config.image_size, config.num_channels),seed=args.seed,dtype=jnp.dtype(args.dtype),
    )
    model.config.save_pretrained(save_path)
    num_global_steps = len(train_loader)*args.epochs 
    num_warmup_steps = len(train_loader)*args.warmup_epochs
    lr_schedule =optax.warmup_cosine_decay_schedule(
        init_value=args.warmup_lr,
        peak_value=args.lr,
        warmup_steps=num_warmup_steps,
        decay_steps=num_global_steps,
        end_value=args.min_lr,
    )
    tx = optax.adamw(
        learning_rate=lr_schedule,
        b1=args.adamw_beta1,
        b2=args.adamw_beta2,
        eps=args.adamw_eps,
        weight_decay=args.weight_decay,
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)
    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"train state restored from {args.restore_checkpoint_path}")
        print(f"skip trian step to {state.step}")
    latest_global_step = state.step
    curr_epoch = latest_global_step//len(train_loader)
    state = replicate(state)
    def train_step(state, batch, rng):
        dropout_rng, new_dropout_rng = jax.random.split(rng)
        def loss_fn(params):
            outputs = state.apply_fn(params=params,pixel_values=batch["images"],train=True,dropout_rng=dropout_rng,)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
            return loss, logits
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        state = state.apply_gradients(grads=grads)
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,"learning_rate": lr_schedule(state.step),}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return state, metrics, new_dropout_rng
    def eval_step(state, batch):
        outputs = state.apply_fn(params=state.params,pixel_values=batch["images"],train=False,)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"]).mean()
        acc1, acc5 = accuracy(logits, batch["labels"], topk=(1, 5))
        metrics = {"loss": loss,"acc1": acc1,"acc5": acc5,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    rng = jax.random.PRNGKey(args.seed)
    global_step = latest_global_step
    best_val_acc1 = 85.9375
    print("Starting training...")
    print(f"JAX devices: {jax.devices()}")
    print(f"Using {jax.local_device_count()} devices")
    for epoch in range(curr_epoch,args.epochs):
        print(f"Epoch {epoch}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_metrics_stack = []
        pbar = tqdm(enumerate(train_loader), desc="Training", leave=False)
        for batch_idx, (images, labels) in pbar:
            # Prepare and shard batch
            batch = prepare_image_batch(images,labels)
            # Run train step
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics = unreplicate(train_metrics)
            train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)
            loss, acc1, acc5 = (float(train_metrics["loss"]),float(train_metrics["acc1"]),float(train_metrics["acc5"]),)            
            curr_lr = float(lr_schedule(global_step))
            if global_step % args.wandb_logging_frequency == 0:
                wandb.log({"train/loss": loss,"train/acc1": acc1,"train/acc5": acc5,"lr": curr_lr,}, step=global_step)
            pbar.set_postfix({"train/loss": f"{loss:.3f}","train/acc1": f"{acc1:.3f}", "train/acc5": f"{acc5:.3f}","lr": f"{curr_lr:.2e}",})
            global_step += 1
        eval_results = []
        pbar = tqdm(enumerate(val_loader), desc="Evaluating", leave=False)
        for batch_idx, (images, labels) in pbar:
            batch = prepare_image_batch(images, labels)
            eval_metric = parallel_eval_step(state, batch)
            eval_results.append(eval_metric)            
        # Compute mean metrics across all eval batches
        eval_metrics = get_metrics(eval_results)
        eval_metrics = unreplicate(eval_metrics)
        eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
        val_loss, val_acc1, val_acc5 = float(eval_metrics["loss"]), float(eval_metrics["acc1"]), float(eval_metrics["acc5"])
        wandb.log({"val/loss": val_loss,"val/acc1": val_acc1, "val/acc5" : val_acc5}, step=global_step)
        # Save best model
        if best_val_acc1 <= val_acc1:
            best_val_acc1 = val_acc1
            # model.params = unreplicate(state).params
            # best_dir = os.path.join(save_path, f"best_{global_step}")
            # model.save_pretrained(best_dir)
            # print(f"Best model saved at step {global_step}")
            # remove_old_dirs_with_prefix(save_path, "best_", global_step)
            checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=global_step,prefix="best_",keep=1)                    
        # Save last model
        # model.params = unreplicate(state).params
        # last_dir = os.path.join(save_path, f"last_{global_step}")
        # model.save_pretrained(last_dir)
        # print(f"Last model saved at step {global_step}")
        # remove_old_dirs_with_prefix(save_path, "last_", global_step)
        checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=global_step,prefix="last_",keep=1)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on Imagenet")
    # --- Model & Training Config ---
    parser.add_argument("--hidden-size", type=int, default=768, help="Dimensionality of the encoder layers and the pooler layer.")
    parser.add_argument("--num-hidden-layers", type=int, default=12, help="Number of hidden layers in the Transformer encoder.")
    parser.add_argument("--num-attention-heads", type=int, default=12, help="Number of attention heads for each attention layer in the Transformer encoder.")
    parser.add_argument("--intermediate-size", type=int, default=3072, help="Dimensionality of the intermediate (feed-forward) layer in the Transformer encoder.")
    parser.add_argument("--position-embeddings", type=str, default='sinusoidal')
    parser.add_argument("--rotary_value", action="store_true",help='Whether or not apply rotary position embeddings on value layer.')
    parser.add_argument("--num-shared-experts", type=int, default = 1)
    parser.add_argument("--num-routed-experts", type=int, default = 0)
    parser.add_argument("--topk",  type=int, default = 0)
    parser.add_argument('--q_lora_rank', type=int, default=8,help='Rank of the LoRA adaptation for query projections.')
    parser.add_argument('--qk_rope_head_dim', type=int, default=64,help='Head dimension used for RoPE on query/key.')
    parser.add_argument('--kv_lora_rank', type=int, default=8,help='Rank of the LoRA adaptation for key/value projections.')
    parser.add_argument('--v_head_dim', type=int, default=64,help='Head dimension used for value projections.')
    parser.add_argument('--qk_nope_head_dim', type=int, default=64,help='Head dimension for NOPE (non-position encoding) on query/key.')
    parser.add_argument("--attention-bias", action="store_true",help='Use Bias in Attention.')
    parser.add_argument('--routed-scaling-factor', type=float, default=1.0,help='')
    parser.add_argument("--lmc-layer-indices",type=int,nargs="*",default=[],help="List of lmc layer indices (optional, default: empty list)")
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)    
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--patience", type=int, default=10, help="Early stopping patience")
    parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
    # --- Data Config ---
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'])
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument('--num_workers', type=int, default=10)
    parser.add_argument('--pin-mem', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--color-jitter', type=float, default=0.4)
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1')
    parser.add_argument('--train-interpolation', type=str, default='bicubic')
    parser.add_argument('--reprob', type=float, default=0.25)
    parser.add_argument('--remode', type=str, default='pixel')
    parser.add_argument('--recount', type=int, default=1)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    parser.add_argument("--wandb-entity", default=None, help="wandb entity for logging")
    parser.add_argument("--wandb-group", default=None, help="wandb group for logging")
    parser.add_argument("--wandb-project", default=None, help="wandb project name for logging")
    parser.add_argument("--wandb-id", default=None, help="wandb project name for logging")
    parser.add_argument("--wandb-logging-frequency", type=int, default=100, help="do logging every logging_frequency step")
    args = parser.parse_args()
    main(args)