import argparse
import augmax
import flax
import jax
import jax.nn
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds 
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad, vmap
from tqdm import tqdm
import matplotlib.pyplot as plt


from src.datasets import load_cifar10
from src.utils import flatten_params, rngmix, timeblock

'''def compute_cifar10_stats(train_ds):
    images = train_ds["images_u8"].astype(jnp.float32) / 255.0  # Shape: (50000, 32, 32, 3)
    mean = jnp.mean(images, axis=(0, 1, 2))  # Per-channel mean
    std = jnp.std(images, axis=(0, 1, 2))    # Per-channel std
    return mean, std

train_ds, test_ds = load_cifar10()
print(compute_cifar10_stats(train_ds))
(Array([0.49139968, 0.4821584 , 0.44653094], dtype=float32), Array([0.24703221, 0.24348514, 0.26158786], dtype=float32))
'''

def make_stuff(model):
  train_transform = augmax.Chain(
      # augmax does not seem to support random crops with padding. See https://github.com/khdlr/augmax/issues/6.
      augmax.RandomSizedCrop(32, 32, zoom_range=(0.8, 1.2)),
      augmax.HorizontalFlip(),
      augmax.Rotate(),
  )

  # Applied to all input images, test and train.
  normalize_transform = augmax.Chain(augmax.ByteToFloat(), augmax.Normalize(mean=[0.49139968, 0.4821584 , 0.44653094], std=[0.24703221, 0.24348514, 0.26158786]))

  @jit
  def batch_eval(params, images_u8, labels):
    images_f32 = vmap(normalize_transform)(None, images_u8)
    y_onehot = jax.nn.one_hot(labels, 10)
    logits = model.apply({"params": params}, images_f32)
    l = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
    num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
    return l, {"num_correct": num_correct}

  @jit
  def step(rng, train_state, images, labels):
    images_transformed = vmap(train_transform)(random.split(rng, images.shape[0]), images)
    (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_transformed,
                                                            labels)
    return train_state.apply_gradients(grads=g), {"batch_loss": l, **info}

  def dataset_loss_and_accuracy(params, dataset, batch_size: int):
    num_examples = dataset["images_u8"].shape[0]
    assert num_examples % batch_size == 0
    num_batches = num_examples // batch_size
    batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size))
    # Can't use vmap or run in a single batch since that overloads GPU memory.
    losses, infos = zip(*[
        batch_eval(
            params,
            dataset["images_u8"][batch_ix[i, :], :, :, :],
            dataset["labels"][batch_ix[i, :]],
        ) for i in range(num_batches)
    ])
    return (
        jnp.sum(batch_size * jnp.array(losses)) / num_examples,
        sum(x["num_correct"] for x in infos) / num_examples,
    )

  def get_mha_inputs(params, dataset, rng, batch_size: int):
      num_examples = dataset["images_u8"].shape[0]
      indices = random.choice(rng, num_examples, shape=(batch_size,), replace=False)
      images_u8 = dataset["images_u8"][indices]
      
      images_f32 = vmap(normalize_transform)(None, images_u8)
      
      _, variables = model.apply({"params": params}, images_f32, mutable=['intermediates'])
      
      intermediate_vars = variables.get('intermediates', {})
      activations = []
      # When submodules are created in a loop, `sow` namespaces the variables
      # with the submodule's auto-generated name (e.g., 'TransformerEncoderLayer_0').
      for i in range(model.num_layers):
          layer_key = f'TransformerEncoderLayer_{i}'
          if layer_key in intermediate_vars and 'mha_input' in intermediate_vars[layer_key]:
              # Flax stores sown values in a tuple, even if there's only one.
              activation = intermediate_vars[layer_key]['mha_input'][0]
              activations.append(activation)
          else:
              raise KeyError(f"Could not find 'mha_input' for layer {i} ('{layer_key}'). "
                           f"Available intermediates: {list(intermediate_vars.keys())}")

      if len(activations) != model.num_layers:
          raise ValueError(f"Expected to get activations for {model.num_layers} layers, "
                           f"but found {len(activations)}.")

      return activations

  return {
      "train_transform": train_transform,
      "normalize_transform": normalize_transform,
      "batch_eval": batch_eval,
      "step": step,
      "dataset_loss_and_accuracy": dataset_loss_and_accuracy,
      "get_mha_inputs": get_mha_inputs,
  }

