import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import torch
import json
import wandb
import os
from typing import Dict, List, Optional, Union
from accelerate import Accelerator
import argparse
from dataset_utils import load_mcq_dataset

def compute_metrics(pred):
    """Compute evaluation metrics for the model."""
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def train_mcq_classifier(
    train_dataset: Dataset, 
    eval_dataset: Dataset, 
    model_name: str,
    output_dir: str,
    num_labels: int = 4,
    learning_rate: float = 1e-5,
    classifier_lr: float = None,
    train_batch_size: int = 8,
    eval_batch_size: int = 16,
    gradient_accumulation_steps: int = 8,
    num_train_epochs: int = 50,
    max_seq_length: int = 2048,
    eval_steps: int = 100,
    save_steps: int = 500,
    freeze_base_model: bool = False,
    random_init: bool = False,
    wandb_run_name: Optional[str] = None,
    warmup_ratio: float = 0.1,
    weight_decay: float = 0.1,
    lr_scheduler_type: str = "constant_with_warmup",
    save_final_model: bool = False,
    use_lora: bool = False,
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: Optional[List[str]] = None,
) -> Dict[str, float]:
    """
    Train a multiple-choice classifier on the provided datasets.
    
    Args:
        train_dataset: Dataset for training
        eval_dataset: Dataset for evaluation
        model_name: Name of the pretrained model
        output_dir: Directory to save outputs
        num_labels: Number of classification labels
        learning_rate: Learning rate for optimization
        classifier_lr: Optional separate learning rate for classifier head (higher than backbone)
        train_batch_size: Batch size for training
        eval_batch_size: Batch size for evaluation
        gradient_accumulation_steps: Number of gradient accumulation steps
        num_train_epochs: Number of training epochs
        max_seq_length: Maximum sequence length for tokenization
        eval_steps: Steps between evaluations
        save_steps: Steps between model saves
        freeze_base_model: Whether to freeze the base model
        random_init: Whether to initialize model randomly
        wandb_run_name: Optional run name for W&B
        warmup_ratio: Ratio of total steps for warmup
        weight_decay: Weight decay for regularization
        lr_scheduler_type: Type of learning rate scheduler
        save_final_model: Whether to save the final model
        use_lora: Whether to use LoRA for parameter-efficient fine-tuning
        lora_r: LoRA rank parameter
        lora_alpha: LoRA alpha parameter
        lora_dropout: LoRA dropout rate
        lora_target_modules: List of module names to apply LoRA to (if None, uses defaults)
        
    Returns:
        Evaluation results
    """
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Initialize model
    if random_init:
        # Initialize with random weights
        config = AutoModelForSequenceClassification.config_from_pretrained(
            model_name,
            num_labels=num_labels
        )
        model = AutoModelForSequenceClassification.from_config(config)
    else:
        # Initialize with pretrained weights
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            problem_type="single_label_classification"
        )
    
    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # For Llama models, ensure pad_token_id is properly set
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Enable gradient checkpointing before the model is wrapped
    # if "llama" in model_name.lower() or "qwen" in model_name.lower():
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    
    # Reinitialize the classification head for better learning
    if hasattr(model, 'classifier'):
        model.classifier.weight.data.normal_(mean=0.0, std=0.02)
        model.classifier.bias.data.zero_()
    elif hasattr(model, 'score'):
        if hasattr(model.score, 'weight') and model.score.weight is not None:
            model.score.weight.data.normal_(mean=0.0, std=0.02)
        if hasattr(model.score, 'bias') and model.score.bias is not None:
            model.score.bias.data.zero_()
    
    # Apply LoRA if specified
    if use_lora:
        if lora_target_modules is None:
            # Default target modules based on model type
            if "llama" in model_name.lower():
                lora_target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
            elif "mistral" in model_name.lower():
                lora_target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
            elif "qwen" in model_name.lower():
                lora_target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
            elif "bert" in model_name.lower() or "deberta" in model_name.lower():
                lora_target_modules = ["query", "key", "value", "output.dense"]
            else:
                lora_target_modules = ["query", "key", "value", "output.dense"]
        
        peft_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            inference_mode=False,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=lora_target_modules,
        )
        
        print(f"Applying LoRA with rank {lora_r}, alpha {lora_alpha}, dropout {lora_dropout}")
        print(f"Target modules: {lora_target_modules}")
        
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
        
        # When using LoRA, we don't need to freeze the base model separately
        if freeze_base_model:
            print("Note: freeze_base_model=True is redundant when using LoRA, as LoRA already freezes base parameters")
            freeze_base_model = False
    
    # Freeze base model if specified (and not using LoRA)
    if freeze_base_model:
        # Freeze all parameters except the classifier layer
        for name, param in model.named_parameters():
            # Only keep classifier/output layer parameters trainable
            if 'classifier' not in name and 'output' not in name:
                param.requires_grad = False
        
        # Count and report frozen parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        
        print(f"Froze {frozen_params:,} parameters out of {total_params:,} total parameters")
        print(f"Training only {trainable_params:,} parameters ({trainable_params/total_params:.2%} of model)")

    def tokenize_function(examples):
        """Tokenize the input text."""
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=max_seq_length
        )

    # Tokenize datasets
    tokenized_train = train_dataset.map(tokenize_function, batched=True, 
                                        remove_columns=['text'])
    tokenized_eval = eval_dataset.map(tokenize_function, batched=True, 
                                      remove_columns=['text'])
    
    # Define training arguments
    # training_args = TrainingArguments(
    #     output_dir=output_dir,
    #     learning_rate=learning_rate,
    #     per_device_train_batch_size=train_batch_size,
    #     per_device_eval_batch_size=eval_batch_size,
    #     gradient_accumulation_steps=gradient_accumulation_steps,
    #     num_train_epochs=num_train_epochs,
    #     weight_decay=weight_decay,
    #     evaluation_strategy="steps",
    #     eval_steps=eval_steps,
    #     metric_for_best_model="accuracy",
    #     save_total_limit=1,
    #     save_steps=save_steps,
    #     run_name=wandb_run_name if wandb_run_name else "mcq-classifier",
    #     warmup_ratio=warmup_ratio,
    #     logging_steps=10,
    #     report_to="wandb",
    #     save_strategy="steps",
    #     lr_scheduler_type=lr_scheduler_type,
    #     fp16=True  # IDEALLY: Enable BF16 precision to match DeepSpeed config
    # )
    
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=5e-6,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=1,
        num_train_epochs=30,
        weight_decay=0.1,
        eval_strategy="steps",
        eval_steps=20,
        metric_for_best_model="accuracy",
        save_total_limit=1,
        run_name="mmlu-mcq",
        warmup_ratio=0.1,
        logging_steps=10,
        report_to="wandb" if accelerator.is_main_process else "none",
        save_strategy="no",
        fp16=True,
        # gradient_checkpointing=True,
        lr_scheduler_type="constant_with_warmup",
        # remove_unused_columns=False,  # Add this line
    )

    # Use separate learning rates for backbone and classifier if requested
    if classifier_lr is not None:        
        # Set up parameter groups with different learning rates
        optimizer_grouped_parameters = [
            {
                # Base model parameters with lower learning rate
                "params": [p for n, p in model.named_parameters() 
                          if not any(classifier_name in n for classifier_name in ['classifier', 'score', 'output'])],
                "lr": learning_rate,
            },
            {
                # Classifier head parameters with higher learning rate
                "params": [p for n, p in model.named_parameters() 
                          if any(classifier_name in n for classifier_name in ['classifier', 'score', 'output'])],
                "lr": classifier_lr,
            },
        ]
        
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            weight_decay=weight_decay,
        )
        
        # Initialize trainer with custom optimizer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_eval,
            compute_metrics=compute_metrics,
            optimizers=(optimizer, None)  # Custom optimizer, default scheduler
        )
    else:
        # Initialize trainer with default optimizer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_eval,
            compute_metrics=compute_metrics
        )
    
    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    # Save sample-level predictions
    predictions = trainer.predict(tokenized_eval)
    preds = predictions.predictions.argmax(-1)
    
    # Create predictions file
    predictions_data = []
    for i, (pred, true_label) in enumerate(zip(preds, eval_dataset['label'])):
        predictions_data.append({
            'index': i,
            'text': eval_dataset[i]['text'],
            'predicted_label': int(pred),
            'true_label': int(true_label),
            'correct': int(pred) == int(true_label)
        })
    
    # Save predictions to output directory
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "predictions.json"), "w") as f:
        json.dump(predictions_data, f, indent=2)
    
    # Save classification report
    report = classification_report(
        eval_dataset['label'], 
        preds, 
        output_dict=True
    )
    with open(os.path.join(output_dir, "classification_report.json"), "w") as f:
        json.dump(report, f, indent=2)
    
    if save_final_model:
        # Save the final model
        trainer.save_model(output_dir)
        
        # For LoRA models, also save the adapter separately
        if use_lora:
            model.save_pretrained(os.path.join(output_dir, "lora_adapter"))
    
    return eval_results

