import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, roc_auc_score
import torch
import json
import wandb
from typing import Dict, List
from accelerate import Accelerator

# np.random.seed(42)
# torch.manual_seed(42)


# Initialize accelerator
accelerator = Accelerator()
    

def load_mmlu_data(split="train"):
    """Load MMLU dataset for multiple-choice classification."""
    if split == "train":
        dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train")
    elif split == "validation":
        dataset = load_dataset("cais/mmlu", "all", split="validation")
    elif split == "test":
        dataset = load_dataset("cais/mmlu", "all", split="test")
    
    # Shuffle the dataset
    dataset = dataset.shuffle(seed=42)
    
    # Format data for 4-way classification (A, B, C, D)
    formatted_data = []
    for item in dataset:
        question = item['question']
        choices = [item['choices'][i] for i in range(4)]
        
        # Create prompt with question and options
        # prompt = f"{question}\n"
        prompt = ""
        prompt += f"A. {choices[0]}\n"
        prompt += f"B. {choices[1]}\n"
        prompt += f"C. {choices[2]}\n"
        prompt += f"D. {choices[3]}"
        
        formatted_data.append({
            'text': prompt,
            'label': item['answer']  # This should be 0, 1, 2, or 3 for A, B, C, D
        })
    
    return Dataset.from_list(formatted_data)


def load_super_gqpa_data(ratio=0.5):
    """Load SuperGOPQA dataset for multiple-choice classification with at most 10 options."""
    # Load SuperGOPQA dataset
    dataset = load_dataset("m-a-p/SuperGPQA")
    
    # Get the test set and split it 50-50 for training and testing
    full_test_data = dataset["train"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    full_test_data = full_test_data.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(full_test_data) * ratio)
    train_data = full_test_data.select(range(split_idx))
    test_data = full_test_data.select(range(split_idx, len(full_test_data)))
    
    # Format data for 4-way classification (A through D)
    def format_dataset(dataset):
        formatted_data = []
        for item in dataset:
            question = item['question']
            options = item['options'].copy()
            
            # Skip if we don't have enough options
            if len(options) < 10:
                print(f"Not enough options, got {len(options)}")
                continue
            
            # Get the correct answer and its index
            correct_answer_index = ord(item['answer_letter']) - ord('A')
            correct_answer = options.pop(correct_answer_index)
            
            # Randomly sample 3 distractors from remaining options
            remaining_options = options
            if len(remaining_options) > 9:
                # Randomly select 3 distractors
                distractor_indices = np.random.choice(len(remaining_options), 9, replace=False)
                distractors = [remaining_options[i] for i in distractor_indices]
            else:
                # Use all remaining options if we have exactly 3 left
                distractors = remaining_options
            
            # Create the 4 options with the correct answer included
            final_options = distractors + [correct_answer]
            
            # Shuffle the options
            shuffled_indices = np.random.permutation(10)
            shuffled_options = [final_options[i] for i in shuffled_indices]
            
            # Find the new index of the correct answer
            new_correct_index = np.where(shuffled_indices == 9)[0][0]  # 3 is the index of correct_answer in final_options
            
            # Create prompt with question and options
            # prompt = f"{question}\n"
            prompt = ""
            for i, choice_letter in enumerate(["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]):
                prompt += f"{choice_letter}. {shuffled_options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            formatted_data.append({
                'text': prompt,
                'label': new_correct_index
            })
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset(train_data)
    test_dataset = format_dataset(test_data)
    
    # print length of train and test datasets
    print(f"Train dataset length: {len(train_dataset)}")
    print(f"Test dataset length: {len(test_dataset)}")
    
    return train_dataset, test_dataset

NUM_OPTIONS = 20

