import argparse
import os
import sys
import json
import logging
import datetime
import math

import torch
import numpy as np
from transformers import (
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from datasets import ClassLabel
from peft import get_peft_model, LoraConfig, TaskType
from omegaconf import OmegaConf
from transformers import TrainerCallback

from training_utils import (
    compute_metrics,
    EnhancedTrainMetricsCallback,
    set_reproducible_seed,
    get_system_info,
    save_confusion_matrix,
    generate_model_card,
    setup_tokenizer,
)
# from nlp_training.pooling import GenericSequenceClassifier
from seq_classifier import GenericSequenceClassifier, GenericSequenceClassifierConfig
from exp_data import get_exp_data_hf


# --- Simple, 2-parameter initializer: (init_type, gain) ---
def _init_lora_A(w: torch.nn.Linear, init_type: str, init_scale: float):
    """Initialize LoRA A matrix with specified method and scale"""
    it = init_type.lower()
    fan_in = w.weight.size(1)

    if it == "uniform":
        stdv = init_scale / math.sqrt(fan_in)
        torch.nn.init.uniform_(w.weight, -stdv, +stdv)

    elif it == "gaussian":
        std = init_scale / math.sqrt(fan_in)
        torch.nn.init.normal_(w.weight, mean=0.0, std=std)

    elif it == "orthogonal":
        torch.nn.init.orthogonal_(w.weight, gain=init_scale)

    elif it == "xavier_uniform":
        torch.nn.init.xavier_uniform_(w.weight, gain=init_scale)

    elif it == "xavier_normal":
        torch.nn.init.xavier_normal_(w.weight, gain=init_scale)

    elif it == "kaiming_uniform":
        # here init_scale is used as 'a'
        torch.nn.init.kaiming_uniform_(w.weight, a=init_scale, mode="fan_in", nonlinearity="leaky_relu")

    elif it == "kaiming_normal":
        torch.nn.init.kaiming_normal_(w.weight, a=init_scale, mode="fan_in", nonlinearity="leaky_relu")

    elif it == "zeros":
        torch.nn.init.zeros_(w.weight)

    else:
        raise ValueError(f"Unknown init_type: {init_type}")


def reinit_lora_weights(model, init_type: str, init_scale: float):
    """
    Reinitialize LoRA weights after PEFT model creation.
    
    Args:
        model: PEFT model with LoRA adapters
        init_type: Initialization method (e.g., 'xavier_uniform', 'kaiming_normal')
        init_scale: Scale parameter for initialization
    
    Returns:
        model: Model with reinitialized LoRA weights
    """
    lora_a_count = 0
    lora_b_count = 0
    
    for name, module in model.named_modules():
        # Handle Linear LoRA layers
        if hasattr(module, 'lora_A') and module.lora_A:
            for adapter_name, lora_A_layer in module.lora_A.items():
                _init_lora_A(lora_A_layer, init_type, init_scale)
                lora_a_count += 1
                
        # Handle LoRA B layers - explicitly set to zeros
        if hasattr(module, 'lora_B') and module.lora_B:
            for adapter_name, lora_B_layer in module.lora_B.items():
                torch.nn.init.zeros_(lora_B_layer.weight)
                if hasattr(lora_B_layer, 'bias') and lora_B_layer.bias is not None:
                    torch.nn.init.zeros_(lora_B_layer.bias)
                lora_b_count += 1
                
        # Handle embedding LoRA layers if present
        if hasattr(module, 'lora_embedding_A') and module.lora_embedding_A:
            for adapter_name, lora_A_param in module.lora_embedding_A.items():
                # For embedding, it's a Parameter, not Linear layer
                # Apply similar initialization logic
                fan_in = lora_A_param.size(1)
                if init_type.lower() == "uniform":
                    stdv = init_scale / math.sqrt(fan_in)
                    torch.nn.init.uniform_(lora_A_param, -stdv, +stdv)
                elif init_type.lower() == "gaussian":
                    std = init_scale / math.sqrt(fan_in)
                    torch.nn.init.normal_(lora_A_param, mean=0.0, std=std)
                elif init_type.lower() == "xavier_uniform":
                    torch.nn.init.xavier_uniform_(lora_A_param, gain=init_scale)
                elif init_type.lower() == "xavier_normal":
                    torch.nn.init.xavier_normal_(lora_A_param, gain=init_scale)
                else:
                    # Fallback to xavier_uniform for embedding parameters
                    torch.nn.init.xavier_uniform_(lora_A_param, gain=init_scale)
                lora_a_count += 1
                
        if hasattr(module, 'lora_embedding_B') and module.lora_embedding_B:
            for adapter_name, lora_B_param in module.lora_embedding_B.items():
                torch.nn.init.zeros_(lora_B_param)
                lora_b_count += 1
    
    logging.info(f"Reinitialized {lora_a_count} LoRA A matrices with {init_type} (scale={init_scale})")
    logging.info(f"Reinitialized {lora_b_count} LoRA B matrices to zeros")
    return model


class PeftGradDebugCallback(TrainerCallback):
    def __init__(self):
        self.step_count = 0
        self.logged_steps = set()
        
    def on_train_begin(self, args, state, control, **kwargs):
        logging.info("=== PeftGradDebugCallback: Training started ===")
        print("=== PeftGradDebugCallback: Training started ===")
        sys.stdout.flush()
        
    def on_step_begin(self, args, state, control, **kwargs):
        # Log every few steps to confirm callback is working
        if state.global_step % 5 == 0 and state.global_step not in self.logged_steps:
            msg = f"[PeftGradDebugCallback] Step {state.global_step} beginning - callback active"
            logging.info(msg)
            print(msg)
            sys.stdout.flush()
            self.logged_steps.add(state.global_step)
    
    def on_backward_end(self, args, state, control, **kwargs):
        self.step_count += 1
        model = kwargs["model"]
        nz = total = 0
        max_gn = 0.0
        lora_params_with_grad = []
        lora_params_without_grad = []
        
        for n, p in model.named_parameters():
            if p.requires_grad and "lora_" in n:
                total += 1
                if (g := p.grad) is not None:
                    if torch.count_nonzero(g) > 0:
                        nz += 1
                        max_gn = max(max_gn, g.norm().item())
                        lora_params_with_grad.append(n)
                    else:
                        lora_params_without_grad.append(n)
                else:
                    lora_params_without_grad.append(n)
        
        # Log every step for the first 10 steps, then every 10 steps, or if no gradients
        should_log = (self.step_count <= 10) or (state.global_step % 10 == 0) or (nz == 0)
        
        if should_log:
            msg = f"[step {state.global_step}] LoRA non-zero grads: {nz}/{total} | max ∥grad∥={max_gn:.3e}"
            logging.info(msg)
            print(msg)
            sys.stdout.flush()
            
            # If no gradients, show which parameters don't have gradients
            if nz == 0 and self.step_count <= 5:
                logging.warning(f"No LoRA gradients! Params without grad: {lora_params_without_grad[:5]}")
                print(f"WARNING: No LoRA gradients! Params without grad: {lora_params_without_grad[:5]}")
            
            # Also log to a separate debug file
            debug_file = os.path.join(args.output_dir, 'lora_debug.log')
            with open(debug_file, 'a') as f:
                f.write(f"{msg}\n")
                if nz == 0:
                    f.write(f"  Params without grad: {lora_params_without_grad}\n")
                f.flush()



def parse_args():
    parser = argparse.ArgumentParser(description="LoRA fine-tune BERT for text classification")
    parser.add_argument("--dataset_name", type=str, default="ag_news",
                        help="Name of the Hugging Face dataset to use (default: 'ag_news')")
    parser.add_argument("--model_name", type=str, default="bert-base-uncased",
                        help="Pretrained BERT model identifier or path")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Where to save LoRA adapters and model")
    parser.add_argument("--num_epochs", type=int, default=3,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16,
                        help="Batch size for training/eval")
    parser.add_argument("--learning_rate", type=float, default=5e-4,
                        help="Learning rate for optimizer")
    parser.add_argument("--max_seq_length", type=int, default=256,
                        help="Max sequence length for tokenization")
    parser.add_argument("--lora_r", type=int, default=8,
                        help="LoRA rank (r)")
    parser.add_argument("--lora_alpha", type=int, default=16,
                        help="LoRA alpha scaling")
    parser.add_argument("--lora_dropout", type=float, default=0.1,
                        help="LoRA dropout rate")
    parser.add_argument("--lora_target_modules", nargs="+", default=["all-linear"],
                        help="Target modules for LoRA; "
                            "e.g. ‘--target_modules all-linear’ or "
                            "‘--target_modules query value’"
                        )
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--val_size", type=float, default=0.1,
                        help="Validation split size")

    parser.add_argument("--debug", action='store_true',
                        help="Enable debug mode with reduced model steps")
    # parser.add_argument("use_mixed_precision", type=int, default=1, choices=[0, 1],
    #                     help="Use mixed precision training (default: True)")
    parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models (optional)")

    # New arguments for generic classifier
    parser.add_argument("--pooling_strategy", type=str, default="mean", 
                        help="Pooling strategy: mean, max, sum, last, first/cls, attention, "
                             "weighted_average, or multi:strategy1,strategy2,... for multi-pooling")
    parser.add_argument("--classifier_hidden_dims", type=int, nargs="*", default=None,
                        help="Hidden dimensions for multi-layer classifier (e.g., --classifier_hidden_dims 512 256)")
    parser.add_argument("--classifier_dropout", type=float, default=0.1,
                        help="Dropout rate for classifier layers")
    parser.add_argument("--classifier_activation", type=str, default="relu",
                        choices=["relu", "gelu", "tanh", "leaky_relu", "swish", "silu"],
                        help="Activation function for classifier hidden layers")
    parser.add_argument("--pooling_combination", type=str, default="concat",
                        choices=["concat", "mean", "max", "learned"],
                        help="How to combine multiple pooling strategies (for multi-pooling)")
    
    # Custom LoRA initialization arguments
    parser.add_argument("--lora_init_type", type=str, default=None,
                        choices=["uniform", "gaussian", "orthogonal", "xavier_uniform", "xavier_normal", 
                                "kaiming_uniform", "kaiming_normal", "zeros"],
                        help="LoRA initialization method. If not specified, uses PEFT default initialization.")
    parser.add_argument("--lora_init_scale", type=float, default=1.0,
                        help="Scale parameter for LoRA initialization (default: 1.0)")
    
    return parser.parse_args()


