import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import torch
import json
import wandb
from typing import Dict, List
from accelerate import Accelerator

# np.random.seed(42)
# torch.manual_seed(42)

def load_mmlu_data2(split="train"):
    """Load MMLU dataset for multiple-choice classification."""
    if split == "train":
        dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train")
    elif split == "validation":
        dataset = load_dataset("cais/mmlu", "all", split="validation")
    elif split == "test":
        dataset = load_dataset("cais/mmlu", "all", split="test")
    
    # Shuffle the dataset
    dataset = dataset.shuffle(seed=42)
    
    # Format data for 4-way classification (A, B, C, D)
    formatted_data = []
    for item in dataset:
        question = item['question']
        choices = [item['choices'][i] for i in range(4)]
        
        options = list(choices)
        # # DONT SHUFFLE THE CHOICES
        # # choices = np.random.permutation(choices)
        
        # correct_answer_index = item['answer']
        # correct_answer = options.pop(correct_answer_index)
        
        # # Randomly sample 3 distractors from remaining options
        # remaining_options = options
        # if len(remaining_options) > 3:
        #     # Randomly select 3 distractors
        #     distractor_indices = np.random.choice(len(remaining_options), 3, replace=False)
        #     distractors = [remaining_options[i] for i in distractor_indices]
        #     assert False
        # else:
        #     # Use all remaining options if we have exactly 3 left
        #     distractors = remaining_options
        
        # # Create the 4 options with the correct answer included
        # final_options = distractors + [correct_answer]
        
        # # Shuffle the options
        # shuffled_indices = np.random.permutation(4)
        # shuffled_options = [final_options[i] for i in shuffled_indices]
        
        # # Find the new index of the correct answer
        # new_correct_index = np.where(shuffled_indices == 3)[0][0]  # 3 is the index of correct_answer in final_options
        # # print("Shuffled options: ", shuffled_options)
        # # print("Correct answer: ", correct_answer)
        # # print("New correct answer: ", shuffled_options[new_correct_index])
        # assert shuffled_options[new_correct_index] == choices[item['answer']]
            
        # choices = shuffled_options
        # Create prompt with question and options
        # prompt = f"{question}\n"
        prompt = ""
        prompt += f"A. {choices[0]}\n"
        prompt += f"B. {choices[1]}\n"
        prompt += f"C. {choices[2]}\n"
        prompt += f"D. {choices[3]}"
        
        formatted_data.append({
            'text': prompt,
            'label': item['answer'] # new_correct_index # item['answer']  # This should be 0, 1, 2, or 3 for A, B, C, D
        })
    
    return Dataset.from_list(formatted_data)


