import os
import heapq
from absl import logging
import flax
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
from flax import traverse_util
from vit_jax import checkpoint, input_pipeline, utils, models, train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

logging.set_verbosity(logging.INFO)

def main():
    # Configuration
    ckpt_path = '/'
    os.makedirs(ckpt_path, exist_ok=True)
    model_name = 'ViT-S_16'
    model_path = f"/weights/ViT-S_16.npz"
    assert os.path.exists(model_path), f"Model path {model_path} does not exist."

    dataset = 'cifar10'
    batch_size = 512
    num_epochs = 20  # Increased from original total_steps=100 (~1 epoch) for monitoring
    #patience = 5

    # Load dataset
    config = common_config.with_dataset(common_config.get_config(), dataset)
    config.batch = batch_size
    config.pp.crop = 224
    ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
    ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
    num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
    num_train_examples = input_pipeline.get_dataset_info(dataset, 'train')['num_examples']
    steps_per_epoch = num_train_examples // batch_size  # ~97 for CIFAR-10 with batch_size=512
    total_steps = num_epochs * steps_per_epoch

    # Model initialization
    model_config = models_config.MODEL_CONFIGS[model_name]
    model = models.VisionTransformer(num_classes=num_classes, **model_config)
    variables = jax.jit(lambda: model.init(
        jax.random.PRNGKey(0),
        next(iter(ds_train.as_numpy_iterator()))['image'][0, :1],
        train=False,
    ), backend='cpu')()
    params = checkpoint.load_pretrained(
        pretrained_path=model_path,
        init_params=variables['params'],
        model_config=model_config,
    )

    # Cast boolean parameters to float32
    params = jax.tree_util.tree_map(
        lambda x: x.astype(jnp.float32),
        params
    )

    # Replicate parameters across devices
    params_repl = flax.jax_utils.replicate(params)
    vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
        dict(params=params), inputs, train=False))

    # Optimizer and learning rate schedule
    warmup_steps = 5
    decay_type = 'cosine'
    grad_norm_clip = 1
    accum_steps = 8
    base_lr = 0.0005
    lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
    tx = optax.chain(
        optax.clip_by_global_norm(grad_norm_clip),
        optax.sgd(
            learning_rate=lr_fn,
            momentum=0.9,
            accumulator_dtype='bfloat16',
        ),
    )
    update_fn_repl = train.make_update_fn(apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
    opt_state = tx.init(params)
    opt_state_repl = flax.jax_utils.replicate(opt_state)
    update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

    # Metrics storage
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    learning_rates = []

    # Best models tracking
    best_models = []  # Heap of (-test_loss, epoch, filename) for top 5 lowest test losses

    # Early stopping
    best_test_loss = float('inf')
    epochs_since_improvement = 0

    # Define test metrics computation
    @jax.pmap
    def compute_batch_metrics(params, batch):
        logits = model.apply({'params': params}, batch['image'], train=False)
        loss = -jnp.sum(jax.nn.log_softmax(logits) * batch['label']) / batch['label'].shape[0]
        accuracy = (logits.argmax(axis=-1) == batch['label'].argmax(axis=-1)).mean()
        return loss, accuracy

    def evaluate_test(params_repl):
        test_loss_sum = 0.0
        test_accuracy_sum = 0.0
        num_batches = 0
        test_iter = iter(ds_test.as_numpy_iterator())
        for batch in test_iter:
            loss, accuracy = compute_batch_metrics(params_repl, batch)
            test_loss_sum += jnp.mean(loss)
            test_accuracy_sum += jnp.mean(accuracy)
            num_batches += 1
        return test_loss_sum / num_batches, test_accuracy_sum / num_batches

    test_loss, test_accuracy = evaluate_test(params_repl)
    print(f'Start: testloss_{test_loss:.4f}_testacc_{test_accuracy:.4f}')

    # Training loop
    train_iter = iter(ds_train.as_numpy_iterator())
    global_step = 0
    for epoch in range(num_epochs):
        train_loss_sum = 0.0
        train_correct_sum = 0.0
        for _ in range(steps_per_epoch):
            batch = next(train_iter)
            params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
                params_repl, opt_state_repl, batch, update_rng_repl
            )
            loss = jnp.mean(loss_repl)
            train_loss_sum += loss
            predicted = vit_apply_repl(params_repl, batch['image'])
            is_correct = (predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1))
            batch_accuracy = jnp.mean(is_correct)
            train_correct_sum += batch_accuracy
            global_step += 1

        # Compute epoch metrics
        train_loss = train_loss_sum / steps_per_epoch
        train_accuracy = train_correct_sum / steps_per_epoch
        train_losses.append(float(train_loss))
        train_accuracies.append(float(train_accuracy))

        # Evaluate on test set
        test_loss, test_accuracy = evaluate_test(params_repl)
        test_losses.append(float(test_loss))
        test_accuracies.append(float(test_accuracy))

        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        # Learning rate
        lr = lr_fn(global_step)
        learning_rates.append(float(lr))

        metrics_str = f"_trainloss_{train_loss:.4f}_testloss_{test_loss:.4f}_trainacc_{train_accuracy:.4f}_testacc_{test_accuracy:.4f}"
        weights_file = f"{ckpt_path}/cifar10_vit_epoch{epoch}{metrics_str}.npz"
        with open(weights_file, "wb") as f:
            #f.write(flax.serialization.to_bytes(flax.jax_utils.unreplicate(params_repl)))
            np.savez(f, **traverse_util.flatten_dict(flax.jax_utils.unreplicate(params_repl), sep='/'))
            
    # Save plots
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f'{ckpt_path}/loss_plot.png')
    plt.close()

    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f'{ckpt_path}/accuracy_plot.png')
    plt.close()

    plt.figure()
    plt.plot(learning_rates, label='Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.savefig(f'{ckpt_path}/lr_plot.png')
    plt.close()

    print("Training completed. Plots and best models saved.")

if __name__ == "__main__":
    main()