def main():
    args = parse_args()
    if len(args.lora_target_modules) == 1:
        args.lora_target_modules = args.lora_target_modules[0]
    
    # Create timestamp string
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    # Append timestamp to output_dir
    args.output_dir = os.path.join(args.output_dir, f"model_{args.model_name}_ds_{args.dataset_name}_train_epoch_{args.num_epochs}_run_{timestamp}_seed_{args.seed}")
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Setup logging
    print(f"DEBUG: EXECUTING lora_v2.py (LoRA training script)")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, 'training.log')),
            logging.StreamHandler()
        ]
    )
    
    logging.info(f"Starting {args.model_name} LoRA training experiment")
    logging.info(f"Arguments: {vars(args)}")
    
    # Set reproducible seed
    set_reproducible_seed(args.seed)
    logging.info(f"Set random seed to {args.seed}")
    
    # Collect system information
    system_info = get_system_info()
    logging.info(f"System info collected: {system_info['python_version'].split()[0]}, "
                f"PyTorch {system_info['torch_version']}, "
                f"Transformers {system_info['transformers_version']}")
    
    # Load and prepare dataset
    logging.info(f"Loading {args.dataset_name} dataset...")
    
    # Load dataset config
    dataset_config = OmegaConf.load(f"conf/dataset/{args.dataset_name}.yaml")
    
    # Handle debug mode by setting sample limits
    if args.debug:
        dataset_config.max_train_samples = 100
        dataset_config.max_val_samples = 100
        dataset_config.max_test_samples = 100

    dataset = get_exp_data_hf(dataset_config, val_size=args.val_size, seed=args.seed)
    
    label_column = dataset_config.label_column
    unique_labels = set(example[label_column] for example in dataset["train"])
    print("Unique labels in training set:", unique_labels)

    # Initialize tokenizer
    # tokenizer = BertTokenizerFast.from_pretrained(args.model_name)
    # load tokenizer + model as before

    tokenizer, tokens_added = setup_tokenizer(args.model_name, args.hf_token)

    def tokenize_fn(ex):
        text_column = dataset_config.text_column
        return tokenizer(ex[text_column], truncation=True, max_length=args.max_seq_length)
    
    # Tokenize dataset
    tokenized = dataset.map(tokenize_fn, batched=True)
    
    # detect device & precision
    use_cuda = torch.cuda.is_available()
    use_mps  = torch.backends.mps.is_available()

    # Data collator
    data_collator = DataCollatorWithPadding(tokenizer)
    
    # Initialize model
    # logging.info(f"Loading base model: {args.model_name}")
    # num_labels = dataset["train"].features["label"].num_classes
    # base_model = AutoModelForSequenceClassification.from_pretrained(
    #     args.model_name,
    #     num_labels=num_labels,
    #     token=args.hf_token,  # Use your Hugging Face token
    # )
    logging.info(f"Loading base model: {args.model_name}")
    num_labels = dataset["train"].features[label_column].num_classes

    print(f"Number of labels ------- : {num_labels}")
    pooling_strategy = getattr(args, 'pooling_strategy', 'mean')
    hidden_dims = getattr(args, 'classifier_hidden_dims', None)
    dropout_rate = getattr(args, 'classifier_dropout', 0.1)
    activation = getattr(args, 'classifier_activation', 'relu')
    # pooling_combination = getattr(args, 'pooling_combination', 'concat')  # For multi-pooling
    max_seq_length = getattr(args, 'max_seq_length', 512)  # For weighted average pooling

    # Create the model with generic classifier
    base_model = GenericSequenceClassifier(
        model_name=args.model_name,
        num_labels=num_labels,
        pooling_strategy=pooling_strategy,
        hidden_dims=hidden_dims,
        dropout_rate=dropout_rate,
        activation=activation,
        hf_token=args.hf_token,
        # Additional kwargs for specific pooling strategies
        # combination=pooling_combination,  # For multi-pooling
        max_seq_length=max_seq_length,  # For weighted average pooling
    )
    
    # Fix model type to avoid PEFT warning about unsupported model type
    base_model.config.model_type = "generic_sequence_classifier"
    # Also fix the underlying transformer model type to match
    if hasattr(base_model, 'transformer') and hasattr(base_model.transformer, 'config'):
        base_model.transformer.config.model_type = "generic_sequence_classifier"
    
    # Register the model with transformers AutoModel for proper PEFT support
    from transformers import AutoModel, AutoConfig
    if not hasattr(AutoConfig, "_model_type_to_module_mapping") or \
       "generic_sequence_classifier" not in AutoConfig._model_type_to_module_mapping:
        AutoConfig.register("generic_sequence_classifier", GenericSequenceClassifierConfig)
        AutoModel.register(GenericSequenceClassifierConfig, GenericSequenceClassifier)

    # print(f"Base model created with GenericSequenceClassifier.")
    # for layer_name, params in base_model.named_parameters():
    #     print(layer_name, params.shape)
    # assert False, 'breakpoint after GenericSequenceClassifier creation'

    # Debugging step to verify label mappings
    # print(f"Model id2label: {base_model.transformer.config.id2label}")
    # print(f"Model label2id: {base_model.transformer.config.label2id}")

    # assert False, 'breakpoint for debugging label mappings'
    # Only resize if needed
    if tokens_added > 0:
        base_model.config.pad_token_id = tokenizer.pad_token_id
        base_model.config.vocab_size = len(tokenizer)
        base_model.resize_token_embeddings(len(tokenizer))
        logging.info(f"Resized model embeddings to {len(tokenizer)} tokens")


    logging.info("Applying LoRA configuration...")
    print(f"LLLLLLLoRA target modules: {args.lora_target_modules}")
    
    # Auto-detect target modules based on model architecture
    def get_model_target_modules(model, requested_modules):
        """
        Auto-detect appropriate target modules based on model architecture.
        Supports: BERT, RoBERTa, DistilBERT, GPT-2, Llama, Mistral, Gemma, TinyLlama
        """
        if isinstance(requested_modules, str) and requested_modules == "all-linear":
            return "all-linear"
        
        available_modules = set()
        for name, module in model.named_modules():
            if hasattr(module, 'weight') and len(module.weight.shape) == 2:  # Linear layers
                available_modules.add(name.split('.')[-1])  # Get the last part of the name
        
        # Define comprehensive mapping for different model architectures
        target_mappings = [
            # BERT/RoBERTa style attention modules
            {
                "pattern": ("query", "key", "value"),
                "modules": ["query", "key", "value"],
                "description": "BERT/RoBERTa-style attention"
            },
            # DistilBERT style attention modules
            {
                "pattern": ("q_lin", "k_lin", "v_lin"),
                "modules": ["q_lin", "k_lin", "v_lin"],
                "description": "DistilBERT-style attention"
            },
            # Llama/Mistral/Gemma style attention modules (most common in modern LLMs)
            {
                "pattern": ("q_proj", "k_proj", "v_proj"),
                "modules": ["q_proj", "k_proj", "v_proj"],
                "description": "Llama/Mistral/Gemma-style attention"
            },
            # GPT-2 style attention (combined QKV)
            {
                "pattern": ("c_attn",),
                "modules": ["c_attn"],
                "description": "GPT-2-style combined attention"
            },
            # GPT-2 style with output projection
            {
                "pattern": ("c_attn", "c_proj"),
                "modules": ["c_attn", "c_proj"],
                "description": "GPT-2-style attention with projection"
            },
        ]
        
        # Check each mapping in order of preference
        for mapping in target_mappings:
            pattern = mapping["pattern"]
            modules = mapping["modules"]
            description = mapping["description"]
            
            if all(mod in available_modules for mod in pattern):
                logging.info(f"Auto-detected LoRA target modules: {modules} ({description})")
                return modules
        
        # Special case: if we have requested standard names, try common alternatives
        if isinstance(requested_modules, list) and set(requested_modules) == {"key", "query", "value"}:
            # Try alternative patterns that might work
            alternative_patterns = [
                ["q_lin", "k_lin", "v_lin"],  # DistilBERT
                ["q_proj", "k_proj", "v_proj"],  # Llama-style
                ["c_attn"]  # GPT-2
            ]
            
            for alt_pattern in alternative_patterns:
                if all(mod in available_modules for mod in alt_pattern):
                    logging.info(f"Auto-mapped standard attention modules to: {alt_pattern}")
                    return alt_pattern
        
        # Fallback: return original if no mapping found
        logging.warning(f"No automatic mapping found for requested modules {requested_modules}")
        logging.info(f"Available linear modules in model: {sorted(available_modules)}")
        return requested_modules
    
    # Auto-detect the appropriate target modules
    target_modules = get_model_target_modules(base_model, args.lora_target_modules)
    
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,  # Reverted: Both methods should train classifier
        inference_mode=False,
        r=args.lora_r,
        # r=32,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        # init_lora_weights='orthogonal',
        target_modules=target_modules,  # Use auto-detected target modules
        # target_modules=["q_proj","v_proj"],  # Use provided target modules
        modules_to_save=["classifier", "pooling_strategy"],  # keeps head as a normal saved module
        # bias='lora_only',  # Only train LoRA biases
    )
    model = get_peft_model(base_model, peft_config)
    
    # Apply custom LoRA initialization if specified
    if args.lora_init_type is not None:
        logging.info(f"Applying custom LoRA initialization: {args.lora_init_type} with scale {args.lora_init_scale}")
        model = reinit_lora_weights(model, args.lora_init_type, args.lora_init_scale)
    else:
        logging.info("Using PEFT default LoRA initialization")

    def debug_peft_matches(m):
        found = []
        for n, mod in m.named_modules():
            if any(hasattr(mod, a) for a in ("lora_A", "lora_B", "lora_A_default", "lora_B_default")):
                found.append(n)
        print(f"[PEFT] LoRA attached to {len(found)} modules")
        for n in found[:30]:
            print("  ", n)
        if not found:
            print("[PEFT] WARNING: No target modules were matched!")


    debug_peft_matches(model)

    model.print_trainable_parameters()
    for n,_ in model.named_modules():
        assert not n.startswith("classifier.") or "lora_" not in n, f"Classifier got LoRA: {n}"
    
    # Debug: Check LoRA parameter requires_grad status
    print("\n=== LoRA Parameter Debug ===")
    lora_params_count = 0
    lora_trainable_count = 0
    for n, p in model.named_parameters():
        if "lora_" in n:
            lora_params_count += 1
            if p.requires_grad:
                lora_trainable_count += 1
                print(f"✓ {n}: requires_grad={p.requires_grad}, shape={p.shape}")
            else:
                print(f"✗ {n}: requires_grad={p.requires_grad}, shape={p.shape}")
    
    print(f"LoRA params total: {lora_params_count}, trainable: {lora_trainable_count}")
    
    # Also check classifier parameters
    classifier_params_count = 0
    classifier_trainable_count = 0
    for n, p in model.named_parameters():
        if "classifier" in n:
            classifier_params_count += 1
            if p.requires_grad:
                classifier_trainable_count += 1
                print(f"✓ Classifier {n}: requires_grad={p.requires_grad}, shape={p.shape}")
    
    print(f"Classifier params total: {classifier_params_count}, trainable: {classifier_trainable_count}")
    print("=" * 30)
    
    # Test gradient flow with a dummy batch
    print("\n=== Testing Gradient Flow ===")
    model.train()
    dummy_batch = {
        'input_ids': torch.randint(0, 1000, (2, 10)).to(next(model.parameters()).device),
        'attention_mask': torch.ones(2, 10).to(next(model.parameters()).device),
        'labels': torch.randint(0, num_labels, (2,)).to(next(model.parameters()).device)
    }
    
    try:
        model.zero_grad()
        outputs = model(**dummy_batch)
        loss = outputs.loss
        loss.backward()
        
        lora_grads_found = 0
        for n, p in model.named_parameters():
            if "lora_" in n and p.grad is not None and torch.count_nonzero(p.grad) > 0:
                lora_grads_found += 1
        
        print(f"Dummy test: {lora_grads_found} LoRA params received gradients")
        print(f"Loss: {loss.item():.4f}")
        
    except Exception as e:
        print(f"Gradient test failed: {e}")
    
    model.zero_grad()  # Clean up
    print("=" * 30)



    # assert False, 'breakpoint after LoRA setup'
    # for n, p in model.named_parameters():
    #     if p.requires_grad:
    #         print(n, p.shape)
    # assert False, 'breakpoint after model and LoRA setup'

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logging.info(f"Total parameters: {total_params:,}")
    logging.info(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    
    # Prepare datasets
    train_dataset = tokenized["train"]
    val_dataset = tokenized["val"]
    test_dataset = tokenized["test"]
    
    logging.info(f"Dataset splits - Train: {len(train_dataset)}, "
                f"Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    # --------------------- DEBUG BEGIN ---------------------
    # drop raw columns; keep only encodings + label
    # remove_cols = [c for c in dataset["train"].column_names
    #             if c not in [dataset_config.text_column, label_column]]
    # tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=remove_cols)

    # # rename label to 'labels' for Trainer / model API
    # if label_column != "labels":
    #     tokenized = tokenized.rename_column(label_column, "labels")

    # # set torch format; BERT often has token_type_ids
    # cols = ["input_ids", "attention_mask", "labels"]
    # if "token_type_ids" in tokenized["train"].column_names:
    #     cols.append("token_type_ids")
    # tokenized.set_format(type="torch", columns=cols)

    # train_dataset = tokenized["train"]; val_dataset = tokenized["val"]; test_dataset = tokenized["test"]

    # # Count grads in LoRA vs head
    # lora_params = sum(p.numel() for n,p in model.named_parameters() if p.requires_grad and "lora_" in n)
    # head_params = sum(p.numel() for n,p in model.named_parameters() if p.requires_grad and "classifier" in n)
    # print(f"Trainable LoRA params: {lora_params:,}")
    # print(f"Trainable head  params: {head_params:,}")

    # # Verify some LoRA weights actually change with a single step
    # import copy
    # before = {}
    # for n,p in model.named_parameters():
    #     if p.requires_grad and "lora_" in n:
    #         before[n] = p.detach().clone()

    # # batch = next(iter(torch.utils.data.DataLoader(train_dataset, batch_size=2, collate_fn=data_collator)))
    # dl = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=data_collator, shuffle=True)
    # opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-3)

    # def count_nz_grads():
    #     nz = total = 0
    #     for n,p in model.named_parameters():
    #         if p.requires_grad and "lora_" in n:
    #             total += 1
    #             if p.grad is not None and torch.count_nonzero(p.grad).item() > 0:
    #                 nz += 1
    #     return nz, total

    # for step, batch in enumerate(dl):
    #     if step == 0:  # optional few warmup steps
    #         pass
    #     batch = {k: v.to(next(model.parameters()).device) for k,v in batch.items()}
    #     model.train(); model.zero_grad()
    #     loss = model(**batch).loss
    #     loss.backward()
    #     nz, total = count_nz_grads()
    #     print(f"step {step} | loss {loss.item():.4f} | LoRA non-zero grad: {nz}/{total}")
    #     opt.step()
    #     if step == 10: break
    
    
    
    # model.zero_grad()

    # assert False, 'breakpoint after single step grad check'

    # --------------------- DEBUG END ---------------------

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        eval_strategy="epoch",
        logging_strategy="steps",
        save_strategy="epoch",
        logging_dir=os.path.join(args.output_dir, "logs"),
        logging_steps=50,
        seed=args.seed,
        metric_for_best_model="eval_loss",
        save_total_limit=1,
        load_best_model_at_end=True,
        bf16=use_cuda,
        use_mps_device=use_mps,  # Use MPS if available
        report_to=None,  # Disable wandb/tensorboard
        # max_steps=20 if args.debug else -1,  # Limit steps for debug mode
        # gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
        # gradient_checkpointing_kwargs={"use_reentrant": False},  # Use newer checkpointing format
    )
    

    logging.info(f"Training arguments: {training_args}")

    # Initialize callback
    metrics_callback = EnhancedTrainMetricsCallback(args.output_dir)
    


    from transformers import TrainerCallback

    class PeftGradDebugCallback(TrainerCallback):
        def on_backward_end(self, args, state, control, **kwargs):
            model = kwargs["model"]
            nz = total = 0
            max_gn = 0.0
            for n, p in model.named_parameters():
                if p.requires_grad and "lora_" in n:
                    total += 1
                    if (g := p.grad) is not None:
                        if torch.count_nonzero(g) > 0:
                            nz += 1
                            max_gn = max(max_gn, g.norm().item())
            # if state.global_step % max(1, args.logging_steps) == 0:
            print(f"[step {state.global_step}] LoRA non-zero grads: {nz}/{total} | max ∥grad∥={max_gn:.3e}")

    # add to your `callbacks=[...]`
    peft_debug_callback = PeftGradDebugCallback()
    callbacks=[metrics_callback, peft_debug_callback]
    
    logging.info("Added PeftGradDebugCallback to monitor LoRA gradient flow")
    
    # Test callback setup by checking LoRA parameters
    lora_param_count = 0
    for n, p in model.named_parameters():
        if p.requires_grad and "lora_" in n:
            lora_param_count += 1
    
    test_msg = f"DEBUG: Found {lora_param_count} LoRA parameters that will be monitored"
    logging.info(test_msg)
    print(test_msg)
    sys.stdout.flush()

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=callbacks,
    )
    
    # Set trainer reference in callback
    metrics_callback.trainer = trainer
    
    # Train model
    logging.info("Starting training...")
    trainer.train()
    
    # Final evaluation
    logging.info("Performing final evaluation...")
    
    # Evaluate on all splits
    train_metrics = trainer.evaluate(eval_dataset=train_dataset, metric_key_prefix="train")
    val_metrics = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="val")
    test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
    
    # Print results
    logging.info("=== Final Results ===")
    logging.info(f"Train   → loss: {train_metrics['train_loss']:.4f}, accuracy: {train_metrics['train_accuracy']:.4f}")
    logging.info(f"Val     → loss: {val_metrics['val_loss']:.4f}, accuracy: {val_metrics['val_accuracy']:.4f}")
    logging.info(f"Test    → loss: {test_metrics['test_loss']:.4f}, accuracy: {test_metrics['test_accuracy']:.4f}")
    
    # Generate predictions for confusion matrix
    logging.info("Generating confusion matrices...")
    # class_names = ['World', 'Sports', 'Business', 'Technology']
    # assume `dataset` is a DatasetDict with a “train” split
    label_feat = dataset["train"].features[label_column]

    if isinstance(label_feat, ClassLabel):
        class_names = label_feat.names
    else:
        # fallback: if labels are raw ints with no ClassLabel feature
        class_names = [str(i) for i in range(label_feat.num_classes)]
    
    # Test set confusion matrix
    test_predictions = trainer.predict(test_dataset)
    test_preds = np.argmax(test_predictions.predictions, axis=-1)
    test_labels = test_predictions.label_ids
    
    save_confusion_matrix(test_labels, test_preds, class_names, 
                         os.path.join(args.output_dir, 'confusion_matrix_test.json'))
    
    # Merge LoRA weights
    logging.info("Merging LoRA weights...")
    model = model.merge_and_unload()
    
    # Save merged model
    merged_output_dir = os.path.join(args.output_dir, "final_model")
    model.save_pretrained(merged_output_dir)
    tokenizer.save_pretrained(merged_output_dir)
    
    # Create experiment configuration
    experiment_info = {
        'experiment_id': f"{args.model_name}_lora_{timestamp}",
        'command': ' '.join(sys.argv),
        'args': vars(args),
        'dataset_info': {
            'name': args.dataset_name,
            'splits': {
                'train': len(train_dataset),
                'val': len(val_dataset),
                'test': len(test_dataset)
            },
            'num_classes': num_labels,
            'class_names': class_names
        },
        'model_info': {
            'base_model': args.model_name,
            'total_params': total_params,
            'trainable_params': trainable_params,
            'lora_config': {
                'r': args.lora_r,
                'alpha': args.lora_alpha,
                'dropout': args.lora_dropout,
                'target_modules': args.lora_target_modules,
                'init_type': args.lora_init_type if args.lora_init_type else 'peft_default',
                'init_scale': args.lora_init_scale if args.lora_init_type else None
            }
        },
        'training_args': training_args.to_dict(),
        'system_info': system_info,
        'results': {
            'train_metrics': train_metrics,
            'val_metrics': val_metrics,
            'test_metrics': test_metrics
        }
    }
    
    # Save experiment configuration
    with open(os.path.join(args.output_dir, 'experiment_config.json'), 'w') as f:
        json.dump(experiment_info, f, indent=2)
    
    # Generate model card
    all_metrics = {**train_metrics, **val_metrics, **test_metrics}
    generate_model_card(args, system_info, experiment_info, all_metrics, args.output_dir)
    
    logging.info("=== Experiment Complete ===")
    logging.info(f"All artifacts saved to: {args.output_dir}")
    logging.info("Files created:")
    for file in os.listdir(args.output_dir):
        if os.path.isfile(os.path.join(args.output_dir, file)):
            logging.info(f"  - {file}")
    logging.info(f"  - merged_output_dir/ (contains final model)")


if __name__ == '__main__':
    main()