def load_mmlu_binary(ratio=0.5):
    """Load MMLU dataset for binary classification."""
    # Load MMLU dataset
    dataset = load_dataset("cais/mmlu", "all", split="test")
    
    # Shuffle the dataset with a fixed seed for reproducibility
    full_dataset = dataset.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(full_dataset) * ratio)
    train_data = full_dataset.select(range(split_idx))
    test_data = full_dataset.select(range(split_idx, len(full_dataset)))
    
    def format_dataset_binary(dataset):
        formatted_data = []
        num_positive = 0
        num_negative = 0
        
        for item in dataset:
            question = item['question']
            choices = [item['choices'][i] for i in range(4)]
            correct_answer_index = item['answer']
            
            # For each option, create a binary classification example
            for i, option in enumerate(choices):
                is_correct = (i == correct_answer_index)
                
                # Create a binary classification prompt
                prompt = f"Question: Is the following option the correct answer to this question?\n{question}\nOption: {option}\nA. True\nB. False"
                
                # Create a binary classification prompt with no question
                # prompt = f"Question: Is the following option the correct answer?\nOption: {option}\nA. True\nB. False"
                
                # Label is 0 for True (correct) and 1 for False (incorrect)
                label = 0 if is_correct else 1
                
                # Keep track of positive and negative examples
                if label == 0:
                    num_positive += 1
                else:
                    num_negative += 1
                
                formatted_data.append({
                    'text': prompt,
                    'label': label,
                    'question': question,
                    'option': option,
                    'is_correct': is_correct
                })
        
        print(f"Binary classification dataset statistics:")
        print(f"Total examples: {len(formatted_data)}")
        print(f"Positive examples (correct answers): {num_positive} ({num_positive/len(formatted_data):.2%})")
        print(f"Negative examples (incorrect answers): {num_negative} ({num_negative/len(formatted_data):.2%})")
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset_binary(train_data)
    test_dataset = format_dataset_binary(test_data)
    
    print(f"Binary train dataset length: {len(train_dataset)}")
    print(f"Binary test dataset length: {len(test_dataset)}")
    
    return train_dataset, test_dataset

def load_hle_binary(ratio=0.5):
    """Load Humanity's Last Exam dataset for binary classification."""

    dataset = load_dataset("cais/hle", split="test")
    
    # Only keep rows with answer_type == "multipleChoice"
    dataset = dataset.filter(lambda x: x["answer_type"] == "multipleChoice")
    
    # Shuffle the dataset with a fixed seed for reproducibility
    full_dataset = dataset.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(full_dataset) * ratio)
    train_data = full_dataset.select(range(split_idx))
    test_data = full_dataset.select(range(split_idx, len(full_dataset)))
    
    def format_dataset_binary(dataset):
        formatted_data = []
        num_positive = 0
        num_negative = 0
        
        for item in dataset:
            question_text = item['question']
            answer = item['answer']
            
            # Extract options from the question text
            if "Answer Choices:" in question_text:
                # Split the question to get the part with answer choices
                parts = question_text.split("Answer Choices:")
                only_question = parts[0].strip()
                options_text = parts[1].strip()
                
                # Extract options (assuming they're formatted as A. Option, B. Option, etc.)
                options = []
                option_letters = []
                for line in options_text.split('\n'):
                    if line.strip() and line[0].isalpha() and line[1] == '.':
                        option_letter = line[0]
                        option_text = line[2:].strip()
                        options.append(option_text)
                        option_letters.append(option_letter)
                
                # Get the correct answer (assuming it's in format like "A" or "B")
                correct_answer_index = ord(answer.upper()) - ord('A')
                if correct_answer_index < 0 or correct_answer_index >= len(options):
                    print(f"Invalid answer index: {answer}")
                    continue
                
                # For each option, create a binary classification example
                for i, option in enumerate(options):
                    is_correct = (i == correct_answer_index)
                    
                    # Create a binary classification prompt
                    prompt = f"Question: Is the following option the correct answer to this question?\n{only_question}\nOption: {option}\nA. True\nB. False"
                    
                    # Label is 0 for True (correct) and 1 for False (incorrect)
                    label = 0 if is_correct else 1
                    
                    # Keep track of positive and negative examples
                    if label == 0:
                        num_positive += 1
                    else:
                        num_negative += 1
                    
                    formatted_data.append({
                        'text': prompt,
                        'label': label,
                        'question': only_question,
                        'option': option,
                        'is_correct': is_correct
                    })
            else:
                assert False, "Question does not contain answer choices"
        
        print(f"Binary classification dataset statistics:")
        print(f"Total examples: {len(formatted_data)}")
        print(f"Positive examples (correct answers): {num_positive} ({num_positive/len(formatted_data):.2%})")
        print(f"Negative examples (incorrect answers): {num_negative} ({num_negative/len(formatted_data):.2%})")
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset_binary(train_data)
    test_dataset = format_dataset_binary(test_data)
    
    print(f"Binary train dataset length: {len(train_dataset)}")
    print(f"Binary test dataset length: {len(test_dataset)}")
    
    return train_dataset, test_dataset


