import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from flax.jax_utils import replicate, unreplicate
from datasets import Dataset
from typing import List

# --- Utility Functions (from dbpedia_bert_train.py) ---

def pad_sequences(sequences: List[np.ndarray], pad_value: int = 0) -> np.ndarray:
    """Pads a list of sequences to the same length."""
    max_len = max(len(seq) for seq in sequences)
    return np.stack([
        np.pad(seq, (0, max_len - len(seq)), constant_values=pad_value)
        for seq in sequences
    ])

# --- Main Fine-tuning Logic ---

def main():
    parser = argparse.ArgumentParser(description="Finetune BERT attention layers on DBpedia")
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained BERT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers in the model")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of transformer heads in the model")
    parser.add_argument("--embedding-dim", type=int, default=48, help="Embedding dimension of the model")
    parser.add_argument("--hidden-dim", type=int, default=192, help="Hidden dimension of the model's FFN")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Comma-separated indices of attention layers to finetune, or 'all'")
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=4)
    parser.add_argument("--ckpt-path", type=str, default="./checkpoints_finetune_dbpedia", help="Path to save finetuned ckpts")
    parser.add_argument("--train-dataset-path", type=str, required=True, help="Path to train parquet")
    parser.add_argument("--test-dataset-path", type=str, required=True, help="Path to test parquet")
    args = parser.parse_args()

    # --- Import correct model and training steps based on RoPE flag ---
    if not args.rope_use:
        from .dbpedia_bert_train import BertModel, train_step, eval_step
    else:
        from .dbpedia_bert_train_rope import BertModel, train_step, eval_step

    # --- JAX Setup ---
    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices")
    assert args.batch_size % num_devices == 0, "Batch size must be divisible by the number of devices"
    per_device_batch_size = args.batch_size // num_devices

    os.makedirs(args.ckpt_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)
    
    finetune_layer_indices = list(range(args.num_layers)) if args.finetune_layer_which.lower() == "all" else [int(idx) for idx in args.finetune_layer_which.split(",")]
    print(f"Fine-tuning attention weights for layers: {finetune_layer_indices}")

    # --- Data Loading and Preprocessing ---
    print("Loading DBpedia data...")
    train_dataset = Dataset.from_parquet(args.train_dataset_path)
    test_dataset = Dataset.from_parquet(args.test_dataset_path)
    x_train, y_train = train_dataset['input_ids'], train_dataset['label']
    x_test, y_test = test_dataset['input_ids'], test_dataset['label']
    
    PAD = 0
    train_ds = {"token_ids": x_train, "labels": np.array(y_train, dtype=np.int32)}
    test_ds = {"token_ids": x_test, "labels": np.array(y_test, dtype=np.int32)}

    # --- Load and Prepare Model for Fine-tuning ---
    model = BertModel(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        num_layers=args.num_layers, 
        num_heads=args.num_heads
    )
    
    # Load pretrained parameters
    with open(args.model_path, "rb") as f:
        serialized_params = f.read()
    
    rng, init_rng = random.split(rng)
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    dummy_params = model.init({'params': init_rng, 'dropout': init_rng}, dummy_input)['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)

    # Initialize a new model to get fresh weights for the attention layers
    rng, new_init_rng = random.split(rng)
    new_params = model.init({'params': new_init_rng, 'dropout': new_init_rng}, dummy_input)['params']

    # Combine parameters: use pretrained weights, but overwrite specified attention layers with new ones
    combined_params = pretrained_params.copy()

    for idx in finetune_layer_indices:
        layer_key = f'TransformerEncoderLayer_{idx}'
        attention_key = 'MultiHeadDotProductAttention_0'
        for component in ['key', 'query', 'value', 'out']:
            combined_params[layer_key][attention_key][component] = new_params[layer_key][attention_key][component]
    
    # --- Setup Optimizer to Freeze Non-Attention Layers ---
    def get_labels(params):
        """Creates a PyTree of labels ('trainable' or 'frozen') for the optimizer."""
        def label_fn(path, _):
            # Path is a tuple of PathEntry objects
            # Example path: ('TransformerEncoderLayer_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')
            if len(path) == 4 and path[1].key == 'MultiHeadDotProductAttention_0' and path[2].key in ['key', 'query', 'value', 'out']:
                layer_key = path[0].key
                if layer_key.startswith('TransformerEncoderLayer_'):
                    idx = int(layer_key.split('_')[-1])
                    if idx in finetune_layer_indices:
                        return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)

    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    jax.tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    if args.optimizer == "sgd":
        base_tx = optax.sgd(args.learning_rate, momentum=0.9)
    elif args.optimizer == "adamw":
        base_tx = optax.adamw(args.learning_rate, weight_decay=1e-3)
    else:  # adam
        base_tx = optax.adam(args.learning_rate)

    # Use multi_transform to apply the optimizer only to 'trainable' parts
    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # --- Create and Replicate Training State ---
    train_state = TrainState.create(apply_fn=model.apply, params=combined_params, tx=tx)
    train_state = replicate(train_state)

    # Pmap the training and evaluation steps
    p_train_step = jax.pmap(train_step, axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # --- Training Loop ---
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []

    patience = 3
    best_test_loss = float('inf')
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    num_train_examples = len(train_ds["labels"])
    num_test_examples = len(test_ds["labels"])

    def shard(array):
        """Reshapes an array for multi-device processing."""
        return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        rng, epoch_rng = random.split(rng)
        
        perm = jax.random.permutation(epoch_rng, num_train_examples)
        
        # Train
        batch_metrics = []
        pbar = tqdm(range(0, num_train_examples, args.batch_size), desc="Training")
        for i in pbar:
            batch_perm = perm[i:i+args.batch_size]
            current_batch_size = len(batch_perm)
            pad_size = args.batch_size - current_batch_size

            # Skip last batch if it's smaller to avoid pmap issues
            if pad_size > 0:
                continue

            token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm]
            labels = train_ds["labels"][batch_perm]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.ones((args.batch_size,)) # All examples are valid
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            
            rng, step_rng = random.split(rng)
            rngs = random.split(step_rng, num_devices)
            train_state, metrics = p_train_step(train_state, batch, rngs)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({"loss": f"{metrics['loss']:.4f}", "acc": f"{metrics['accuracy']:.4f}"})

        epoch_train_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_train_acc = np.mean([m['accuracy'] for m in batch_metrics])
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        # Evaluate
        batch_metrics = []
        pbar = tqdm(range(0, num_test_examples, args.batch_size), desc="Evaluating")
        for i in pbar:
            end = min(i + args.batch_size, num_test_examples)
            current_batch_size = end - i
            
            if current_batch_size < args.batch_size:
                continue # Skip smaller batches

            token_ids = test_ds["token_ids"][i:end]
            labels = test_ds["labels"][i:end]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.ones((args.batch_size,))

            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            metrics = p_eval_step(train_state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({"loss": f"{metrics['loss']:.4f}", "acc": f"{metrics['accuracy']:.4f}"})

        epoch_test_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_test_acc = np.mean([m['accuracy'] for m in batch_metrics])
        test_losses.append(epoch_test_loss)
        test_accuracies.append(epoch_test_acc)

        print(f"Epoch {epoch+1} Summary: "
              f"Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Test Acc={epoch_test_acc:.4f}")

        # --- Early Stopping and Best Model Saving ---
        if epoch_test_loss < best_test_loss:
            best_test_loss = epoch_test_loss
            best_params = unreplicate(train_state).params
            best_epoch = epoch + 1
            epochs_since_improvement = 0
            print(f"New best test loss: {best_test_loss:.4f}. Saving model for epoch {best_epoch}.")
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break
    
    # --- Save Best Model and Metrics Plot ---
    if best_params is not None:
        finetune_layer_str = "all" if args.finetune_layer_which.lower() == "all" else ".".join(str(x) for x in finetune_layer_indices)
        best_metrics = f"testloss_{test_losses[best_epoch-1]:.4f}_testacc_{test_accuracies[best_epoch-1]:.4f}"
        weights_file = (
            f"dbpedia_bert_attn_finetune_seed{args.seed}_"
            f"opt_{args.optimizer}_lr_{args.learning_rate}_L{args.num_layers}_H{args.num_heads}_"
            f"finetune_layers_{finetune_layer_str}_best_epoch{best_epoch}_{best_metrics}.flax"
        )
        with open(os.path.join(args.ckpt_path, weights_file), "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))
        print(f"Saved best model to {os.path.join(args.ckpt_path, weights_file)}")

    # Plot and save metrics
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].plot(train_losses, label='Train Loss')
    axs[0].plot(test_losses, label='Test Loss')
    axs[0].set_xlabel('Epoch'); axs[0].set_ylabel('Loss'); axs[0].set_title('Loss over Epochs'); axs[0].legend()
    
    axs[1].plot(train_accuracies, label='Train Accuracy')
    axs[1].plot(test_accuracies, label='Test Accuracy')
    axs[1].set_xlabel('Epoch'); axs[1].set_ylabel('Accuracy'); axs[1].set_title('Accuracy over Epochs'); axs[1].legend()
    
    plot_filename = f"dbpedia_metrics_plot_seed{args.seed}_finetune_{finetune_layer_str}.png"
    plt.tight_layout()
    plt.savefig(os.path.join(args.ckpt_path, plot_filename))
    plt.close()
    print(f"Saved metrics plot to {os.path.join(args.ckpt_path, plot_filename)}")


if __name__ == "__main__":
    main()