import argparse
import augmax
import flax
import jax
import jax.nn
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad, vmap
from tqdm import tqdm
import matplotlib.pyplot as plt
from src.utils import rngmix, timeblock

def rope_dot_product_attention(query, key, value,
                              bias=None, dropout_rng=None, dropout_rate=0.0,
                              deterministic=False, dtype=jnp.float32, precision=None):
    """
    Computes dot-product attention after applying Rotary Position Embeddings (RoPE)
    to the query and key.
    """
    # Flax MHA expects inputs as (batch, num_heads, seq_len, head_dim)
    # We transpose to (batch, seq_len, num_heads, head_dim) for easier RoPE application
    query_t = jnp.transpose(query, (0, 2, 1, 3))
    key_t = jnp.transpose(key, (0, 2, 1, 3))
    
    seq_len = query_t.shape[1]
    head_dim = query_t.shape[-1]
    
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    
    # Create Rotary Embeddings
    # theta_i = 10000**(-2*(i-1)/d) for i=1..d/2 -> 10000**(-2j/d) for j=0..d/2-1
    freqs = 10000.0 ** (-jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    positions = jnp.arange(seq_len, dtype=jnp.float32)
    # freqs_grid shape: (seq_len, head_dim/2)
    freqs_grid = jnp.einsum('i,j->ij', positions, freqs)
    # emb shape: (seq_len, head_dim)
    emb = jnp.repeat(freqs_grid, 2, axis=-1)

    # Expand dims for broadcasting to (batch, seq_len, num_heads, head_dim)
    # cos_pos/sin_pos shape: (1, seq_len, 1, head_dim)
    cos_pos = jnp.cos(emb)[None, :, None, :]
    sin_pos = jnp.sin(emb)[None, :, None, :]
    
    # Helper to apply RoPE, consistent with row vector convention
    def _apply_rope(x, cos, sin):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_rotated = jnp.stack([-x2, x1], axis=-1).reshape(x.shape)
        return x * cos + x_rotated * sin
        
    query_rope = _apply_rope(query_t, cos_pos, sin_pos)
    key_rope = _apply_rope(key_t, cos_pos, sin_pos)
    
    # Transpose back to (batch, num_heads, seq_len, head_dim)
    query_rope = jnp.transpose(query_rope, (0, 2, 1, 3))
    key_rope = jnp.transpose(key_rope, (0, 2, 1, 3))
    
    # Call the original dot_product_attention with the rotated q,k
    return nn.dot_product_attention(
        query_rope, key_rope, value, bias=bias, dropout_rng=dropout_rng,
        dropout_rate=dropout_rate, deterministic=deterministic, dtype=dtype, precision=precision
    )

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(x)
        self.sow('intermediates', 'mha_input', x_norm)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, 
            qkv_features=self.embedding_dim,
            attention_fn=rope_dot_product_attention
        )(x_norm, x_norm)
        post_attention = x + attn_output
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(post_attention)
        ffn_hidden = nn.Dense(self.hidden_dim)(x_norm)
        ffn_post_activation = nn.gelu(ffn_hidden)
        ffn_output = nn.Dense(self.embedding_dim)(ffn_post_activation)
        post_ffn = post_attention + ffn_output
        return post_ffn

# Vision Transformer Model
class ViTModel(nn.Module):
    patch_size: int = 7
    embedding_dim: int = 16
    num_heads: int = 0
    num_layers: int = 0
    hidden_dim: int = 64
    num_classes: int = 10

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.embedding_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding="VALID")(x)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        # Positional embeddings are removed, RoPE handles position information
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(embedding_dim=self.embedding_dim, num_heads=self.num_heads, hidden_dim=self.hidden_dim)(x)
        x = x[:, 0, :]
        x = nn.Dense(self.num_classes)(x)
        return x

# Utility Functions
def make_stuff(model):
    normalize_transform = augmax.ByteToFloat()

    @jit
    def batch_eval(params, images_u8, labels):
        images_f32 = vmap(normalize_transform)(None, images_u8)
        logits = model.apply({"params": params}, images_f32)
        y_onehot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
        num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
        return loss, {"num_correct": num_correct}

    @jit
    def step(train_state, images_u8, labels):
        (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_u8, labels)
        return train_state.apply_gradients(grads=g), {"batch_loss": l, **info}

    def dataset_loss_and_accuracy(params, dataset, batch_size: int):
        num_examples = dataset["images_u8"].shape[0]
        assert num_examples % batch_size == 0
        num_batches = num_examples // batch_size
        batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size))
        losses, infos = zip(*[batch_eval(params, dataset["images_u8"][batch_ix[i, :], :, :, :], dataset["labels"][batch_ix[i, :]]) for i in range(num_batches)])
        return jnp.sum(batch_size * jnp.array(losses)) / num_examples, sum(x["num_correct"] for x in infos) / num_examples

    def get_mha_inputs(params, dataset, rng, batch_size: int):
        num_examples = dataset["images_u8"].shape[0]
        indices = random.choice(rng, num_examples, shape=(batch_size,), replace=False)
        images_u8 = dataset["images_u8"][indices]
        images_f32 = vmap(normalize_transform)(None, images_u8)
        _, variables = model.apply({"params": params}, images_f32, mutable=['intermediates'])
        intermediate_vars = variables.get('intermediates', {})
        activations = []
        for i in range(model.num_layers):
            layer_key = f'TransformerEncoderLayer_{i}'
            if layer_key in intermediate_vars and 'mha_input' in intermediate_vars[layer_key]:
                activation = intermediate_vars[layer_key]['mha_input'][0]
                activations.append(activation)
            else:
                raise KeyError(f"Could not find 'mha_input' for layer {i} ('{layer_key}'). "
                               f"Available intermediates: {list(intermediate_vars.keys())}")
        if len(activations) != model.num_layers:
            raise ValueError(f"Expected to get activations for {model.num_layers} layers, "
                             f"but found {len(activations)}.")
        return activations

    return {"normalize_transform": normalize_transform, "batch_eval": batch_eval, "step": step, "dataset_loss_and_accuracy": dataset_loss_and_accuracy, "get_mha_inputs": get_mha_inputs}

