from typing import Any, Callable, Optional, Tuple, Type

import flax.linen as nn
import jax.numpy as jnp

Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any

class IdentityLayer(nn.Module):
  """Identity layer, convenient for giving a name to an array."""

  @nn.compact
  def __call__(self, x):
    return x


class AddPositionEmbs(nn.Module):
  """Adds learned positional embeddings to the inputs.

  Attributes:
    posemb_init: positional embedding initializer.
  """

  posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
  param_dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, inputs):
    """Applies the AddPositionEmbs module.

    Args:
      inputs: Inputs to the layer.

    Returns:
      Output tensor with shape `(bs, timesteps, in_dim)`.
    """
    # inputs.shape is (batch_size, seq_len, emb_dim).
    assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                              ' but it is: %d' % inputs.ndim)
    pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
    pe = self.param(
        'pos_embedding', self.posemb_init, pos_emb_shape, self.param_dtype)
    return inputs + pe


class MlpBlock(nn.Module):
  """Transformer MLP / feed-forward block."""

  mlp_dim: int
  dtype: Dtype = jnp.float32
  param_dtype: Dtype = jnp.float32
  out_dim: Optional[int] = None
  dropout_rate: float = 0.1
  kernel_init: Callable[[PRNGKey, Shape, Dtype],
                        Array] = nn.initializers.xavier_uniform()
  bias_init: Callable[[PRNGKey, Shape, Dtype],
                      Array] = nn.initializers.normal(stddev=1e-6)

  @nn.compact
  def __call__(self, inputs, *, deterministic):
    """Applies Transformer MlpBlock module."""
    actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
    x = nn.Dense(
        features=self.mlp_dim,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
            inputs)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    output = nn.Dense(
        features=actual_out_dim,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
            x)
    output = nn.Dropout(
        rate=self.dropout_rate)(
            output, deterministic=deterministic)
    return output


class Encoder1DBlock(nn.Module):
  """Transformer encoder layer.

  Attributes:
    inputs: input data.
    mlp_dim: dimension of the mlp on top of attention block.
    dtype: the dtype of the computation (default: float32).
    dropout_rate: dropout rate.
    attention_dropout_rate: dropout for attention heads.
    deterministic: bool, deterministic or not (to apply dropout).
    num_heads: Number of heads in nn.MultiHeadDotProductAttention
  """

  mlp_dim: int
  num_heads: int
  dtype: Dtype = jnp.float32
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, inputs, *, deterministic):
    """Applies Encoder1DBlock module.

    Args:
      inputs: Inputs to the layer.
      deterministic: Dropout will not be applied when set to true.

    Returns:
      output after transformer encoder block.
    """

    # Attention block.
    assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
    x = nn.LayerNorm(dtype=self.dtype)(inputs)
    x = nn.MultiHeadDotProductAttention(
        dtype=self.dtype,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=self.attention_dropout_rate,
        num_heads=self.num_heads)(
            x, x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=self.dtype)(x)
    y = MlpBlock(
        mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
            y, deterministic=deterministic)

    return x + y


class Encoder(nn.Module):
  """Transformer Model Encoder for sequence to sequence translation.

  Attributes:
    num_layers: number of layers
    mlp_dim: dimension of the mlp on top of attention block
    num_heads: Number of heads in nn.MultiHeadDotProductAttention
    dropout_rate: dropout rate.
    attention_dropout_rate: dropout rate in self attention.
  """

  num_layers: int
  mlp_dim: int
  num_heads: int
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  add_position_embedding: bool = True

  @nn.compact
  def __call__(self, x, *, train):
    """Applies Transformer model on the inputs.

    Args:
      x: Inputs to the layer.
      train: Set to `True` when training.

    Returns:
      output of a transformer encoder.
    """
    assert x.ndim == 3  # (batch, len, emb)

    if self.add_position_embedding:
      x = AddPositionEmbs(
          posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
          name='posembed_input')(
              x)
      x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

    # Input Encoder
    for lyr in range(self.num_layers):
      x = Encoder1DBlock(
          mlp_dim=self.mlp_dim,
          dropout_rate=self.dropout_rate,
          attention_dropout_rate=self.attention_dropout_rate,
          name=f'encoderblock_{lyr}',
          num_heads=self.num_heads)(
              x, deterministic=not train)
    encoded = nn.LayerNorm(name='encoder_norm')(x)

    return encoded