def load_mmlu_pro_binary(ratio=0.5):
    """Load MMLU-Pro dataset for binary classification."""
    # Load MMLU-Pro dataset
    dataset = load_dataset("TIGER-Lab/MMLU-Pro")
    
    # Get the test set and split it 50-50 for training and testing
    test_data = dataset["test"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    full_test_data = test_data.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(full_test_data) * ratio)
    train_data = full_test_data.select(range(split_idx))
    test_data = full_test_data.select(range(split_idx, len(full_test_data)))
    
    def format_dataset_binary(dataset):
        formatted_data = []
        num_positive = 0
        num_negative = 0
        for item in dataset:
            question = item['question']
            options = item['options'].copy()
            
            # Get the correct answer and its index
            correct_answer_index = item['answer_index']
            
            # Randomly shuffle options and track the correct answer index
            correct_option = options[correct_answer_index]
            indices = np.arange(len(options))
            shuffled_indices = np.random.permutation(indices)
            shuffled_options = [options[i] for i in shuffled_indices]
            
            # Update options and find new index of correct answer
            options = list(shuffled_options)
            # Fix: numpy arrays don't have .index() method, convert to list or find index manually
            correct_answer_index = np.where(shuffled_indices == correct_answer_index)[0][0]
            
            incorrect_count = 0
            
            # For each option, create a binary classification example
            for i, option in enumerate(options):
                is_correct = (i == correct_answer_index)
                incorrect_count += 1 if not is_correct else 0
                
                # Create a binary classification prompt
                # prompt = f"Question: Is the following option the correct answer to this question?\n{question}\nOption: {option}\nA. True\nB. False"
                
                # Create a binary classification prompt
                prompt = f"Question: Is the following option the correct answer?\nOption: {option}\nA. True\nB. False"
                
                if not is_correct and incorrect_count > 1:
                    continue
                
                # Label is 0 for True (correct) and 1 for False (incorrect)
                label = 0 if is_correct else 1
                
                # Keep track of positive and negative examples
                if label == 0:
                    num_positive += 1
                else:
                    num_negative += 1
                
                formatted_data.append({
                    'text': prompt,
                    'label': label,
                    'question': question,
                    'option': option,
                    'is_correct': is_correct
                })
        
        print(f"Binary classification dataset statistics:")
        print(f"Total examples: {len(formatted_data)}")
        print(f"Positive examples (correct answers): {num_positive} ({num_positive/len(formatted_data):.2%})")
        print(f"Negative examples (incorrect answers): {num_negative} ({num_negative/len(formatted_data):.2%})")
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset_binary(train_data)
    test_dataset = format_dataset_binary(test_data)
    
    print(f"Binary train dataset length: {len(train_dataset)}")
    print(f"Binary test dataset length: {len(test_dataset)}")
    
    return train_dataset, test_dataset