# Load Datasets
def load_datasets(data_dir="/log-lmc_attn-mnist/data"):
    train_ds_images_u8, train_ds_labels = tfds.as_numpy(tfds.load("mnist", split="train", batch_size=-1, as_supervised=True, data_dir=data_dir))
    test_ds_images_u8, test_ds_labels = tfds.as_numpy(tfds.load("mnist", split="test", batch_size=-1, as_supervised=True, data_dir=data_dir))
    return {"images_u8": train_ds_images_u8, "labels": train_ds_labels}, {"images_u8": test_ds_images_u8, "labels": test_ds_labels}

# Main Function with GPU 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True)
    parser.add_argument("--learning-rate", type=float, required=True)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--num-heads", type=int, required=True)
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    class Config:
        pass

    config = Config()
    config.seed = args.seed
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 100
    config.batch_size = 500

    # Confirm GPU availability with JAX
    print("JAX devices:", jax.devices())

    # Optionally, keep TensorFlow off GPU to avoid conflicts (uncomment if needed)
    # tf.config.set_visible_devices([], "GPU") # See https://github.com/tensorflow/tensorflow/issues/53831

    rng = random.PRNGKey(config.seed)
    model = ViTModel(num_layers=args.num_layers, num_heads=args.num_heads)
    stuff = make_stuff(model)

    with timeblock("load_datasets"):
        train_ds, test_ds = load_datasets()

    num_train_examples = train_ds["images_u8"].shape[0]
    num_test_examples = test_ds["images_u8"].shape[0]
    assert num_train_examples % config.batch_size == 0

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(init_value=1e-6, peak_value=config.learning_rate, warmup_steps=10, decay_steps=config.num_epochs * (num_train_examples // config.batch_size))
        tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        tx = optax.adam(config.learning_rate)
    else:
        tx = optax.adamw(config.learning_rate, weight_decay=1e-4)

    train_state = TrainState.create(apply_fn=model.apply, params=model.init(rngmix(rng, "init"), jnp.zeros((1, 28, 28, 1)))["params"], tx=tx)

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    for epoch in tqdm(range(config.num_epochs)):
        infos = []
        with timeblock(f"Epoch"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape((-1, config.batch_size))
            for i in range(batch_ix.shape[0]):
                p = batch_ix[i, :]
                train_state, info = stuff["step"](train_state, train_ds["images_u8"][p, :, :, :], train_ds["labels"][p])
                infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 10_000)

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        metrics_str = f"_trainloss_{train_loss:.4f}_testloss_{test_loss:.4f}_trainacc_{train_accuracy:.4f}_testacc_{test_accuracy:.4f}"
        weights_file = (
            f"{args.ckpt_path}/mnist_vit_rope_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{model.num_layers}_patch_size_{model.patch_size}_num_heads_{model.num_heads}_hidden_dim_{model.hidden_dim}_embedding_dim_{model.embedding_dim}_epoch{epoch}{metrics_str}"
        )
        with open(weights_file, "wb") as f:
            f.write(flax.serialization.to_bytes(train_state.params))

        # Plot and save metrics
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        # Loss subplot
        axs[0].plot(train_losses, label='Train Loss')
        axs[0].plot(test_losses, label='Test Loss')
        axs[0].set_xlabel('Epoch')
        axs[0].set_ylabel('Loss')
        axs[0].set_title('Training and Test Loss over Epochs')
        axs[0].legend()
        # Accuracy subplot
        axs[1].plot(train_accuracies, label='Train Accuracy')
        axs[1].plot(test_accuracies, label='Test Accuracy')
        axs[1].set_xlabel('Epoch')
        axs[1].set_ylabel('Accuracy')
        axs[1].set_title('Training and Test Accuracy over Epochs')
        axs[1].legend()
        plt.tight_layout()
        plt.savefig(f"{args.ckpt_path}/metrics_plot.png")
        plt.close()

if __name__ == "__main__":
    main()