def load_mmlu_data(split="train"):
    """Load MMLU dataset for multiple-choice classification."""
    if split == "train":
        dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train")
    elif split == "validation":
        dataset = load_dataset("cais/mmlu", "all", split="validation")
    elif split == "test":
        dataset = load_dataset("cais/mmlu", "all", split="test")
    
    # Shuffle the dataset
    dataset = dataset.shuffle(seed=42)
    
    # Format data for 4-way classification (A, B, C, D)
    formatted_data = []
    for item in dataset:
        question = item['question']
        choices = [item['choices'][i] for i in range(4)]
        
        # Create prompt with question and options
        prompt = f"{question}\n"
        # prompt = ""
        prompt += f"A. {choices[0]}\n"
        prompt += f"B. {choices[1]}\n"
        prompt += f"C. {choices[2]}\n"
        prompt += f"D. {choices[3]}"
        
        formatted_data.append({
            'text': prompt,
            'label': item['answer']  # This should be 0, 1, 2, or 3 for A, B, C, D
        })
    
    return Dataset.from_list(formatted_data)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def train_model(
    train_dataset, 
    eval_dataset, 
    model_name: str,
    output_dir: str
) -> Dict[str, float]:
    
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=4,  # For A, B, C, D options
        problem_type="single_label_classification"
    )

    # Reinitialize the classification head for better learning
    # model.classifier.weight.data.normal_(mean=0.0, std=0.02)
    # model.classifier.bias.data.zero_()
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=512
        )

    # Tokenize datasets
    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_eval = eval_dataset.map(tokenize_function, batched=True)

    # Calculate class weights
    labels = train_dataset['label']
    class_counts = np.bincount(labels)
    total_samples = len(labels)
    class_weights = torch.FloatTensor([total_samples / (len(class_counts) * count) for count in class_counts])

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=8,
        num_train_epochs=200,
        weight_decay=0.1,
        eval_strategy="steps",
        eval_steps=100,
        metric_for_best_model="accuracy",
        save_total_limit=1,
        run_name="mmlu-mcq",
        warmup_ratio=0.1,
        logging_steps=10,
        report_to="wandb",
        save_strategy="no",
        lr_scheduler_type="linear",
        # lr_scheduler_type="constant_with_warmup"  # Changed from "linear" to "constant_with_warmup"
    )

    # Create custom trainer with weighted loss
    class WeightedTrainer(Trainer):
        def __init__(self, class_weights, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.class_weights = class_weights.to(self.args.device)

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            labels = inputs.pop("labels")
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Apply weighted cross entropy loss
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
            
            return (loss, outputs) if return_outputs else loss

    # Initialize optimizer with different parameters
    optimizer = torch.optim.AdamW(
        [
            {"params": model.deberta.parameters(), "lr": 5e-6},  # Lower LR for pretrained layers
            {"params": model.classifier.parameters(), "lr": 5e-6}  # Higher LR for classification head
        ],
        weight_decay=0.01
    )

    # Initialize trainer with custom optimizer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, None)  # Custom optimizer, default scheduler
    )

    # Prepare everything with accelerator
    trainer = accelerator.prepare(trainer)

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def main():
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Load MMLU datasets
    train_dataset = load_mmlu_data(split="train")
    val_dataset = load_mmlu_data(split="validation")
    test_dataset = load_mmlu_data(split="test")
    
    # train_dataset = train_dataset.select(range(10000))
    
    # split test into train and test
    test_dataset = test_dataset.train_test_split(test_size=0.6, seed=42)
    
    train_dataset = test_dataset["train"]
    test_dataset = test_dataset["test"]
    
    # # Assert each row in train_dataset is not in test_dataset
    # for row in train_dataset:
    #     if row in test_dataset:
    #         print(row)
        
    # Similar test is not in train dataset
    # for row in test_dataset:
    #     assert row not in train_dataset
    
    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="mcq-classifier", name="mmlu-try2-60percent")
    
    # Use a strong model for multiple-choice tasks
    model_name = "google/flan-t5-xl"  # Alternative: "microsoft/deberta-v3-large"
    model_name = "microsoft/deberta-v3-large"
    
    # Train on combined train and validation data, evaluate on test
    if accelerator.is_main_process:
        print("Training model on training data...")
    
    # First train on training data and evaluate on validation data
    test_results = train_model(
        train_dataset,
        test_dataset,
        model_name,
        output_dir="./results/mmlu/validation"
    )
    
    # Log validation results
    # if accelerator.is_main_process:
    #     wandb.log({
    #         "val_accuracy": val_results['eval_accuracy'],
    #         "val_f1": val_results['eval_f1'],
    #         "val_precision": val_results['eval_precision'],
    #         "val_recall": val_results['eval_recall']
    #     })
        
    #     print("\nValidation performance:")
    #     print(val_results)
    
    # Then train on combined train+validation and evaluate on test
    # if accelerator.is_main_process:
    #     print("\nTraining final model on combined train and validation data...")
    
    # combined_train = concatenate_datasets([train_dataset, val_dataset])
    
    # test_results = train_model(
    #     combined_train,
    #     test_dataset,
    #     model_name,
    #     output_dir="./results/mmlu/test"
    # )
    
    # Log test results
    if accelerator.is_main_process:
        wandb.log({
            "test_accuracy": test_results['eval_accuracy'],
            "test_f1": test_results['eval_f1'],
            "test_precision": test_results['eval_precision'],
            "test_recall": test_results['eval_recall']
        })
        
        print("\nTest performance:")
        print(test_results)
        
        wandb.finish()
    
    return test_results

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    test_results = main()
    print("\nTest Performance:")
    print(test_results)