def load_super_gqpa_binary(ratio=0.5):
    """Load SuperGOPQA dataset for binary classification."""
    # Load SuperGOPQA dataset
    dataset = load_dataset("m-a-p/SuperGPQA")
    
    # Get the test set and split it 50-50 for training and testing
    full_test_data = dataset["train"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    full_test_data = full_test_data.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(full_test_data) * ratio)
    train_data = full_test_data.select(range(split_idx))
    test_data = full_test_data.select(range(split_idx, len(full_test_data)))
    
    def format_dataset_binary(dataset):
        formatted_data = []
        num_positive = 0
        num_negative = 0
        for item in dataset:
            question = item['question']
            options = item['options'].copy()
            
            # Get the correct answer and its index
            correct_answer_index = ord(item['answer_letter']) - ord('A')
            
            # Randomly shuffle options and track the correct answer index
            correct_option = options[correct_answer_index]
            indices = np.arange(len(options))
            shuffled_indices = np.random.permutation(indices)
            shuffled_options = [options[i] for i in shuffled_indices]
            
            # Update options and find new index of correct answer
            options = list(shuffled_options)
            # Fix: numpy arrays don't have .index() method, convert to list or find index manually
            correct_answer_index = np.where(shuffled_indices == correct_answer_index)[0][0]
            
            incorrect_count = 0
            
            # For each option, create a binary classification example
            for i, option in enumerate(options):
                is_correct = (i == correct_answer_index)
                incorrect_count += 1 if not is_correct else 0
                
                # Create a binary classification prompt
                # prompt = f"Question: Is the following option the correct answer to this question?\n{question}\nOption: {option}\nA. True\nB. False"
                
                # Create a binary classification prompt
                prompt = f"Question: Is the following option the correct answer?\nOption: {option}\nA. True\nB. False"
                
                if not is_correct and incorrect_count > 1:
                    continue
                
                # Label is 0 for True (correct) and 1 for False (incorrect)
                label = 0 if is_correct else 1
                
                # Keep track of positive and negative examples
                if label == 0:
                    num_positive += 1
                else:
                    num_negative += 1
                
                formatted_data.append({
                    'text': prompt,
                    'label': label,
                    'question': question,
                    'option': option,
                    'is_correct': is_correct
                })
                
        print(f"Binary classification dataset statistics:")
        print(f"Total examples: {len(formatted_data)}")
        print(f"Positive examples (correct answers): {num_positive} ({num_positive/len(formatted_data):.2%})")
        print(f"Negative examples (incorrect answers): {num_negative} ({num_negative/len(formatted_data):.2%})")
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset_binary(train_data)
    test_dataset = format_dataset_binary(test_data)
    
    print(f"Binary train dataset length: {len(train_dataset)}")
    print(f"Binary test dataset length: {len(test_dataset)}")
    
    return train_dataset, test_dataset


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    
    # For binary classification, also compute ROC AUC
    # Use predicted probabilities for the positive class
    try:
        probs = torch.softmax(torch.tensor(pred.predictions), dim=1)[:, 0].numpy()
        auc = roc_auc_score(labels, probs)
    except:
        auc = 0.0
    
    # Also compute the report for per-class metrics
    report = classification_report(labels, preds, output_dict=True)
    
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auc': auc,
        'class_0_precision': report['0']['precision'] if '0' in report else 0,
        'class_0_recall': report['0']['recall'] if '0' in report else 0,
        'class_1_precision': report['1']['precision'] if '1' in report else 0,
        'class_1_recall': report['1']['recall'] if '1' in report else 0
    }

