import hydra
from omegaconf import DictConfig, OmegaConf
import argparse
import os
import sys
import json
import logging
import datetime

import torch
import numpy as np
from transformers import (
    BertForSequenceClassification,
    BertTokenizerFast,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    TrainerCallback,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from datasets import ClassLabel
import numpy as np
from exp_data import get_exp_data_hf

from training_utils import (
    compute_metrics,
    EnhancedTrainMetricsCallback,
    set_reproducible_seed,
    get_system_info,
    save_confusion_matrix,
    generate_model_card,
    setup_tokenizer,
)

from seq_classifier import GenericSequenceClassifier


def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune BERT for text classification")
    parser.add_argument("--dataset_name", type=str, default="ag_news",
                        help="Name of the Hugging Face dataset to use (default: 'ag_news')")
    parser.add_argument("--model_name", type=str, default="bert-base-uncased",
                        help="Pretrained BERT model identifier or path")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Directory to save the model and outputs")
    parser.add_argument("--num_epochs", type=int, default=3,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16,
                        help="Batch size for training/eval")
    parser.add_argument("--learning_rate", type=float, default=5e-4,
                        help="Learning rate for optimizer")
    parser.add_argument("--max_seq_length", type=int, default=256,
                        help="Max sequence length for tokenization")
    parser.add_argument("--freeze_base", action="store_true",
                        help="If set, freeze BERT base and only train the classification head")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--val_size", type=float, default=0.1,
                        help="Validation split size")
    parser.add_argument("--debug", action='store_true',
                        help="Enable debug mode with reduced model steps")
    parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models (optional)")
    # New arguments for generic classifier
    parser.add_argument("--pooling_strategy", type=str, default="mean", 
                        help="Pooling strategy: mean, max, sum, last, first/cls, attention, "
                             "weighted_average, or multi:strategy1,strategy2,... for multi-pooling")
    parser.add_argument("--classifier_hidden_dims", type=int, nargs="*", default=None,
                        help="Hidden dimensions for multi-layer classifier (e.g., --classifier_hidden_dims 512 256)")
    parser.add_argument("--classifier_dropout", type=float, default=0.1,
                        help="Dropout rate for classifier layers")
    parser.add_argument("--classifier_activation", type=str, default="relu",
                        choices=["relu", "gelu", "tanh", "leaky_relu", "swish", "silu"],
                        help="Activation function for classifier hidden layers")
    parser.add_argument("--pooling_combination", type=str, default="concat",
                        choices=["concat", "mean", "max", "learned"],
                        help="How to combine multiple pooling strategies (for multi-pooling)")
    
    return parser.parse_args()


