import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, tree_map, value_and_grad, vmap
from tqdm import tqdm
import matplotlib.pyplot as plt

from src.utils import flatten_params, rngmix, timeblock  

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False)(x)
        attn_output = nn.MultiHeadDotProductAttention(num_heads=self.num_heads, qkv_features=self.embedding_dim)(x_norm, x_norm)
        x = x + attn_output
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False)(x)
        ffn_output = nn.Dense(self.hidden_dim)(x_norm)
        ffn_output = nn.gelu(ffn_output)
        ffn_output = nn.Dense(self.embedding_dim)(ffn_output)
        x = x + ffn_output
        return x

# Vision Transformer Model
class ViTModel(nn.Module):
    patch_size: int = 7
    embedding_dim: int = 32 
    num_heads: int = 1
    num_layers: int = 2
    hidden_dim: int = 32 
    num_classes: int = 10

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.embedding_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding="VALID")(x)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        pos_embedding = self.param('pos_embedding', nn.initializers.normal(stddev=0.02), (1, x.shape[1], self.embedding_dim))
        x = x + pos_embedding
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(embedding_dim=self.embedding_dim, num_heads=self.num_heads, hidden_dim=self.hidden_dim)(x)
        x = x[:, 0, :]
        x = nn.Dense(self.num_classes)(x)
        return x

# Utility Functions
def make_stuff(model):
    normalize_transform = augmax.ByteToFloat()

    @jit
    def batch_eval(params, images_u8, labels):
        images_f32 = vmap(normalize_transform)(None, images_u8)
        logits = model.apply({"params": params}, images_f32)
        y_onehot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
        num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
        return loss, {"num_correct": num_correct}

    @jit
    def step(train_state, images_u8, labels):
        (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_u8, labels)
        return train_state.apply_gradients(grads=g), {"batch_loss": l, **info}

    def dataset_loss_and_accuracy(params, dataset, batch_size: int):
        num_examples = dataset["images_u8"].shape[0]
        assert num_examples % batch_size == 0
        num_batches = num_examples // batch_size
        batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size))
        losses, infos = zip(*[batch_eval(params, dataset["images_u8"][batch_ix[i, :], :, :, :], dataset["labels"][batch_ix[i, :]]) for i in range(num_batches)])
        return jnp.sum(batch_size * jnp.array(losses)) / num_examples, sum(x["num_correct"] for x in infos) / num_examples

    return {"normalize_transform": normalize_transform, "batch_eval": batch_eval, "step": step, "dataset_loss_and_accuracy": dataset_loss_and_accuracy}

# Load Datasets
def load_datasets():
    train_ds_images_u8, train_ds_labels = tfds.as_numpy(tfds.load("mnist", split="train", batch_size=-1, as_supervised=True))
    test_ds_images_u8, test_ds_labels = tfds.as_numpy(tfds.load("mnist", split="test", batch_size=-1, as_supervised=True))
    return {"images_u8": train_ds_images_u8, "labels": train_ds_labels}, {"images_u8": test_ds_images_u8, "labels": test_ds_labels}

# Main Function with GPU and Early Stopping
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True)
    parser.add_argument("--learning-rate", type=float, required=True)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    class Config:
        pass
    config = Config()
    config.seed = args.seed
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 100
    config.batch_size = 500

    # Confirm GPU availability with JAX
    print("JAX devices:", jax.devices())

    rng = random.PRNGKey(config.seed)
    model = ViTModel(num_layers = args.num_layers)
    stuff = make_stuff(model)

    with timeblock("load_datasets"):
        train_ds, test_ds = load_datasets()
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(init_value=1e-6, peak_value=config.learning_rate, warmup_steps=10, decay_steps=config.num_epochs * (num_train_examples // config.batch_size))
        tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        tx = optax.adam(config.learning_rate)
    else:
        tx = optax.adamw(config.learning_rate, weight_decay=1e-4)

    train_state = TrainState.create(apply_fn=model.apply, params=model.init(rngmix(rng, "init"), jnp.zeros((1, 28, 28, 1)))["params"], tx=tx)

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Early stopping setup
    patience = 10
    best_test_accuracy = 0.0
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    for epoch in tqdm(range(config.num_epochs)):
        infos = []
        with timeblock(f"Epoch"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape((-1, config.batch_size))
            for i in range(batch_ix.shape[0]):
                p = batch_ix[i, :]
                train_state, info = stuff["step"](train_state, train_ds["images_u8"][p, :, :, :], train_ds["labels"][p])
                infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 10_000)

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        metrics_str = f"_trainloss_{train_loss:.4f}_testloss_{test_loss:.4f}_trainacc_{train_accuracy:.4f}_testacc_{test_accuracy:.4f}"
        weights_file = (
            f"{args.ckpt_path}/mnist_vit_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{model.num_layers}_hidden_dim_{model.hidden_dim}_embedding_dim_{model.embedding_dim}_num_heads_{model.num_heads}_epoch{epoch}{metrics_str}"
        )
        with open(weights_file, "wb") as f:
            f.write(flax.serialization.to_bytes(train_state.params))
    # Plot and save metrics
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Test Loss over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/loss_plot.png")
    plt.close()
    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Test Accuracy over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/accuracy_plot.png")
    plt.close()

if __name__ == "__main__":
    main()