import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import random, value_and_grad
from jax.tree_util import tree_map
import jax.tree_util as tree_util
from tqdm import tqdm
from flax.traverse_util import flatten_dict, unflatten_dict

from src.utils import flatten_params, rngmix, timeblock
#from .mnist_vit_train import ViTModel, make_stuff, load_datasets

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of transformer heads")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer(s) to finetune")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()


    if not args.rope_use:
        from .mnist_vit_train import ViTModel, make_stuff, load_datasets
    else:
        from .mnist_vit_train_rope import ViTModel, make_stuff, load_datasets

    # Configuration
    class Config:
        pass
    config = Config()
    config.seed = args.seed
    config.model_path = args.model_path
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_layers = args.num_layers 
    config.num_heads = args.num_heads
    config.finetune_layer_which = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]

    config.num_epochs = 100
    config.batch_size = 100
    
    # Load MNIST datasets
    with timeblock("Load datasets"):
        train_ds, test_ds = load_datasets(data_dir="/log-lmc_attn-mnist/data")
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    # Load pre-trained model
    model = ViTModel(num_layers=config.num_layers, num_heads=config.num_heads)
    with open(config.model_path, "rb") as f:
        serialized_params = f.read()
    dummy_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)
    # Define fine-tuning model
    finetune_model = model
    finetune_model.finetune_layer_which = config.finetune_layer_which
    rng = jax.random.PRNGKey(config.seed)
    new_params = finetune_model.init(rng, jnp.ones((1, 28, 28, 1)))['params']

    # Combine parameters: copy pre-trained weights except for some Attention components 
    combined_params = pretrained_params.copy()

    for idx in finetune_layer_indices:
        layer_key = f'TransformerEncoderLayer_{idx}'
        attention_key = 'MultiHeadDotProductAttention_0'
        for component in ['key', 'query', 'value', 'out']:
            combined_params[layer_key][attention_key][component] = new_params[layer_key][attention_key][component]

    def get_labels(params):
        def label_fn(path, _):
            if len(path) == 4 and path[1].key == 'MultiHeadDotProductAttention_0' and path[2].key in ['key', 'query', 'value', 'out'] and path[3].key in ['bias', 'kernel']:
                layer_key = path[0].key
                if layer_key.startswith('TransformerEncoderLayer_'):
                    idx = int(layer_key.split('_')[-1])
                    if idx in finetune_model.finetune_layer_which:
                        return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)

    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    #########

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (num_train_examples // config.batch_size)
        )
        base_tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        base_tx = optax.adam(config.learning_rate)
    else:  # adamw
        base_tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Create training state
    train_state = TrainState.create(
        apply_fn=finetune_model.apply,
        params=combined_params,
        tx=tx
    )
    stuff = make_stuff(finetune_model)

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Early stopping setup
    patience = 10
    best_test_accuracy = 0.0
    best_test_loss = 99999
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    # Training loop
    for epoch in tqdm(range(config.num_epochs), desc="Epochs"):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape(
                (-1, config.batch_size)
            )
        for i in range(batch_ix.shape[0]):
            p = batch_ix[i, :]
            images_u8 = train_ds["images_u8"][p, :, :, :]
            labels = train_ds["labels"][p]
            train_state, info = stuff["step"](train_state, images_u8, labels)
            infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params,
                test_ds,
                10000  # MNIST test set size
            )

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_params = train_state.params
            best_epoch = epoch
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    # Save fine-tuned model
    with timeblock("Model serialization"):
        metrics_str = f"trainloss_{best_metric['train_loss']}_testloss_{best_metric['test_loss']}_trainacc_{best_metric['train_accuracy']}_testacc_{best_metric['test_accuracy']}"
        finetune_layer_str = ".".join(str(x) for x in finetune_model.finetune_layer_which)
        weights_file = (
            f"mnist_vit_attn_finetune_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{finetune_model.num_layers}_finetune_layer_which_{finetune_layer_str}_"
            f"patch_size_{finetune_model.patch_size}_num_heads_{finetune_model.num_heads}_hidden_dim_{finetune_model.hidden_dim}_embedding_dim_{finetune_model.embedding_dim}_best_epoch{best_epoch}_{metrics_str}"
        )
        with open(f'{args.ckpt_path}/{weights_file}', "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))

    # Plot and save metrics
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    # Loss subplot
    axs[0].plot(train_losses, label='Train Loss')
    axs[0].plot(test_losses, label='Test Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[0].set_title('Training and Test Loss over Epochs')
    axs[0].legend()
    # Accuracy subplot
    axs[1].plot(train_accuracies, label='Train Accuracy')
    axs[1].plot(test_accuracies, label='Test Accuracy')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].set_title('Training and Test Accuracy over Epochs')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}metrics_plot.png")
    plt.close()

if __name__ == "__main__":
    main()