class VisionTransformer(nn.Module):
  """VisionTransformer."""

  num_classes: int
  patches: Any
  transformer: Any
  hidden_size: int
  resnet: Optional[Any] = None
  representation_size: Optional[int] = None
  classifier: str = 'token'
  head_bias_init: float = 0.
  encoder: Type[nn.Module] = Encoder
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):

    x = inputs

    n, h, w, c = x.shape

    # We can merge s2d+emb into a single conv; it's the same.
    x = nn.Conv(
        features=self.hidden_size,
        kernel_size=self.patches.size,
        strides=self.patches.size,
        padding='VALID',
        name='embedding')(
            x)

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if self.transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if self.classifier in ['token', 'token_unpooled']:
        cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = self.encoder(name='Transformer', **self.transformer)(x, train=train)

    if self.classifier == 'token':
      x = x[:, 0]
    elif self.classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
    elif self.classifier in ['unpooled', 'token_unpooled']:
      pass
    else:
      raise ValueError(f'Invalid classifier={self.classifier}')

    if self.representation_size is not None:
      x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
      x = nn.tanh(x)
    else:
      x = IdentityLayer(name='pre_logits')(x)

    if self.num_classes:
      x = nn.Dense(
          features=self.num_classes,
          name='head',
          kernel_init=nn.initializers.zeros,
          bias_init=nn.initializers.constant(self.head_bias_init))(x)
    return x


class MoEBlock(nn.Module):
  """Mixture of Experts block."""
  num_experts: int
  expert_hidden_dim: int
  dtype: Dtype = jnp.float32
  param_dtype: Dtype = jnp.float32
  dropout_rate: float = 0.1
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)

  @nn.compact
  def __call__(self, inputs, *, deterministic, return_moe_outputs=False):
    hidden_dim = inputs.shape[-1]
    # Gating network
    gate_logits = nn.Dense(
        features=self.num_experts,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init,
        name='gate'
    )(inputs)  # Shape: (batch, seq_len, num_experts)
    gates = nn.softmax(gate_logits, axis=-1)

    # Define experts
    experts = []
    for i in range(self.num_experts):
      expert = nn.Sequential([
          nn.Dense(
              features=self.expert_hidden_dim,
              dtype=self.dtype,
              param_dtype=self.param_dtype,
              kernel_init=self.kernel_init,
              bias_init=self.bias_init,
              name=f'expert_{i}_dense1'
          ),
          nn.gelu,
          nn.Dropout(rate=self.dropout_rate, deterministic=deterministic),
          nn.Dense(
              features=hidden_dim,
              dtype=self.dtype,
              param_dtype=self.param_dtype,
              kernel_init=self.kernel_init,
              bias_init=self.bias_init,
              name=f'expert_{i}_dense2'
          ),
          nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)
      ])
      experts.append(expert)

    # Apply experts
    expert_outputs = [expert(inputs) for expert in experts]
    expert_outputs = jnp.stack(expert_outputs, axis=2)  # Shape: (batch, seq_len, num_experts, hidden_dim)

    # Weighted sum
    weighted_outputs = gates[:, :, :, None] * expert_outputs
    output = jnp.sum(weighted_outputs, axis=2)  # Shape: (batch, seq_len, hidden_dim)

    if return_moe_outputs:
      return output, gate_logits, expert_outputs
    return output


