import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import random, value_and_grad
from jax.tree_util import tree_map
import jax.tree_util as tree_util
from tqdm import tqdm
from flax.traverse_util import flatten_dict, unflatten_dict

from src.utils import flatten_params, rngmix, timeblock
from src.datasets import load_cifar10
#from .cifar10_vit_train import ViTModel, make_stuff

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of transformer heads")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer(s) to finetune")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    if not args.rope_use:
        from .cifar10_vit_train import ViTModel, make_stuff
    else:
        from .cifar10_vit_train_rope import ViTModel, make_stuff

    # Configuration
    class Config:
        pass
    config = Config()
    config.seed = args.seed
    config.model_path = args.model_path
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_layers = args.num_layers 
    config.num_heads = args.num_heads
    config.finetune_layer_which = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]

    config.num_epochs = 100
    config.batch_size = 100
    
    # Load pre-trained model
    model = ViTModel(num_layers=config.num_layers, num_heads=config.num_heads)
    with open(config.model_path, "rb") as f:
        serialized_params = f.read()
    dummy_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)
    # Define fine-tuning model
    finetune_model = model
    finetune_model.finetune_layer_which = config.finetune_layer_which
    rng = jax.random.PRNGKey(config.seed)
    new_params = finetune_model.init(rng, jnp.ones((1, 32, 32, 3)))['params']

    '''num_heads = 4
    #print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, pretrained_params)
    print("Layers of the model:")
    for path, shape in layer_paths:  
        print(f"  {path} {shape}")'''
    '''example output
    Layers of the model:
    Conv_0/bias (32,)
    Conv_0/kernel (4, 4, 3, 32)
    Dense_0/bias (10,)
    Dense_0/kernel (32, 10)
    TransformerEncoderLayer_0/Dense_0/bias (128,)
    TransformerEncoderLayer_0/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_0/Dense_1/bias (32,)
    TransformerEncoderLayer_0/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_0/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_0/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_0/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_0/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8) #Key projections for 4 attention heads, each with a dimension of 8 (4 heads x 8 = 32, matching the model's hidden size).
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    TransformerEncoderLayer_1/Dense_0/bias (128,)
    TransformerEncoderLayer_1/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_1/Dense_1/bias (32,)
    TransformerEncoderLayer_1/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_1/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_1/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_1/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_1/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    TransformerEncoderLayer_2/Dense_0/bias (128,)
    TransformerEncoderLayer_2/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_2/Dense_1/bias (32,)
    TransformerEncoderLayer_2/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_2/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_2/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_2/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_2/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    cls_token (1, 1, 32)
    pos_embedding (1, 65, 32)
    '''

    # Combine parameters: copy pre-trained weights except for some Attention components 
    combined_params = pretrained_params.copy()

    for idx in finetune_model.finetune_layer_which:
        layer_key = f'TransformerEncoderLayer_{idx}'
        attention_key = 'MultiHeadDotProductAttention_0'
        for component in ['key', 'query', 'value', 'out']:
            combined_params[layer_key][attention_key][component] = new_params[layer_key][attention_key][component]

    def get_labels(params):
        def label_fn(path, _):
            if True:#len(path) == 4 and path[1].key in 'MultiHeadDotProductAttention_0' and path[2].key in ['key', 'query', 'value', 'out'] and path[3].key in ['bias', 'kernel']:
                layer_key = path[0].key
                if layer_key.startswith('TransformerEncoderLayer_'):
                    idx = int(layer_key.split('_')[-1])
                    if idx in finetune_model.finetune_layer_which:
                        return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)

    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")
    #########

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (50000 // config.batch_size)
        )
        base_tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        base_tx = optax.adam(config.learning_rate)
    else:  # adamw
        base_tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Create training state
    train_state = TrainState.create(
        apply_fn=finetune_model.apply,
        params=combined_params,
        tx=tx
    )
    stuff = make_stuff(finetune_model)

    # Load CIFAR-10 datasets
    with timeblock("Load datasets"):
        train_ds, test_ds = load_cifar10(data_dir="/root/log/cifar10/data")
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Early stopping setup
    patience = 10
    best_test_accuracy = 0.0
    best_test_loss = 99999
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    # Training loop
    for epoch in tqdm(range(config.num_epochs), desc="Epochs"):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape(
                (-1, config.batch_size)
            )
        batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0])
        for i in range(batch_ix.shape[0]):
            p = batch_ix[i, :]
            images_u8 = train_ds["images_u8"][p, :, :, :]
            labels = train_ds["labels"][p]
            train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels)
            infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params,
                test_ds,
                2000  # CIFAR-10 test set size
            )

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_params = train_state.params
            best_epoch = epoch
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    # Save fine-tuned model
    with timeblock("Model serialization"):
        metrics_str = f"trainloss_{best_metric['train_loss']}_testloss_{best_metric['test_loss']}_trainacc_{best_metric['train_accuracy']}_testacc_{best_metric['test_accuracy']}"
        finetune_layer_str = ".".join(str(x) for x in finetune_model.finetune_layer_which)
        weights_file = (
            f"cifar10_vit_attn_finetune_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{finetune_model.num_layers}_finetune_layer_which_{finetune_layer_str}_"
            f"patch_size_{finetune_model.patch_size}_num_heads_{finetune_model.num_heads}_hidden_dim_{finetune_model.hidden_dim}_embedding_dim_{finetune_model.embedding_dim}_best_epoch{best_epoch}_{metrics_str}"
        )
        with open(f'{args.ckpt_path}/{weights_file}', "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].plot(train_losses, label='Train Loss')
    axs[0].plot(test_losses, label='Test Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[0].set_title('Training and Test Loss over Epochs')
    axs[0].legend()
    axs[1].plot(train_accuracies, label='Train Accuracy')
    axs[1].plot(test_accuracies, label='Test Accuracy')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].set_title('Training and Test Accuracy over Epochs')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}metrics_plot.png")
    plt.close()

if __name__ == "__main__":
    main()