def rope_dot_product_attention(query, key, value,
                              bias=None, dropout_rng=None, dropout_rate=0.0,
                              deterministic=False, dtype=jnp.float32, precision=None):
    """
    Computes dot-product attention after applying Rotary Position Embeddings (RoPE)
    to the query and key.
    """
    # Flax MHA expects inputs as (batch, num_heads, seq_len, head_dim)
    # We transpose to (batch, seq_len, num_heads, head_dim) for easier RoPE application
    query_t = jnp.transpose(query, (0, 2, 1, 3))
    key_t = jnp.transpose(key, (0, 2, 1, 3))
    
    seq_len = query_t.shape[1]
    head_dim = query_t.shape[-1]
    
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    
    # Create Rotary Embeddings
    # theta_i = 10000**(-2*(i-1)/d) for i=1..d/2 -> 10000**(-2j/d) for j=0..d/2-1
    freqs = 10000.0 ** (-jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    positions = jnp.arange(seq_len, dtype=jnp.float32)
    # freqs_grid shape: (seq_len, head_dim/2)
    freqs_grid = jnp.einsum('i,j->ij', positions, freqs)
    # emb shape: (seq_len, head_dim)
    emb = jnp.repeat(freqs_grid, 2, axis=-1)

    # Expand dims for broadcasting to (batch, seq_len, num_heads, head_dim)
    # cos_pos/sin_pos shape: (1, seq_len, 1, head_dim)
    cos_pos = jnp.cos(emb)[None, :, None, :]
    sin_pos = jnp.sin(emb)[None, :, None, :]
    
    # Helper to apply RoPE, consistent with row vector convention
    def _apply_rope(x, cos, sin):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_rotated = jnp.stack([-x2, x1], axis=-1).reshape(x.shape)
        return x * cos + x_rotated * sin
        
    query_rope = _apply_rope(query_t, cos_pos, sin_pos)
    key_rope = _apply_rope(key_t, cos_pos, sin_pos)
    
    # Transpose back to (batch, num_heads, seq_len, head_dim)
    query_rope = jnp.transpose(query_rope, (0, 2, 1, 3))
    key_rope = jnp.transpose(key_rope, (0, 2, 1, 3))
    
    # Call the original dot_product_attention with the rotated q,k
    return nn.dot_product_attention(
        query_rope, key_rope, value, bias=bias, dropout_rng=dropout_rng,
        dropout_rate=dropout_rate, deterministic=deterministic, dtype=dtype, precision=precision
    )

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(x)
        self.sow('intermediates', 'mha_input', x_norm)
        # Apply RoPE by passing the custom attention function
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, 
            qkv_features=self.embedding_dim,
            attention_fn=rope_dot_product_attention
        )(x_norm, x_norm)
        post_attention = x + attn_output
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(post_attention)
        ffn_hidden = nn.Dense(self.hidden_dim)(x_norm)
        ffn_post_activation = nn.gelu(ffn_hidden)
        ffn_output = nn.Dense(self.embedding_dim)(ffn_post_activation)
        post_ffn = post_attention + ffn_output
        return post_ffn
        
# Vision Transformer Model (adapted from mnist_vit_train.py for CIFAR-10)
class ViTModel(nn.Module):
    patch_size: int = 4  # Cifar10 32x32 images: 32/4 (patch size) = 8 patches per dimension
    embedding_dim: int = 128
    num_heads: int = 0
    num_layers: int = 0
    hidden_dim: int = 512
    num_classes: int = 10  # CIFAR-10 has 10 classes

    @nn.compact
    def __call__(self, x):
        # Input x: (batch_size, 32, 32, 3)
        x = nn.Conv(
            features=self.embedding_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID"
        )(x)  # Output: (batch_size, 8, 8, embedding_dim)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))  # (batch_size, 64, embedding_dim)
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)  # (batch_size, 65, embedding_dim)
        
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim
            )(x)
        x = x[:, 0, :]  # Extract CLS token
        x = nn.Dense(self.num_classes)(x)
        return x

'''num_heads = 4
    #print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, pretrained_params)
    print("Layers of the model:")
    for path, shape in layer_paths:  
        print(f"  {path} {shape}")'''