class Encoder1DBlockMoE(nn.Module):
  """Transformer encoder layer with MoE."""
  mlp_dim: int
  num_heads: int
  num_experts: int
  expert_hidden_dim: int
  # Optional parameters (with defaults)
  dtype: Dtype = jnp.float32
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, inputs, *, deterministic, return_moe_outputs=False):
    # Attention block
    x = nn.LayerNorm(dtype=self.dtype)(inputs)
    x = nn.MultiHeadDotProductAttention(
        dtype=self.dtype,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=self.attention_dropout_rate,
        num_heads=self.num_heads)(x, x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + inputs

    # MoE block
    y = nn.LayerNorm(dtype=self.dtype)(x)
    moe_block = MoEBlock(
        num_experts=self.num_experts,
        expert_hidden_dim=self.expert_hidden_dim,
        dtype=self.dtype,
        dropout_rate=self.dropout_rate,
        name='MoEBlock'
    )
    if return_moe_outputs:
      moe_output, gate_logits, expert_outputs = moe_block(
          y, deterministic=deterministic, return_moe_outputs=True)
      return x + moe_output, gate_logits, expert_outputs
    else:
      moe_output = moe_block(y, deterministic=deterministic)
      return x + moe_output


class EncoderMoE(nn.Module):
  """Transformer encoder with MoE in the last layer."""
  num_layers: int
  mlp_dim: int
  num_heads: int
  moe_layer_which: int  
  num_experts: int
  expert_hidden_dim: int
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  add_position_embedding: bool = True

  @nn.compact
  def __call__(self, x, *, train, return_moe_outputs=False):
    assert x.ndim == 3  # (batch, len, emb)
    if self.add_position_embedding:
      x = AddPositionEmbs(
          posemb_init=nn.initializers.normal(stddev=0.02),
          name='posembed_input')(x)
      x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

    # Encoder blocks
    for lyr in range(self.num_layers):
      if lyr != self.moe_layer_which:
        x = Encoder1DBlock(
            mlp_dim=self.mlp_dim,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
            name=f'encoderblock_{lyr}',
            num_heads=self.num_heads)(x, deterministic=not train)
      else:
        if return_moe_outputs:
          x, gate_logits, expert_outputs = Encoder1DBlockMoE(
              mlp_dim=self.mlp_dim,
              num_heads=self.num_heads,
              dropout_rate=self.dropout_rate,
              attention_dropout_rate=self.attention_dropout_rate,
              num_experts=self.num_experts,
              expert_hidden_dim=self.expert_hidden_dim,
              name=f'encoderblock_{lyr}'
          )(x, deterministic=not train, return_moe_outputs=True)
        else:
          x = Encoder1DBlockMoE(
              mlp_dim=self.mlp_dim,
              num_heads=self.num_heads,
              dropout_rate=self.dropout_rate,
              attention_dropout_rate=self.attention_dropout_rate,
              num_experts=self.num_experts,
              expert_hidden_dim=self.expert_hidden_dim,
              name=f'encoderblock_{lyr}'
          )(x, deterministic=not train)
    encoded = nn.LayerNorm(name='encoder_norm')(x)
    if return_moe_outputs:
      return encoded, gate_logits, expert_outputs
    return encoded

class VisionTransformerMoE(VisionTransformer):
  """VisionTransformer with MoE."""
  encoder: Type[nn.Module] = EncoderMoE

  @nn.compact
  def __call__(self, inputs, *, train, return_moe_outputs=False):
    x = inputs
    n, h, w, c = x.shape
    x = nn.Conv(
        features=self.hidden_size,
        kernel_size=self.patches.size,
        strides=self.patches.size,
        padding='VALID',
        name='embedding')(x)
    n, h, w, c = x.shape
    x = jnp.reshape(x, [n, h * w, c])

    if self.classifier in ['token', 'token_unpooled']:
      cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)

    encoder_output = self.encoder(name='Transformer', **self.transformer)(
        x, train=train, return_moe_outputs=return_moe_outputs)
    if return_moe_outputs:
      encoded, gate_logits, expert_outputs = encoder_output
    else:
      encoded = encoder_output

    if self.classifier == 'token':
      cls_output = encoded[:, 0]
    elif self.classifier == 'gap':
      cls_output = jnp.mean(encoded, axis=list(range(1, encoded.ndim - 1)))
    elif self.classifier in ['unpooled', 'token_unpooled']:
      cls_output = encoded
    else:
      raise ValueError(f'Invalid classifier={self.classifier}')

    if self.representation_size is not None:
      cls_output = nn.Dense(features=self.representation_size, name='pre_logits')(cls_output)
      cls_output = nn.tanh(cls_output)
    else:
      cls_output = IdentityLayer(name='pre_logits')(cls_output)

    if self.num_classes:
      output = nn.Dense(
          features=self.num_classes,
          name='head',
          kernel_init=nn.initializers.zeros,
          bias_init=nn.initializers.constant(self.head_bias_init))(cls_output)
    else:
      output = cls_output

    if return_moe_outputs:
      return output, gate_logits, expert_outputs
    return output

  def get_moe_outputs(self, params, x):
    """Returns gate logits and expert outputs from the MoE layer."""
    variables = {'params': params}
    _, gate_logits, expert_outputs = self.apply(
        variables, x, train=False, return_moe_outputs=True)
    return gate_logits, expert_outputs

  def get_moe_params(self, params):
    """Returns a dictionary of MoE gating and expert parameters as NumPy arrays."""
    num_layers = self.transformer['num_layers']
    moe_layer_which = self.transformer['moe_layer_which']
    num_experts = self.transformer['num_experts']
    moe_block_params = params['Transformer'][f'encoderblock_{moe_layer_which}']['MoEBlock']
    dict_out = {}
    # Gating parameters
    dict_out['gating_kernel'] = np.array(moe_block_params['gate']['kernel'])
    dict_out['gating_bias'] = np.array(moe_block_params['gate']['bias'])
    # Expert parameters
    for i in range(num_experts):
      dict_out[f'expert_{i}_layer1_kernel'] = np.array(
          moe_block_params[f'expert_{i}_dense1']['kernel'])
      dict_out[f'expert_{i}_layer1_bias'] = np.array(
          moe_block_params[f'expert_{i}_dense1']['bias'])
      dict_out[f'expert_{i}_layer2_kernel'] = np.array(
          moe_block_params[f'expert_{i}_dense2']['kernel'])
      dict_out[f'expert_{i}_layer2_bias'] = np.array(
          moe_block_params[f'expert_{i}_dense2']['bias'])
    return dict_out


