from typing import Any, Callable, Optional, Tuple, Type

import flax.linen as nn
import jax.numpy as jnp

Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any

class IdentityLayer(nn.Module):
  """Identity layer, convenient for giving a name to an array."""

  @nn.compact
  def __call__(self, x):
    return x


class AddPositionEmbs(nn.Module):
  """Adds learned positional embeddings to the inputs.

  Attributes:
    posemb_init: positional embedding initializer.
  """

  posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
  param_dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, inputs):
    """Applies the AddPositionEmbs module.

    Args:
      inputs: Inputs to the layer.

    Returns:
      Output tensor with shape `(bs, timesteps, in_dim)`.
    """
    # inputs.shape is (batch_size, seq_len, emb_dim).
    assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                              ' but it is: %d' % inputs.ndim)
    pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
    pe = self.param(
        'pos_embedding', self.posemb_init, pos_emb_shape, self.param_dtype)
    return inputs + pe


class MlpBlock(nn.Module):
  """Transformer MLP / feed-forward block."""

  mlp_dim: int
  dtype: Dtype = jnp.float32
  param_dtype: Dtype = jnp.float32
  out_dim: Optional[int] = None
  dropout_rate: float = 0.1
  kernel_init: Callable[[PRNGKey, Shape, Dtype],
                        Array] = nn.initializers.xavier_uniform()
  bias_init: Callable[[PRNGKey, Shape, Dtype],
                      Array] = nn.initializers.normal(stddev=1e-6)

  @nn.compact
  def __call__(self, inputs, *, deterministic):
    """Applies Transformer MlpBlock module."""
    actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
    x = nn.Dense(
        features=self.mlp_dim,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
            inputs)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    output = nn.Dense(
        features=actual_out_dim,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
            x)
    output = nn.Dropout(
        rate=self.dropout_rate)(
            output, deterministic=deterministic)
    return output


class Encoder1DBlock(nn.Module):
  """Transformer encoder layer.

  Attributes:
    inputs: input data.
    mlp_dim: dimension of the mlp on top of attention block.
    dtype: the dtype of the computation (default: float32).
    dropout_rate: dropout rate.
    attention_dropout_rate: dropout for attention heads.
    deterministic: bool, deterministic or not (to apply dropout).
    num_heads: Number of heads in nn.MultiHeadDotProductAttention
    attention_fn: Attention function to use in MultiHeadDotProductAttention.
  """

  mlp_dim: int
  num_heads: int
  dtype: Dtype = jnp.float32
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  attention_fn: Callable = nn.dot_product_attention

  @nn.compact
  def __call__(self, inputs, *, deterministic):
    """Applies Encoder1DBlock module.

    Args:
      inputs: Inputs to the layer.
      deterministic: Dropout will not be applied when set to true.

    Returns:
      output after transformer encoder block.
    """

    # Attention block.
    assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
    x = nn.LayerNorm(dtype=self.dtype)(inputs)
    self.sow('intermediates', 'mha_input', x)
    x = nn.MultiHeadDotProductAttention(
        dtype=self.dtype,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=self.attention_dropout_rate,
        num_heads=self.num_heads,
        attention_fn=self.attention_fn)(
            x, x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=self.dtype)(x)
    y = MlpBlock(
        mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
            y, deterministic=deterministic)

    return x + y


class Encoder(nn.Module):
  """Transformer Model Encoder for sequence to sequence translation.

  Attributes:
    num_layers: number of layers
    mlp_dim: dimension of the mlp on top of attention block
    num_heads: Number of heads in nn.MultiHeadDotProductAttention
    dropout_rate: dropout rate.
    attention_dropout_rate: dropout rate in self attention.
    add_position_embedding: Whether to add learned positional embeddings.
    attention_fn: Attention function to use in Encoder1DBlock.
  """

  num_layers: int
  mlp_dim: int
  num_heads: int
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  add_position_embedding: bool = True
  attention_fn: Callable = nn.dot_product_attention

  @nn.compact
  def __call__(self, x, *, train):
    """Applies Transformer model on the inputs.

    Args:
      x: Inputs to the layer.
      train: Set to `True` when training.

    Returns:
      output of a transformer encoder.
    """
    assert x.ndim == 3  # (batch, len, emb)

    if self.add_position_embedding:
      x = AddPositionEmbs(
          posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
          name='posembed_input')(
              x)
      x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

    # Input Encoder
    for lyr in range(self.num_layers):
      x = Encoder1DBlock(
          mlp_dim=self.mlp_dim,
          dropout_rate=self.dropout_rate,
          attention_dropout_rate=self.attention_dropout_rate,
          name=f'encoderblock_{lyr}',
          num_heads=self.num_heads,
          attention_fn=self.attention_fn)(
              x, deterministic=not train)
    encoded = nn.LayerNorm(name='encoder_norm')(x)

    return encoded