'''example output
Layers of the model:
  Conv_0/bias (128,)
  Conv_0/kernel (4, 4, 3, 128)
  Dense_0/bias (10,)
  Dense_0/kernel (128, 10)
  TransformerEncoderLayer_0/Dense_0/bias (512,)
  TransformerEncoderLayer_0/Dense_0/kernel (128, 512)
  TransformerEncoderLayer_0/Dense_1/bias (128,)
  TransformerEncoderLayer_0/Dense_1/kernel (512, 128)
  TransformerEncoderLayer_0/LayerNorm_0/bias (128,)
  TransformerEncoderLayer_0/LayerNorm_0/scale (128,)
  TransformerEncoderLayer_0/LayerNorm_1/bias (128,)
  TransformerEncoderLayer_0/LayerNorm_1/scale (128,)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/bias (4, 32)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/kernel (128, 4, 32)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/bias (128,)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/kernel (4, 32, 128)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/bias (4, 32)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/kernel (128, 4, 32)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/bias (4, 32)
  TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/kernel (128, 4, 32)
  TransformerEncoderLayer_1/Dense_0/bias (512,)
  TransformerEncoderLayer_1/Dense_0/kernel (128, 512)
  TransformerEncoderLayer_1/Dense_1/bias (128,)
  TransformerEncoderLayer_1/Dense_1/kernel (512, 128)
  TransformerEncoderLayer_1/LayerNorm_0/bias (128,)
  TransformerEncoderLayer_1/LayerNorm_0/scale (128,)
  TransformerEncoderLayer_1/LayerNorm_1/bias (128,)
  TransformerEncoderLayer_1/LayerNorm_1/scale (128,)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/bias (4, 32)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/kernel (128, 4, 32)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/bias (128,)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/kernel (4, 32, 128)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/bias (4, 32)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/kernel (128, 4, 32)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/bias (4, 32)
  TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/kernel (128, 4, 32)
  cls_token (1, 1, 128)
'''

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True)
    parser.add_argument("--learning-rate", type=float, required=True)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--num-heads", type=int, required=True)
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    # Configuration
    class Config:
        pass
    config = Config()
    config.seed = args.seed
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 100
    config.batch_size = 100  

    rng = random.PRNGKey(config.seed)
    model = ViTModel(num_layers = args.num_layers, num_heads = args.num_heads)
    stuff = make_stuff(model)

    with timeblock("Load datasets"):
        train_ds, test_ds = load_cifar10(data_dir="/root/log/cifar10/data")
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0
        print(f"Number of training examples: {num_train_examples}")
        print(f"Number of test examples: {num_test_examples}")

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (num_train_examples // config.batch_size)
        )
        tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        tx = optax.adam(config.learning_rate)
    else:  # adamw
        tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    init_params = model.init(rngmix(rng, "init"), jnp.zeros((1, 32, 32, 3)))["params"]

    #print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, init_params)
    print("Layers of the model:")
    for path, shape in layer_paths: 
        print(f"  {path} {shape}")

    # Initialize model (CIFAR-10 input shape: 32x32x3)
    train_state = TrainState.create(
        apply_fn=model.apply,
        params=init_params,
        tx=tx
    )

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Training loop
    for epoch in tqdm(range(config.num_epochs)):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape(
                (-1, config.batch_size)
            )
        batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0])
        for i in range(batch_ix.shape[0]):
            p = batch_ix[i, :]
            images_u8 = train_ds["images_u8"][p, :, :, :]
            labels = train_ds["labels"][p]
            train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels)
            infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        # Test evaluation
        with timeblock("Test set evaluation"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params, test_ds, 1000
            )
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        metrics_str = f"_trainloss_{train_loss:.4f}_testloss_{test_loss:.4f}_trainacc_{train_accuracy:.4f}_testacc_{test_accuracy:.4f}"
        weights_file = (
            f"{args.ckpt_path}/cifar10_vit_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{model.num_layers}_patch_size_{model.patch_size}_num_heads_{model.num_heads}_hidden_dim_{model.hidden_dim}_embedding_dim_{model.embedding_dim}_epoch{epoch}{metrics_str}"
        )
        with open(weights_file, "wb") as f:
            f.write(flax.serialization.to_bytes(train_state.params))

        # Plot and save metrics
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        # Loss subplot
        axs[0].plot(train_losses, label='Train Loss')
        axs[0].plot(test_losses, label='Test Loss')
        axs[0].set_xlabel('Epoch')
        axs[0].set_ylabel('Loss')
        axs[0].set_title('Training and Test Loss over Epochs')
        axs[0].legend()
        # Accuracy subplot
        axs[1].plot(train_accuracies, label='Train Accuracy')
        axs[1].plot(test_accuracies, label='Test Accuracy')
        axs[1].set_xlabel('Epoch')
        axs[1].set_ylabel('Accuracy')
        axs[1].set_title('Training and Test Accuracy over Epochs')
        axs[1].legend()
        plt.tight_layout()
        plt.savefig(f"{args.ckpt_path}/metrics_plot.png")
        plt.close()

if __name__ == "__main__":
    main()