def main(cfg: DictConfig = None):
    if cfg is None:
        # Fallback to argparse for direct CLI use
        args = parse_args()
        # Convert argparse Namespace to DictConfig
        cfg = OmegaConf.create(vars(args)) 
    else:
        args = cfg  # For code below, use args as alias for cfg
    # Create timestamp string
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    # Append timestamp to output_dir
    args.output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Setup logging
    print(f"DEBUG: EXECUTING finetune_v2.py (head_only/full training script)")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, 'training.log')),
            logging.StreamHandler()
        ]
    )
    
    logging.info(f"Starting {args.model_name} fine-tuning experiment")
    logging.info(f"Arguments: {vars(args)}")

    
    # Set reproducible seed
    set_reproducible_seed(args.seed)
    logging.info(f"Set random seed to {args.seed}")
    
    # Collect system information
    system_info = get_system_info()
    logging.info(f"System info collected: {system_info['python_version'].split()[0]}, "
                f"PyTorch {system_info['torch_version']}, "
                f"Transformers {system_info['transformers_version']}")
    
    # Load and prepare dataset
    logging.info(f"Loading {args.dataset_name} dataset...")
    
    # Load dataset config
    dataset_config = OmegaConf.load(f"conf/dataset/{args.dataset_name}.yaml")
    
    # Handle debug mode by setting sample limits
    if args.debug:
        dataset_config.max_train_samples = 100
        dataset_config.max_val_samples = 100
        dataset_config.max_test_samples = 100

    dataset = get_exp_data_hf(dataset_config, val_size=args.val_size, seed=args.seed)
    
    # Initialize tokenizer
    # tokenizer = BertTokenizerFast.from_pretrained(args.model_name)
    # load tokenizer + model as before

    tokenizer, tokens_added = setup_tokenizer(args.model_name, args.hf_token)

    def tokenize_fn(ex):
        text_column = dataset_config.text_column
        return tokenizer(ex[text_column], truncation=True, max_length=args.max_seq_length)
    
    # Tokenize dataset
    tokenized = dataset.map(tokenize_fn, batched=True)
    
    # detect device & precision
    use_cuda = torch.cuda.is_available()
    use_mps  = torch.backends.mps.is_available()

    # Data collator
    data_collator = DataCollatorWithPadding(tokenizer)
    
    # Initialize model
    logging.info(f"Loading base model: {args.model_name}")
    label_column = dataset_config.label_column
    num_labels = dataset["train"].features[label_column].num_classes
    pooling_strategy = getattr(args, 'pooling_strategy', 'mean')
    hidden_dims = getattr(args, 'classifier_hidden_dims', None)
    dropout_rate = getattr(args, 'classifier_dropout', 0.1)
    activation = getattr(args, 'classifier_activation', 'relu')
    pooling_combination = getattr(args, 'pooling_combination', 'concat')  # For multi-pooling
    max_seq_length = getattr(args, 'max_seq_length', 512)  # For weighted average pooling

    
    
    # model = AutoModelForSequenceClassification.from_pretrained(
    #     args.model_name,
    #     num_labels=num_labels,
    #     token=args.hf_token,  # Use your Hugging Face token
    # )
    # Create the model with generic classifier
    model = GenericSequenceClassifier(
        model_name=args.model_name,
        num_labels=num_labels,
        pooling_strategy=pooling_strategy,
        hidden_dims=hidden_dims,
        dropout_rate=dropout_rate,
        activation=activation,
        hf_token=args.hf_token,
        # Additional kwargs for specific pooling strategies
        combination=pooling_combination,  # For multi-pooling
        max_seq_length=max_seq_length,  # For weighted average pooling
    )
    
    logging.info(f"Created model with {pooling_strategy} pooling")
    logging.info(f"Pooling output size: {model.get_pooling_output_size()}")

    # Only resize if needed
    if tokens_added > 0:
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.vocab_size = len(tokenizer)
        model.resize_token_embeddings(len(tokenizer))
        logging.info(f"Resized model embeddings to {len(tokenizer)} tokens")

    # print(f"model.config.vocab_size = {model.config.vocab_size}")
    # assert False, 'breakpoint to check vocab size'
    # Optionally freeze base model
    # if args.freeze_base:
    #     backbone = getattr(model, model.base_model_prefix)  # e.g. model.bert, model.roberta, model.distilbert, etc.
    #     print(f"freezing weights for backbone = {backbone}")
    #     for param in backbone.parameters():
    #         param.requires_grad = False
    if args.freeze_base:
        # Freeze the base model (encoder) but keep classifier trainable
        for param in model.transformer.parameters():
            param.requires_grad = False
        logging.info("Froze base model parameters")

    # # Prepare datasets
    # train_dataset = TextClassificationDataset(
    #     train_texts, train_labels, tokenizer, max_length=args.max_seq_length
    # )
    # eval_dataset = TextClassificationDataset(
    #     val_texts, val_labels, tokenizer, max_length=args.max_seq_length
    # )
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logging.info(f"Total parameters: {total_params:,}")
    logging.info(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    train_dataset = tokenized["train"]
    val_dataset = tokenized["val"]
    test_dataset = tokenized["test"]


    # print(f"training for {args.num_epochs} epochs")
    # assert False, 'breakpoint to check epochs'
    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        eval_strategy="epoch",
        logging_strategy="steps",
        save_strategy="epoch",
        logging_dir=os.path.join(args.output_dir, "logs"),
        logging_steps=50,
        seed=args.seed,
        metric_for_best_model="eval_loss",
        save_total_limit=1,
        load_best_model_at_end=True,
        bf16=use_cuda,
        use_mps_device=use_mps,  # Use MPS if available
        report_to=None,  # Disable wandb/tensorboard
        # max_steps=20 if args.debug else -1,  # Limit steps for debug mode
    )

    # # Define accuracy metric
    # def compute_metrics(p):
    #     preds = np.argmax(p.predictions, axis=1)
    #     acc = accuracy_score(p.label_ids, preds)
    #     return {"accuracy": acc}

    metrics_callback = EnhancedTrainMetricsCallback(args.output_dir)

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[metrics_callback]
    )

    # Set trainer reference in callback
    metrics_callback.trainer = trainer
    
    # Train model
    logging.info("Starting training...")
    trainer.train()
    
    # Final evaluation
    logging.info("Performing final evaluation...")
    
    # Evaluate on all splits
    train_metrics = trainer.evaluate(eval_dataset=train_dataset, metric_key_prefix="train")
    val_metrics = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="val")
    test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
    
    # Print results
    logging.info("=== Final Results ===")
    logging.info(f"Train   → loss: {train_metrics['train_loss']:.4f}, accuracy: {train_metrics['train_accuracy']:.4f}")
    logging.info(f"Val     → loss: {val_metrics['val_loss']:.4f}, accuracy: {val_metrics['val_accuracy']:.4f}")
    logging.info(f"Test    → loss: {test_metrics['test_loss']:.4f}, accuracy: {test_metrics['test_accuracy']:.4f}")

        
    # Generate predictions for confusion matrix
    logging.info("Generating confusion matrices...")
    # class_names = ['World', 'Sports', 'Business', 'Technology']
    # assume `dataset` is a DatasetDict with a “train” split
    label_feat = dataset["train"].features["label"]

    if isinstance(label_feat, ClassLabel):
        class_names = label_feat.names
    else:
        # fallback: if labels are raw ints with no ClassLabel feature
        class_names = [str(i) for i in range(label_feat.num_classes)]
    
    # Test set confusion matrix
    test_predictions = trainer.predict(test_dataset)
    test_preds = np.argmax(test_predictions.predictions, axis=-1)
    test_labels = test_predictions.label_ids
    
    save_confusion_matrix(test_labels, test_preds, class_names, 
                         os.path.join(args.output_dir, 'confusion_matrix_test.json'))

    merged_output_dir = os.path.join(args.output_dir, "final_model")
    trainer.save_model(merged_output_dir)
    # Create experiment configuration
    experiment_info = {
        'experiment_id': f"{args.model_name}_fully_finetune_{timestamp}",
        'command': ' '.join(sys.argv),
        'args': vars(args),
        'dataset_info': {
            'name': args.dataset_name,
            'splits': {
                'train': len(train_dataset),
                'val': len(val_dataset),
                'test': len(test_dataset)
            },
            'num_classes': num_labels,
            'class_names': class_names
        },
        'model_info': {
            'base_model': args.model_name,
            'total_params': total_params,
            'trainable_params': trainable_params,
        },
        'training_args': training_args.to_dict(),
        'system_info': system_info,
        'results': {
            'train_metrics': train_metrics,
            'val_metrics': val_metrics,
            'test_metrics': test_metrics
        }
    }
    
    # Save experiment configuration
    with open(os.path.join(args.output_dir, 'experiment_config.json'), 'w') as f:
        json.dump(experiment_info, f, indent=2)
    
    # Generate model card
    all_metrics = {**train_metrics, **val_metrics, **test_metrics}
    generate_model_card(args, system_info, experiment_info, all_metrics, args.output_dir)
    
    logging.info("=== Experiment Complete ===")
    logging.info(f"All artifacts saved to: {args.output_dir}")
    logging.info("Files created:")
    for file in os.listdir(args.output_dir):
        if os.path.isfile(os.path.join(args.output_dir, file)):
            logging.info(f"  - {file}")
    logging.info(f"  - merged_output_dir/ (contains final model)")


if __name__ == '__main__':
    try:
        import sys
        if '--hydra' in sys.argv:
            hydra.main(config_path="../../conf", config_name="config")(main)()
        else:
            main()
    except ImportError:
        main()