class VisionTransformer(nn.Module):
  """VisionTransformer."""

  num_classes: int
  patches: Any
  transformer: Any
  hidden_size: int
  resnet: Optional[Any] = None
  representation_size: Optional[int] = None
  classifier: str = 'token'
  head_bias_init: float = 0.
  encoder: Type[nn.Module] = Encoder
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):

    x = inputs

    n, h, w, c = x.shape

    # We can merge s2d+emb into a single conv; it's the same.
    x = nn.Conv(
        features=self.hidden_size,
        kernel_size=self.patches.size,
        strides=self.patches.size,
        padding='VALID',
        name='embedding')(
            x)

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if self.transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if self.classifier in ['token', 'token_unpooled']:
        cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = self.encoder(name='Transformer', **self.transformer)(x, train=train)

    if self.classifier == 'token':
      x = x[:, 0]
    elif self.classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
    elif self.classifier in ['unpooled', 'token_unpooled']:
      pass
    else:
      raise ValueError(f'Invalid classifier={self.classifier}')

    if self.representation_size is not None:
      x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
      x = nn.tanh(x)
    else:
      x = IdentityLayer(name='pre_logits')(x)

    if self.num_classes:
      x = nn.Dense(
          features=self.num_classes,
          name='head',
          kernel_init=nn.initializers.zeros,
          bias_init=nn.initializers.constant(self.head_bias_init))(x)
    return x

def get_mha_inputs(model, params, images, train=False):
    _, variables = model.apply(
        {"params": params},
        images,
        train=train,
        mutable=['intermediates']
    )
    
    intermediate_vars = variables.get('intermediates', {})
    activations = []
    num_layers = model.transformer.num_layers
    for i in range(num_layers):
        layer_key = f'encoderblock_{i}'
        if 'Transformer' in intermediate_vars and layer_key in intermediate_vars['Transformer'] and 'mha_input' in intermediate_vars['Transformer'][layer_key]:
            activation = intermediate_vars['Transformer'][layer_key]['mha_input'][0]
            activations.append(activation)
        else:
            raise KeyError(f"Could not find 'mha_input' for layer {i} ('Transformer/{layer_key}'). "
                          f"Available intermediates: {list(intermediate_vars.keys())}")
    if len(activations) != num_layers:
        raise ValueError(f"Expected to get activations for {num_layers} layers, "
                        f"but found {len(activations)}.")
    return activations

from flax import traverse_util

def sanity_check_params(source_params, target_params, print_matched=False):
    """
    Sanity check for parameter copying. Flattens both param trees and compares keys and shapes.
    
    Args:
        source_params: The original parameters (e.g., from pre-trained model).
        target_params: The target parameters after copying (e.g., new_params).
    """
    # Flatten the parameter trees using '/' as separator for readable paths.
    source_flat = traverse_util.flatten_dict(source_params, sep='/')
    target_flat = traverse_util.flatten_dict(target_params, sep='/')
    
    # Get sets of keys for comparison.
    source_keys = set(source_flat.keys())
    target_keys = set(target_flat.keys())
    
    # Matching keys.
    matching_keys = sorted(source_keys.intersection(target_keys))
    
    # Unmatched keys.
    source_only = sorted(source_keys - target_keys)
    target_only = sorted(target_keys - source_keys)
    
    if print_matched:
        print("Matching layers:")
        for key in matching_keys:
            source_shape = source_flat[key].shape
            target_shape = target_flat[key].shape
            if source_shape == target_shape:
                print(f"  {key}: {source_shape}")
            else:
                print(f"  {key}: source {source_shape} != target {target_shape}")
    
    print("\nUnmatched layers in source:")
    for key in source_only:
        print(f"  {key}: {source_flat[key].shape}")
    
    print("\nUnmatched layers in target:")
    for key in target_only:
        print(f"  {key}: {target_flat[key].shape}")


