import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import value_and_grad, lax
from jax import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from flax.jax_utils import replicate, unreplicate

# Import from the local IMDB datamodule and utils
from .datamodule import load_imdb_review_data, pad_sequences
from ..utils import rngmix, timeblock

def main():
    parser = argparse.ArgumentParser(description="Finetune BERT attention layers on IMDB Reviews")
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained BERT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers in the model")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of transformer heads in the model")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Comma-separated indices of attention layers to finetune, or 'all'")
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=6)
    parser.add_argument("--ckpt-path", type=str, default="./checkpoints_finetune_imdb", help="Path to save finetuned ckpts")
    parser.add_argument("--data-path", type=str, required=True, help="Path to IMDB CSV data file")
    args = parser.parse_args()

    if not args.rope_use:
        from .imdbreview_bert_train import BertModel, train_step, eval_step
    else:
        from .imdbreview_bert_train_rope import BertModel, train_step, eval_step

    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices")
    assert args.batch_size % num_devices == 0, "Batch size must be divisible by the number of devices"
    per_device_batch_size = args.batch_size // num_devices

    os.makedirs(args.ckpt_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)
    
    finetune_layer_indices = list(range(args.num_layers)) if args.finetune_layer_which.lower() == "all" else [int(idx) for idx in args.finetune_layer_which.split(",")]

    # --- Data Loading and Preprocessing ---
    print("Loading and preprocessing IMDB data...")
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_imdb_review_data(args.data_path, max_seq_len=256)
    train_ds = {"token_ids": x_train, "labels": np.array(y_train, dtype=np.int32)}
    test_ds = {"token_ids": x_test, "labels": np.array(y_test, dtype=np.int32)}

    # --- Load and Prepare Model Parameters ---
    model = BertModel(num_layers=args.num_layers, num_heads=args.num_heads, num_classes=num_classes, max_seq_len=256)
    
    # Load pretrained parameters
    with open(args.model_path, "rb") as f:
        serialized_params = f.read()
    
    rng, init_rng = random.split(rng)
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    dummy_params = model.init({'params': init_rng, 'dropout': init_rng}, dummy_input)['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)

    # Initialize a new model to get fresh weights for attention layers
    rng, new_init_rng = random.split(rng)
    new_params = model.init({'params': new_init_rng, 'dropout': new_init_rng}, dummy_input)['params']

    # Combine parameters: copy pretrained, then overwrite specified attention layers
    combined_params = pretrained_params.copy()

    for idx in finetune_layer_indices:
        layer_key = f'TransformerEncoderLayer_{idx}'
        attention_key = 'MultiHeadDotProductAttention_0'
        for component in ['key', 'query', 'value', 'out']:
            combined_params[layer_key][attention_key][component] = new_params[layer_key][attention_key][component]
    
    def get_labels(params):
        def label_fn(path, _):
            if len(path) == 4 and path[1].key == 'MultiHeadDotProductAttention_0' and path[2].key in ['key', 'query', 'value', 'out'] and path[3].key in ['bias', 'kernel']:
                layer_key = path[0].key
                if layer_key.startswith('TransformerEncoderLayer_'):
                    idx = int(layer_key.split('_')[-1])
                    if idx in finetune_layer_indices:
                        return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)

    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    jax.tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    if args.optimizer == "sgd":
        base_tx = optax.sgd(args.learning_rate, momentum=0.9)
    elif args.optimizer == "adamw":
        base_tx = optax.adamw(args.learning_rate, weight_decay=1e-3)
    else:  # adam
        base_tx = optax.adam(args.learning_rate)

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # --- Create Training State ---
    train_state = TrainState.create(
        apply_fn=model.apply,
        params=combined_params,
        tx=tx
    )

    # Replicate across devices
    train_state = replicate(train_state)

    # Pmap the steps
    p_train_step = jax.pmap(train_step, axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # --- Training Loop ---
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []

    patience = 3
    best_test_loss = float('inf')
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    num_train_examples = len(train_ds["labels"])
    num_test_examples = len(test_ds["labels"])

    def shard(array):
        return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        rng, epoch_rng = random.split(rng)
        
        perm = jax.random.permutation(epoch_rng, num_train_examples)
        
        batch_metrics = []
        pbar = tqdm(range(0, num_train_examples, args.batch_size), desc="Training")
        for i in pbar:
            batch_perm = perm[i:i+args.batch_size]
            
            # Handle last batch if it's smaller
            current_batch_size = len(batch_perm)
            if current_batch_size < args.batch_size:
                continue # Simple solution: skip smaller batches to avoid pmap issues

            token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm]
            labels = train_ds["labels"][batch_perm]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels)
            }
            
            rng, step_rng = random.split(rng)
            rngs = random.split(step_rng, num_devices)
            train_state, metrics = p_train_step(train_state, batch, rngs)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({"loss": f"{metrics['loss']:.4f}", "acc": f"{metrics['accuracy']:.4f}"})

        epoch_train_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_train_acc = np.mean([m['accuracy'] for m in batch_metrics])
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        batch_metrics = []
        pbar = tqdm(range(0, num_test_examples, args.batch_size), desc="Evaluating")
        for i in pbar:
            end = min(i + args.batch_size, num_test_examples)
            current_batch_size = end - i
            if current_batch_size < args.batch_size:
                continue # Skip smaller batches

            token_ids = test_ds["token_ids"][i:end]
            labels = test_ds["labels"][i:end]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)

            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels)
            }
            metrics = p_eval_step(train_state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({"loss": f"{metrics['loss']:.4f}", "acc": f"{metrics['accuracy']:.4f}"})

        epoch_test_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_test_acc = np.mean([m['accuracy'] for m in batch_metrics])
        test_losses.append(epoch_test_loss)
        test_accuracies.append(epoch_test_acc)

        print(f"Epoch {epoch+1} Summary: "
              f"Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Test Acc={epoch_test_acc:.4f}")

        if epoch_test_loss < best_test_loss:
            best_test_loss = epoch_test_loss
            best_params = unreplicate(train_state).params
            best_epoch = epoch + 1
            epochs_since_improvement = 0
            print(f"New best test loss: {best_test_loss:.4f}. Saving model for epoch {best_epoch}.")
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
    
    # --- Save Best Model and Metrics ---
    if best_params is not None:
        with timeblock("Model serialization"):
            finetune_layer_str = "all" if args.finetune_layer_which.lower() == "all" else ".".join(str(x) for x in finetune_layer_indices)
            best_metrics = f"testloss_{test_losses[best_epoch-1]:.4f}_testacc_{test_accuracies[best_epoch-1]:.4f}"
            weights_file = (
                f"imdb_bert_attn_finetune_seed{args.seed}_"
                f"opt_{args.optimizer}_lr_{args.learning_rate}_L{args.num_layers}_H{args.num_heads}_"
                f"finetune_layers_{finetune_layer_str}_best_epoch{best_epoch}_{best_metrics}.flax"
            )
            with open(os.path.join(args.ckpt_path, weights_file), "wb") as f:
                f.write(flax.serialization.to_bytes(best_params))
            print(f"Saved best model to {os.path.join(args.ckpt_path, weights_file)}")

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss over Epochs'); plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy over Epochs'); plt.legend()
    
    plot_filename = f"imdb_metrics_plot_seed{args.seed}_finetune.png"
    plt.tight_layout()
    plt.savefig(os.path.join(args.ckpt_path, plot_filename))
    plt.close()
    print(f"Saved metrics plot to {os.path.join(args.ckpt_path, plot_filename)}")


if __name__ == "__main__":
    main()