import os
import heapq
from absl import logging
import flax
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
import ml_collections
import argparse
import jax.tree_util as tree_util
import flax.jax_utils as jax_utils
from flax import traverse_util

from vit_jax import checkpoint, input_pipeline, utils, models, train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, default=2, help="Number of layer in MoE block")
    parser.add_argument("--moe-layer-which", type=int, default=0, help="which mlp (start from 0) layer is replaced by moe")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE block")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    # Configuration
    ckpt_path = args.ckpt_path
    os.makedirs(ckpt_path, exist_ok=True)
    model_path = args.model_path
    assert os.path.exists(model_path), f"Model path {model_path} does not exist."

    seed = 0
    dataset = 'cifar10'
    batch_size = 512
    num_epochs = 15
    #patience = 5

    # Load dataset
    config = common_config.with_dataset(common_config.get_config(), dataset)
    config.batch = batch_size
    config.pp.crop = 224
    ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
    ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
    num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
    num_train_examples = input_pipeline.get_dataset_info(dataset, 'train')['num_examples']
    steps_per_epoch = num_train_examples // batch_size
    total_steps = num_epochs * steps_per_epoch

    # Model configurations
    config = ml_collections.ConfigDict()
    config.model_name = 'ViT-S_16'
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 384
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1536
    config.transformer.num_heads = 6
    config.transformer.num_layers = args.num_layers
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.0
    config.classifier = 'token'
    config.representation_size = None

    config_moe = ml_collections.ConfigDict(config.to_dict())
    config_moe.transformer.moe_layer_which = args.moe_layer_which
    config_moe.transformer.num_experts = args.num_experts
    config_moe.transformer.expert_hidden_dim = config.transformer.mlp_dim

    model_config = config
    model_config_moe = config_moe

    # Load pre-trained model
    params = checkpoint.load(model_path)

    # Initialize fine-tuning model
    finetune_model = VisionTransformerMoE(num_classes=num_classes, **model_config_moe)
    new_params = finetune_model.init(jax.random.PRNGKey(args.seed), jnp.ones((1, 224, 224, 3)), train=False)['params']

    # Copy pre-trained parameters to fine-tuning model
    def copy_matching_params(target, source):
        for key in source:
            if key in target:
                if isinstance(source[key], dict) and isinstance(target[key], dict):
                    copy_matching_params(target[key], source[key])
                else:
                    target[key] = source[key]
    copy_matching_params(new_params, params)

    # Define labels for trainable (MoE) and frozen parameters
    def get_labels(params):
        def label_fn(path, _):
            path_str = '/'.join([p.key for p in path])
            if 'MoE' in path_str:
                return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(new_params)

    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    # Optimizer setup
    warmup_steps = 5
    decay_type = 'cosine'
    grad_norm_clip = 1
    accum_steps = 8
    base_lr = args.learning_rate
    lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)

    base_tx = optax.chain(
        optax.clip_by_global_norm(grad_norm_clip),
        optax.sgd(learning_rate=lr_fn, momentum=0.9)
    )
    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Replicate parameters and optimizer state
    params_repl = jax_utils.replicate(new_params)
    opt_state = tx.init(new_params)
    opt_state_repl = jax_utils.replicate(opt_state)
    update_rng_repl = jax_utils.replicate(jax.random.PRNGKey(0))

    # Update functions
    vit_apply_repl = jax.pmap(lambda params, inputs: finetune_model.apply(
        dict(params=params), inputs, train=False))
    update_fn_repl = train.make_update_fn(apply_fn=finetune_model.apply, accum_steps=accum_steps, tx=tx)

    # Metrics storage
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    learning_rates = []

    # Early stopping
    best_param = None
    best_metric = None
    best_test_loss = float('inf')
    best_epoch = 0
    epochs_since_improvement = 0

    # Test metrics computation
    @jax.pmap
    def compute_batch_metrics(params, batch):
        logits = finetune_model.apply({'params': params}, batch['image'], train=False)
        loss = -jnp.sum(jax.nn.log_softmax(logits) * batch['label']) / batch['label'].shape[0]
        accuracy = (logits.argmax(axis=-1) == batch['label'].argmax(axis=-1)).mean()
        return loss, accuracy

    def evaluate_test(params_repl):
        test_loss_sum = 0.0
        test_accuracy_sum = 0.0
        num_batches = 0
        test_iter = iter(ds_test.as_numpy_iterator())
        for batch in test_iter:
            loss, accuracy = compute_batch_metrics(params_repl, batch)
            test_loss_sum += jnp.mean(loss)
            test_accuracy_sum += jnp.mean(accuracy)
            num_batches += 1
        return test_loss_sum / num_batches, test_accuracy_sum / num_batches

    # Initial evaluation
    test_loss, test_accuracy = evaluate_test(params_repl)
    print(f'Start: testloss_{test_loss:.4f}_testacc_{test_accuracy:.4f}')

    # Training loop
    train_iter = iter(ds_train.as_numpy_iterator())
    global_step = 0
    for epoch in range(num_epochs):
        train_loss_sum = 0.0
        train_correct_sum = 0.0
        for _ in range(steps_per_epoch):
            batch = next(train_iter)
            params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
                params_repl, opt_state_repl, batch, update_rng_repl
            )
            loss = jnp.mean(loss_repl)
            train_loss_sum += loss
            predicted = vit_apply_repl(params_repl, batch['image'])
            is_correct = (predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1))
            batch_accuracy = jnp.mean(is_correct)
            train_correct_sum += batch_accuracy
            global_step += 1

        # Compute epoch metrics
        train_loss = train_loss_sum / steps_per_epoch
        train_accuracy = train_correct_sum / steps_per_epoch
        train_losses.append(float(train_loss))
        train_accuracies.append(float(train_accuracy))

        # Evaluate on test set
        test_loss, test_accuracy = evaluate_test(params_repl)
        test_losses.append(float(test_loss))
        test_accuracies.append(float(test_accuracy))
        print(f"Epoch: {epoch} train_loss: {train_loss:.4f}, train_accuracy: {train_accuracy:.4f}, test_loss: {test_loss:.4f}, test_accuracy: {test_accuracy:.4f}")
        
        # Learning rate
        lr = lr_fn(global_step)
        learning_rates.append(float(lr))

        metrics_str = f"trainloss_{train_loss:.4f}_trainacc_{train_accuracy:.4f}_testloss_{test_loss:.4f}_testacc_{test_accuracy:.4f}"
        weights_file = f"{ckpt_path}/cifar10_vit_epoch{epoch}_{metrics_str}_seed_{args.seed}_lr_{args.learning_rate}_moe_layer_which_{args.moe_layer_which}_num_experts_{args.num_experts}.npz"
        with open(weights_file, "wb") as f:
            #f.write(flax.serialization.to_bytes(flax.jax_utils.unreplicate(params_repl)))
            np.savez(f, **traverse_util.flatten_dict(flax.jax_utils.unreplicate(params_repl), sep='/'))


    # Save plots
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f'{ckpt_path}/loss_plot.png')
    plt.close()

    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f'{ckpt_path}/accuracy_plot.png')
    plt.close()

    plt.figure()
    plt.plot(learning_rates, label='Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.savefig(f'{ckpt_path}/lr_plot.png')
    plt.close()

    print("Training completed. Plots and best model saved.")

if __name__ == "__main__":
    main()