def main():
    """Main function to parse arguments and run training."""
    parser = argparse.ArgumentParser(description="MCQ Classifier Training")
    
    # Dataset arguments
    parser.add_argument("--dataset_name", type=str, default="mmlu_pro",
                        help="Name of the dataset to use")
    parser.add_argument("--train_dataset", type=str, default=None,
                        help="Optional: Separate dataset to use for training")
    parser.add_argument("--test_dataset", type=str, default=None,
                        help="Optional: Separate dataset to use for testing")
    parser.add_argument("--num_options", type=int, default=4,
                        help="Number of options for MCQ")
    parser.add_argument("--train_ratio", type=float, default=0.5,
                        help="Proportion of data for training")
    parser.add_argument("--test_ratio", type=float, default=0.5,
                        help="Proportion of data for testing")
    parser.add_argument("--split_test_set", action="store_true",
                        help="Split test set into train/test when no training set is available")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--only_options", action="store_true",
                        help="Include only options in the prompt (no question)")
    parser.add_argument("--option_sampling_strategy", type=str, default="both", 
                       choices=["both", "correct", "incorrect"],
                       help="Strategy for sampling additional options: 'both' uses all options, "
                            "'correct' uses only correct answers, 'incorrect' uses only incorrect answers")
    
    # Model arguments
    parser.add_argument("--model_name", type=str, default="microsoft/deberta-v3-large",
                        help="Name of the pretrained model")
    parser.add_argument("--max_seq_length", type=int, default=512,
                        help="Maximum sequence length for tokenization")
    parser.add_argument("--freeze_base_model", action="store_true",
                        help="Freeze base model and only train classifier head")
    parser.add_argument("--random_init", action="store_true",
                        help="Initialize model with random weights")
    parser.add_argument("--save_final_model", action="store_true",
                        help="Save the final model")
    
    # Training arguments
    parser.add_argument("--output_dir", type=str, default=None,
                        help="Directory to save model outputs (if None, auto-generated)")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="Learning rate for optimization (backbone)")
    parser.add_argument("--classifier_lr", type=float, default=1e-4,
                        help="Optional separate learning rate for classifier head")
    parser.add_argument("--train_batch_size", type=int, default=8,
                        help="Batch size for training")
    parser.add_argument("--eval_batch_size", type=int, default=16,
                        help="Batch size for evaluation")
    parser.add_argument("--gradacc_steps", type=int, default=8,
                        help="Number of gradient accumulation steps")
    parser.add_argument("--num_train_epochs", type=int, default=60,
                        help="Number of training epochs")
    parser.add_argument("--eval_steps", type=int, default=100,
                        help="Steps between evaluations")
    parser.add_argument("--save_steps", type=int, default=500,
                        help="Steps between model saves")
    parser.add_argument("--warmup_ratio", type=float, default=0.1,
                        help="Ratio of total steps for warmup")
    parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup",
                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
                        help="Learning rate scheduler type")
    
    # Wandb arguments
    parser.add_argument("--wandb_project", type=str, default="cawn",
                        help="Wandb project name")
    parser.add_argument("--wandb_run_name", type=str, default=None,
                        help="Wandb run name (if None, auto-generated)")
    
    # LoRA arguments
    parser.add_argument("--use_lora", action="store_true",
                        help="Use LoRA for parameter-efficient fine-tuning")
    parser.add_argument("--lora_r", type=int, default=8,
                        help="LoRA rank parameter")
    parser.add_argument("--lora_alpha", type=int, default=16,
                        help="LoRA alpha parameter")
    parser.add_argument("--lora_dropout", type=float, default=0.05,
                        help="LoRA dropout rate")
    parser.add_argument("--lora_target_modules", type=str, default=None,
                        help="Comma-separated list of module names to apply LoRA to (if None, uses defaults)")
    
    args = parser.parse_args()
    print(args)
    
    # Define default values for comparison
    default_values = {
        "model_name": "microsoft/deberta-v3-large",
        "max_seq_length": 512,
        "learning_rate": 5e-6,
        "classifier_lr": 1e-4,
        "train_batch_size": 8,
        "eval_batch_size": 16,
        "gradacc_steps": 8,
        "num_train_epochs": 60,
        "train_ratio": 0.5,
        "test_ratio": 0.5,
        "seed": 42,
        "eval_steps": 100,
        "save_steps": 500,
        "warmup_ratio": 0.1,
        "lr_scheduler": "constant_with_warmup",
        "option_sampling_strategy": "both",
        "use_lora": False,
        "lora_r": 8,
        "lora_alpha": 16,
        "lora_dropout": 0.05,
    }
    
    # Define abbreviations for parameters
    param_abbr = {
        "model_name": "mod",
        "max_seq_length": "seq",
        "learning_rate": "lr",
        "classifier_lr": "clr",
        "train_batch_size": "bs",
        "eval_batch_size": "ebs",
        "gradacc_steps": "ga",
        "num_train_epochs": "ep",
        "train_ratio": "trr",
        "test_ratio": "ter",
        "seed": "sd",
        "eval_steps": "es",
        "save_steps": "ss",
        "freeze_base_model": "frz",
        "save_final_model": "sfm",
        "random_init": "ri",
        "only_options": "oo",
        "split_test_set": "spl",
        "warmup_ratio": "wr",
        "lr_scheduler": "sch",
        "option_sampling_strategy": "oss",
        "use_lora": "lora",
        "lora_r": "lr",
        "lora_alpha": "la",
        "lora_dropout": "ld",
    }
    
    # Generate run name components for non-default parameters
    run_components = [args.dataset_name]
    
    # Get short model name (last part of path)
    if args.model_name != default_values["model_name"]:
        model_short = args.model_name.split("/")[-1]
        run_components.append(f"{param_abbr['model_name']}-{model_short}")
    
    # Add numerical parameters if they differ from defaults
    for param, default in default_values.items():
        if param == "model_name":
            continue  # Already handled separately
        
        value = getattr(args, param)
        if value != default:
            # Format numerical values appropriately
            if isinstance(value, float):
                if value < 0.01:
                    # Use scientific notation for small values (e.g., learning rates)
                    formatted_value = f"{value:.0e}".replace("e-0", "e-")
                else:
                    formatted_value = str(value)
            else:
                formatted_value = str(value)
            
            run_components.append(f"{param_abbr[param]}-{formatted_value}")
    
    # Add boolean flags if they're True
    for param in ["freeze_base_model", "random_init", "only_options", "split_test_set", "use_lora", "save_final_model"]:
        if getattr(args, param):
            run_components.append(param_abbr[param])
    
    # Add LoRA parameters if LoRA is enabled
    if args.use_lora:
        if args.lora_r != default_values["lora_r"]:
            run_components.append(f"{param_abbr['lora_r']}-{args.lora_r}")
        if args.lora_alpha != default_values["lora_alpha"]:
            run_components.append(f"{param_abbr['lora_alpha']}-{args.lora_alpha}")
        if args.lora_dropout != default_values["lora_dropout"]:
            run_components.append(f"{param_abbr['lora_dropout']}-{args.lora_dropout}")
    
    # Generate output_dir if not provided
    if args.output_dir is None:
        args.output_dir = f"./results/cawn/{'-'.join(run_components)}"
    
    # Generate wandb_run_name if not provided
    if args.wandb_run_name is None:
        args.wandb_run_name = "-".join(run_components)
    
    # Initialize wandb
    accelerator = Accelerator()
    if accelerator.is_main_process:
        wandb.init(project=args.wandb_project, name=args.wandb_run_name)
        # Log the classifier learning rate explicitly
        if args.classifier_lr is not None:
            wandb.config.update({"classifier_lr": args.classifier_lr})
    
    # Load dataset
    train_dataset, test_dataset = load_mcq_dataset(
        dataset_name=args.dataset_name,
        num_options=args.num_options,
        train_dataset_name=args.train_dataset,
        test_dataset_name=args.test_dataset,
        train_ratio=args.train_ratio,
        test_ratio=args.test_ratio,
        seed=args.seed,
        split_test_set=args.split_test_set,
        only_options=args.only_options,
        option_sampling_strategy=args.option_sampling_strategy
    )
    
    # Parse lora_target_modules if provided
    lora_target_modules = None
    if args.lora_target_modules:
        lora_target_modules = [module.strip() for module in args.lora_target_modules.split(",")]
    
    # Train model
    results = train_mcq_classifier(
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        model_name=args.model_name,
        output_dir=args.output_dir,
        num_labels=args.num_options,
        learning_rate=args.learning_rate,
        classifier_lr=args.classifier_lr,
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
        gradient_accumulation_steps=args.gradacc_steps,
        num_train_epochs=args.num_train_epochs,
        max_seq_length=args.max_seq_length,
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        freeze_base_model=args.freeze_base_model,
        random_init=args.random_init,
        wandb_run_name=args.wandb_run_name,
        warmup_ratio=args.warmup_ratio,
        weight_decay=0.1,
        lr_scheduler_type=args.lr_scheduler,
        save_final_model=args.save_final_model,
        use_lora=args.use_lora,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        lora_target_modules=lora_target_modules,
    )
    
    # Log final results
    if accelerator.is_main_process:
        for key, value in results.items():
            wandb.log({f"final_{key}": value})
        
        print("\nTest performance:")
        print(results)
        
        wandb.finish()
    
    return results

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    results = main()