def train_model(
    train_dataset, 
    eval_dataset, 
    model_name: str,
    output_dir: str,
    class_weights: torch.Tensor = None
) -> Dict[str, float]:
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,  # For binary classification (True/False)
        problem_type="single_label_classification",
    )

    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # For Llama models, ensure pad_token_id is properly set
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Enable gradient checkpointing before the model is wrapped
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    
    if hasattr(model, 'classifier'):
        model.classifier.weight.data.normal_(mean=0.0, std=0.02)
        model.classifier.bias.data.zero_()
    elif hasattr(model, 'score'):
        model.score.weight.data.normal_(mean=0.0, std=0.02)
        
    
    # Check which attribute exists for the classification head
    classification_head_attr = None
    if hasattr(model, 'classifier'):
        classification_head_attr = 'classifier'
    elif hasattr(model, 'score'):
        classification_head_attr = 'score'
    
    # # Freeze all layers except the classification head
    # for param in model.parameters():
    #     param.requires_grad = False
    
    # # Unfreeze only the classification head
    # if classification_head_attr:
    #     for param in getattr(model, classification_head_attr).parameters():
    #         param.requires_grad = True
    # else:
    #     # If no classification head is found, unfreeze the last layer
    #     print("Warning: No classifier or score attribute found. Unfreezing all parameters.")
    #     for param in model.parameters():
    #         param.requires_grad = True
    
    # Print number of trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params}")
    print(f"Total parameters: {total_params}")
    
    def tokenize_function(examples):
        # Ensure we're returning the expected keys for the model
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=256
        )

    # Tokenize datasets
    tokenized_train = train_dataset.map(
        tokenize_function, 
        batched=True,
        remove_columns=['text']
    )
    tokenized_eval = eval_dataset.map(
        tokenize_function, 
        batched=True,
        remove_columns=['text']
    )
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=5e-6,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=1,
        num_train_epochs=100,  # Reduced from 100 since we'll have more examples
        weight_decay=0.1,
        eval_strategy="steps",
        eval_steps=100,
        metric_for_best_model="f1",  # Changed to F1 for imbalanced classification
        save_total_limit=0,
        run_name="binary-classifier",
        warmup_ratio=0.1,
        logging_steps=10,
        report_to="wandb" if accelerator.is_main_process else "none",
        save_strategy="no",
        fp16=True,
        lr_scheduler_type="constant_with_warmup",
    )
    
    # Initialize optimizer with different parameters based on model architecture
    if "llama" in model_name.lower() or "qwen" in model_name.lower():
        # For Llama models
        optimizer = torch.optim.AdamW(
            [
                {"params": [p for n, p in model.named_parameters() if "score" not in n], "lr": 5e-6},
                {"params": [p for n, p in model.named_parameters() if "score" in n], "lr": 5e-5}
            ],
            weight_decay=0.1
        )
    else:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=5e-6,
            weight_decay=0.1
        )
    
    # Create loss function with class weights if provided
    if class_weights is not None:
        # Don't move to device here - it's too early, model device might change later
        # Don't create the loss function yet either, we'll do it in compute_loss
        
        # Custom trainer with weighted loss
        class WeightedLossTrainer(Trainer):
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.class_weights = class_weights
                
            def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
                labels = inputs.pop("labels")
                outputs = model(**inputs)
                logits = outputs.logits
                
                # Move class weights to the same device as logits AND convert to the same dtype
                weights = self.class_weights.to(device=logits.device, dtype=logits.dtype)
                
                # Create the loss function with the properly converted weights
                device_loss_fct = torch.nn.CrossEntropyLoss(weight=weights)
                
                loss = device_loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
                return (loss, outputs) if return_outputs else loss
        
        trainer_class = WeightedLossTrainer
    else:
        trainer_class = Trainer

    # Initialize trainer
    trainer = trainer_class(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, None)  # Custom optimizer, default scheduler
    )

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def main():
    RATIO = 0.5
    
    # Load dataset for binary classification
    # train_dataset, test_dataset = load_hle_binary(ratio=RATIO)
    # train_dataset, test_dataset = load_mmlu_pro_binary(ratio=RATIO)
    # train_dataset, test_dataset = load_mmlu_binary(ratio=RATIO)
    train_dataset, test_dataset = load_super_gqpa_binary(ratio=RATIO)
    
    # Calculate class weights to handle imbalance
    # Count frequency of each class
    class_counts = {}
    for example in train_dataset:
        label = example['label']
        if label not in class_counts:
            class_counts[label] = 0
        class_counts[label] += 1
    
    # Calculate weights inversely proportional to class frequencies
    total_samples = len(train_dataset)
    class_weights = None
    if 0 in class_counts and 1 in class_counts:
        num_classes = 2
        weights = [total_samples / (num_classes * count) for count in [class_counts[0], class_counts[1]]]
        class_weights = torch.tensor(weights)
        print(f"Using class weights: {weights}")

    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="binary-classifier", name=f"supergpqa-binary-noQ-qwen3-4b-{RATIO*100}percent")
    
    # Model selection
    # model_name = "microsoft/deberta-v3-large"
    model_name = "/fast/XXXX-3/models/Qwen3-4B"
    
    # Train on combined train and validation data, evaluate on test
    if accelerator.is_main_process:
        print("Training binary classifier model...")
    
    # Train with class weights to handle imbalance
    test_results = train_model(
        train_dataset,
        test_dataset,
        model_name,
        output_dir="./results/binary_classifier",
        class_weights=class_weights
    )
    
    # Log test results
    if accelerator.is_main_process:
        wandb.log({
            "test_accuracy": test_results['eval_accuracy'],
            "test_f1": test_results['eval_f1'],
            "test_precision": test_results['eval_precision'],
            "test_recall": test_results['eval_recall'],
            "test_auc": test_results['eval_auc'] if 'eval_auc' in test_results else 0,
            "test_class_0_precision": test_results['eval_class_0_precision'] if 'eval_class_0_precision' in test_results else 0,
            "test_class_0_recall": test_results['eval_class_0_recall'] if 'eval_class_0_recall' in test_results else 0,
            "test_class_1_precision": test_results['eval_class_1_precision'] if 'eval_class_1_precision' in test_results else 0,
            "test_class_1_recall": test_results['eval_class_1_recall'] if 'eval_class_1_recall' in test_results else 0
        })
        
        print("\nTest performance:")
        print(test_results)
        
        wandb.finish()
    
    return test_results

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    test_results = main()
    print("\nTest Performance:")
    print(test_results)