import os
import heapq
from absl import logging
import flax
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
import ml_collections
import argparse
import jax.tree_util as tree_util
import flax.jax_utils as jax_utils
from flax import traverse_util

from vit_jax import checkpoint, input_pipeline, utils, models, train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

def rope_dot_product_attention(query, key, value,
                              bias=None, dropout_rng=None, dropout_rate=0.0,
                              deterministic=False, dtype=jnp.float32, precision=None):
    """
    Computes dot-product attention after applying Rotary Position Embeddings (RoPE)
    to the query and key.
    """
    # Flax MHA expects inputs as (batch, num_heads, seq_len, head_dim)
    # We transpose to (batch, seq_len, num_heads, head_dim) for easier RoPE application
    query_t = jnp.transpose(query, (0, 2, 1, 3))
    key_t = jnp.transpose(key, (0, 2, 1, 3))
    
    seq_len = query_t.shape[1]
    head_dim = query_t.shape[-1]
    
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    
    # Create Rotary Embeddings
    # theta_i = 10000**(-2*(i-1)/d) for i=1..d/2 -> 10000**(-2j/d) for j=0..d/2-1
    freqs = 10000.0 ** (-jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    positions = jnp.arange(seq_len, dtype=jnp.float32)
    # freqs_grid shape: (seq_len, head_dim/2)
    freqs_grid = jnp.einsum('i,j->ij', positions, freqs)
    # emb shape: (seq_len, head_dim)
    emb = jnp.repeat(freqs_grid, 2, axis=-1)

    # Expand dims for broadcasting to (batch, seq_len, num_heads, head_dim)
    # cos_pos/sin_pos shape: (1, seq_len, 1, head_dim)
    cos_pos = jnp.cos(emb)[None, :, None, :]
    sin_pos = jnp.sin(emb)[None, :, None, :]
    
    # Helper to apply RoPE, consistent with row vector convention
    def _apply_rope(x, cos, sin):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_rotated = jnp.stack([-x2, x1], axis=-1).reshape(x.shape)
        return x * cos + x_rotated * sin
        
    query_rope = _apply_rope(query_t, cos_pos, sin_pos)
    key_rope = _apply_rope(key_t, cos_pos, sin_pos)
    
    # Transpose back to (batch, num_heads, seq_len, head_dim)
    query_rope = jnp.transpose(query_rope, (0, 2, 1, 3))
    key_rope = jnp.transpose(key_rope, (0, 2, 1, 3))
    
    # Call the original dot_product_attention with the rotated q,k
    return nn.dot_product_attention(
        query_rope, key_rope, value, bias=bias, dropout_rng=dropout_rng,
        dropout_rate=dropout_rate, deterministic=deterministic, dtype=dtype, precision=precision
    )

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset used")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer(s) to finetune")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    parser.add_argument("--num-epochs", type=int, default=10)
    args = parser.parse_args()

    # Configuration
    ckpt_path = args.ckpt_path
    os.makedirs(ckpt_path, exist_ok=True)
    model_path = args.model_path
    assert os.path.exists(model_path), f"Model path {model_path} does not exist."
    assert args.dataset in ['cifar10', 'cifar100']
    dataset = args.dataset
    batch_size = 512
    num_epochs = args.num_epochs
    patience = 5

    # Load dataset
    config = common_config.with_dataset(common_config.get_config(), dataset)
    config.batch = batch_size
    config.pp.crop = 224
    ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
    ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
    num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
    num_train_examples = input_pipeline.get_dataset_info(dataset, 'train')['num_examples']
    steps_per_epoch = num_train_examples // batch_size
    total_steps = num_epochs * steps_per_epoch

    # Model configurations
    config = ml_collections.ConfigDict()
    config.model_name = 'ViT-S_16'
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 384
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1536
    config.transformer.num_heads = 6
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.0
    config.classifier = 'token'
    config.representation_size = None

    config_finetune = ml_collections.ConfigDict(config.to_dict())
    #config_finetune.transformer.rope_use = args.rope_use
    finetune_layer_which = list(range(config.transformer.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]

    if args.rope_use:
        config_finetune.transformer.attention_fn = rope_dot_product_attention
        config_finetune.transformer.add_position_embedding = False
    else:
        config_finetune.transformer.attention_fn = nn.dot_product_attention
        config_finetune.transformer.add_position_embedding = True

    model_config = config
    model_config_finetune = config_finetune

    # Load pre-trained model
    loaded_params = checkpoint.load(model_path)

    # Initialize fine-tuning model
    finetune_model = VisionTransformer(num_classes=num_classes, **model_config_finetune)
    new_params = finetune_model.init(jax.random.PRNGKey(args.seed), jnp.ones((1, 224, 224, 3)), train=False)['params']
    sanity_check_params(new_params, loaded_params, print_matched=False)

    '''#print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, new_params)
    print("Layers of the model:")
    for path, shape in layer_paths:  
        print(f"  {path} {shape}")
    Layers of the model:
      Transformer/encoder_norm/bias (384,)
      Transformer/encoder_norm/scale (384,)
      Transformer/encoderblock_0/LayerNorm_0/bias (384,)
      Transformer/encoderblock_0/LayerNorm_0/scale (384,)
      Transformer/encoderblock_0/LayerNorm_1/bias (384,)
      Transformer/encoderblock_0/LayerNorm_1/scale (384,)
      Transformer/encoderblock_0/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_0/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_0/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_0/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_0/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_1/LayerNorm_0/bias (384,)
      Transformer/encoderblock_1/LayerNorm_0/scale (384,)
      Transformer/encoderblock_1/LayerNorm_1/bias (384,)
      Transformer/encoderblock_1/LayerNorm_1/scale (384,)
      Transformer/encoderblock_1/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_1/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_1/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_1/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_1/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_10/LayerNorm_0/bias (384,)
      Transformer/encoderblock_10/LayerNorm_0/scale (384,)
      Transformer/encoderblock_10/LayerNorm_1/bias (384,)
      Transformer/encoderblock_10/LayerNorm_1/scale (384,)
      Transformer/encoderblock_10/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_10/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_10/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_10/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_10/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_11/LayerNorm_0/bias (384,)
      Transformer/encoderblock_11/LayerNorm_0/scale (384,)
      Transformer/encoderblock_11/LayerNorm_1/bias (384,)
      Transformer/encoderblock_11/LayerNorm_1/scale (384,)
      Transformer/encoderblock_11/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_11/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_11/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_11/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_11/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_2/LayerNorm_0/bias (384,)
      Transformer/encoderblock_2/LayerNorm_0/scale (384,)
      Transformer/encoderblock_2/LayerNorm_1/bias (384,)
      Transformer/encoderblock_2/LayerNorm_1/scale (384,)
      Transformer/encoderblock_2/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_2/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_2/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_2/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_2/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_3/LayerNorm_0/bias (384,)
      Transformer/encoderblock_3/LayerNorm_0/scale (384,)
      Transformer/encoderblock_3/LayerNorm_1/bias (384,)
      Transformer/encoderblock_3/LayerNorm_1/scale (384,)
      Transformer/encoderblock_3/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_3/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_3/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_3/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_3/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_4/LayerNorm_0/bias (384,)
      Transformer/encoderblock_4/LayerNorm_0/scale (384,)
      Transformer/encoderblock_4/LayerNorm_1/bias (384,)
      Transformer/encoderblock_4/LayerNorm_1/scale (384,)
      Transformer/encoderblock_4/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_4/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_4/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_4/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_4/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_5/LayerNorm_0/bias (384,)
      Transformer/encoderblock_5/LayerNorm_0/scale (384,)
      Transformer/encoderblock_5/LayerNorm_1/bias (384,)
      Transformer/encoderblock_5/LayerNorm_1/scale (384,)
      Transformer/encoderblock_5/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_5/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_5/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_5/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_5/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_6/LayerNorm_0/bias (384,)
      Transformer/encoderblock_6/LayerNorm_0/scale (384,)
      Transformer/encoderblock_6/LayerNorm_1/bias (384,)
      Transformer/encoderblock_6/LayerNorm_1/scale (384,)
      Transformer/encoderblock_6/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_6/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_6/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_6/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_6/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_7/LayerNorm_0/bias (384,)
      Transformer/encoderblock_7/LayerNorm_0/scale (384,)
      Transformer/encoderblock_7/LayerNorm_1/bias (384,)
      Transformer/encoderblock_7/LayerNorm_1/scale (384,)
      Transformer/encoderblock_7/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_7/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_7/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_7/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_7/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_8/LayerNorm_0/bias (384,)
      Transformer/encoderblock_8/LayerNorm_0/scale (384,)
      Transformer/encoderblock_8/LayerNorm_1/bias (384,)
      Transformer/encoderblock_8/LayerNorm_1/scale (384,)
      Transformer/encoderblock_8/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_8/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_8/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_8/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_8/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/encoderblock_9/LayerNorm_0/bias (384,)
      Transformer/encoderblock_9/LayerNorm_0/scale (384,)
      Transformer/encoderblock_9/LayerNorm_1/bias (384,)
      Transformer/encoderblock_9/LayerNorm_1/scale (384,)
      Transformer/encoderblock_9/MlpBlock_0/Dense_0/bias (1536,)
      Transformer/encoderblock_9/MlpBlock_0/Dense_0/kernel (384, 1536)
      Transformer/encoderblock_9/MlpBlock_0/Dense_1/bias (384,)
      Transformer/encoderblock_9/MlpBlock_0/Dense_1/kernel (1536, 384)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/bias (6, 64)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/key/kernel (384, 6, 64)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/bias (384,)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/out/kernel (6, 64, 384)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/bias (6, 64)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/query/kernel (384, 6, 64)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/bias (6, 64)
      Transformer/encoderblock_9/MultiHeadDotProductAttention_0/value/kernel (384, 6, 64)
      Transformer/posembed_input/pos_embedding (1, 197, 384)
      cls (1, 1, 384)
      embedding/bias (384,)
      embedding/kernel (16, 16, 3, 384)
      head/bias (10,)
      head/kernel (384, 10)
    '''
    # Copy pre-trained parameters to fine-tuning model
    def copy_matching_params(target, source):
        for key in source:
            if key in target:
                if isinstance(source[key], dict) and isinstance(target[key], dict):
                    copy_matching_params(target[key], source[key])
                else:
                    target[key] = source[key]
    
    copy_matching_params(new_params, loaded_params)

    # Define labels for trainable (MoE) and frozen parameters
    def get_labels(params):
        def label_fn(path, _):
            if len(path) == 5 and path[2].key == 'MultiHeadDotProductAttention_0' and path[3].key in ['key', 'query', 'value', 'out'] and path[4].key in ['bias', 'kernel']:
                layer_key = path[1].key
                if layer_key.startswith('encoderblock_'):
                    idx = int(layer_key.split('_')[-1])
                    if idx in finetune_layer_which:
                        return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, loaded_params)

    labels = get_labels(loaded_params)

    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    # Optimizer setup
    warmup_steps = 5
    decay_type = 'cosine'
    grad_norm_clip = 1
    accum_steps = 8
    base_lr = args.learning_rate
    lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)

    base_tx = optax.chain(
        optax.clip_by_global_norm(grad_norm_clip),
        optax.sgd(learning_rate=lr_fn, momentum=0.9)
    )
    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Replicate parameters and optimizer state
    params_repl = jax_utils.replicate(loaded_params)
    opt_state = tx.init(loaded_params)
    opt_state_repl = jax_utils.replicate(opt_state)
    update_rng_repl = jax_utils.replicate(jax.random.PRNGKey(0))

    # Update functions
    vit_apply_repl = jax.pmap(lambda params, inputs: finetune_model.apply(
        dict(params=params), inputs, train=False))
    update_fn_repl = train.make_update_fn(apply_fn=finetune_model.apply, accum_steps=accum_steps, tx=tx)

    # Metrics storage
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    learning_rates = []

    # Early stopping
    best_param = None
    best_metric = None
    best_test_loss = float('inf')
    best_epoch = 0
    epochs_since_improvement = 0

    # Test metrics computation
    @jax.pmap
    def compute_batch_metrics(params, batch):
        logits = finetune_model.apply({'params': params}, batch['image'], train=False)
        loss = -jnp.sum(jax.nn.log_softmax(logits) * batch['label']) / batch['label'].shape[0]
        accuracy = (logits.argmax(axis=-1) == batch['label'].argmax(axis=-1)).mean()
        return loss, accuracy

    def evaluate_test(params_repl):
        test_loss_sum = 0.0
        test_accuracy_sum = 0.0
        num_batches = 0
        test_iter = iter(ds_test.as_numpy_iterator())
        for batch in test_iter:
            loss, accuracy = compute_batch_metrics(params_repl, batch)
            test_loss_sum += jnp.mean(loss)
            test_accuracy_sum += jnp.mean(accuracy)
            num_batches += 1
        return test_loss_sum / num_batches, test_accuracy_sum / num_batches

    # Initial evaluation
    test_loss, test_accuracy = evaluate_test(params_repl)
    print(f'Start: testloss_{test_loss:.4f}_testacc_{test_accuracy:.4f}')

    # Training loop
    train_iter = iter(ds_train.as_numpy_iterator())
    global_step = 0
    for epoch in range(num_epochs):
        train_loss_sum = 0.0
        train_correct_sum = 0.0
        for _ in range(steps_per_epoch):
            batch = next(train_iter)
            params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
                params_repl, opt_state_repl, batch, update_rng_repl
            )
            loss = jnp.mean(loss_repl)
            train_loss_sum += loss
            predicted = vit_apply_repl(params_repl, batch['image'])
            is_correct = (predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1))
            batch_accuracy = jnp.mean(is_correct)
            train_correct_sum += batch_accuracy
            global_step += 1
            print(f"Epoch {epoch}: loss: {loss} batch_accuracy: {batch_accuracy}")

        # Compute epoch metrics
        train_loss = train_loss_sum / steps_per_epoch
        train_accuracy = train_correct_sum / steps_per_epoch
        train_losses.append(float(train_loss))
        train_accuracies.append(float(train_accuracy))

        # Evaluate on test set
        test_loss, test_accuracy = evaluate_test(params_repl)
        test_losses.append(float(test_loss))
        test_accuracies.append(float(test_accuracy))
        print(f"Epoch: {epoch} train_loss: {train_loss:.4f}, train_accuracy: {train_accuracy:.4f}, test_loss: {test_loss:.4f}, test_accuracy: {test_accuracy:.4f}")
        
        # Learning rate
        lr = lr_fn(global_step)
        learning_rates.append(float(lr))

        # Early stopping and save best model
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_param = jax_utils.unreplicate(params_repl)
            epochs_since_improvement = 0
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            best_epoch = epoch
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
            
    metrics_str = f"testloss_{best_metric["test_loss"]}_testacc_{best_metric["test_accuracy"]}"
    weights_file = f"{ckpt_path}/imgnet{dataset}_vit_epoch{best_epoch}_{metrics_str}_seed_{args.seed}_lr_{args.learning_rate}_finetune_layer_which_{args.finetune_layer_which}.npz"
    if best_param is None: print("Warning: Training finished, but no best parameters were saved. Check the training logic.")
    with open(weights_file, "wb") as f:
        #f.write(flax.serialization.to_bytes(flax.jax_utils.unreplicate(params_repl)))
        np.savez(f, **traverse_util.flatten_dict(best_param, sep='/'))
        print(f"Saving checkpoint to {weights_file}")

    '''with open(f'{args.ckpt_path}/ckpt.txt', "a") as f:
        f.write(f"\n{weights_file}\n")
        for metric, value in best_metric.items():
            f.write(f"{metric}: {value}\n")'''

    # Plot and save metrics
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    # Loss subplot
    axs[0].plot(train_losses, label='Train Loss')
    axs[0].plot(test_losses, label='Test Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[0].set_title('Training and Test Loss over Epochs')
    axs[0].legend()
    # Accuracy subplot
    axs[1].plot(train_accuracies, label='Train Accuracy')
    axs[1].plot(test_accuracies, label='Test Accuracy')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].set_title('Training and Test Accuracy over Epochs')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(f"{ckpt_path}/{args.seed}metrics_plot.png")
    plt.close()

if __name__ == "__main__":
    main()