import os
# Disable tokenizers parallelism which can cause deadlocks
# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# os.environ["HF_HOME"] = "./cache/huggingface"

# os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface_datasets"
# os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"  # Use a local directory instead of NFS

from tqdm import tqdm
import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import torch
import json
import wandb
from typing import Dict, List
from accelerate import Accelerator
from peft import (
    LoraConfig,
    TaskType,
    get_peft_model,
)

# np.random.seed(42)
# torch.manual_seed(42)

# Initialize accelerator
accelerator = Accelerator()
    
# Fix seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

NUM_OPTIONS = 10

def load_mmlu_pro_data():
    """Load MMLU-Pro dataset for multiple-choice classification with 10 options."""
    # Load MMLU-Pro dataset
    dataset = load_dataset("TIGER-Lab/MMLU-Pro")
    
    # Get the test set and split it 50-50 for training and testing
    test_data = dataset["test"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    test_data = test_data.shuffle(seed=42)
    
    # Split into train and test
    split_idx = int(len(test_data) * 0.8)
    train_data = test_data.select(range(split_idx))
    test_data = test_data.select(range(split_idx, len(test_data)))
    correct_indices = {}
    # Format data for 4-way classification (A through D)
    def format_dataset(dataset):
        formatted_data = []
        for item in dataset:
            question = item['question']
            options = item['options'].copy()
            
            # Skip if we don't have enough options
            # if len(options) < 4:
            #     print(f"Not enough options, got {len(options)}")
            #     continue
            
            # Get the correct answer and its index
            correct_answer_index = item['answer_index']
            correct_answer = options.pop(correct_answer_index)
            
            # Randomly sample 3 distractors from remaining options
            remaining_options = options
            if len(remaining_options) > NUM_OPTIONS - 1:
                # Randomly select 3 distractors
                distractor_indices = np.random.choice(len(remaining_options), NUM_OPTIONS - 1, replace=False)
                distractors = [remaining_options[i] for i in distractor_indices]
            else:
                # Use all remaining options if we have exactly 3 left
                distractors = remaining_options
            
            # Create the 4 options with the correct answer included
            final_options = distractors + [correct_answer]
            
            # Shuffle the options
            shuffled_indices = np.random.permutation(len(final_options))
            shuffled_options = [final_options[i] for i in shuffled_indices]
            
            # Find the new index of the correct answer
            new_correct_index = np.where(shuffled_indices == len(final_options) - 1)[0][0]  # 3 is the index of correct_answer in final_options
            
            if new_correct_index not in correct_indices:
                correct_indices[new_correct_index] = 0
            correct_indices[new_correct_index] += 1
            
            # Create prompt with question and options
            # prompt = f"{question}\n"
            prompt = ""
            for i in range(len(final_options)):
                prompt += f"{chr(65 + i)}. {shuffled_options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            formatted_data.append({
                'text': prompt,
                'label': new_correct_index
            })
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset(train_data)
    test_dataset = format_dataset(test_data)
    
    # print length of train and test datasets
    print(f"Train dataset length: {len(train_dataset)}")
    print(f"Test dataset length: {len(test_dataset)}")
    print(f"Correct indices distribution: {correct_indices}")
    
    return train_dataset, test_dataset


def load_super_gqpa_data():
    """Load SuperGOPQA dataset for multiple-choice classification with at most 10 options."""
    # Load SuperGOPQA dataset
    dataset = load_dataset("m-a-p/SuperGPQA")
    
    # Get the test set and split it 50-50 for training and testing
    test_data = dataset["train"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    test_data = test_data.shuffle(seed=42)
    
    # Split into train and test
    split_idx = len(test_data) // 2
    train_data = test_data.select(range(split_idx))
    test_data = test_data.select(range(split_idx, len(test_data)))
    correct_indices = {}
    # Format data for 4-way classification (A through D)
    def format_dataset(dataset):
        formatted_data = []
        for item in dataset:
            question = item['question']
            options = item['options'].copy()
            
            # Skip if we don't have enough options
            # if len(options) < 4:
            #     print(f"Not enough options, got {len(options)}")
            #     continue
            
            # Get the correct answer and its index
            correct_answer_index = ord(item['answer_letter']) - ord('A')
            correct_answer = options.pop(correct_answer_index)
            
            # Randomly sample 3 distractors from remaining options
            remaining_options = options
            if len(remaining_options) > NUM_OPTIONS - 1:
                # Randomly select 3 distractors
                distractor_indices = np.random.choice(len(remaining_options), NUM_OPTIONS - 1, replace=False)
                distractors = [remaining_options[i] for i in distractor_indices]
            else:
                # Use all remaining options if we have exactly 3 left
                distractors = remaining_options
            
            # Create the 4 options with the correct answer included
            final_options = distractors + [correct_answer]
            
            # Shuffle the options
            shuffled_indices = np.random.permutation(len(final_options))
            shuffled_options = [final_options[i] for i in shuffled_indices]
            
            # Find the new index of the correct answer
            new_correct_index = np.where(shuffled_indices == len(final_options) - 1)[0][0]  # 3 is the index of correct_answer in final_options
            
            if new_correct_index not in correct_indices:
                correct_indices[new_correct_index] = 0
            correct_indices[new_correct_index] += 1
            
            # Create prompt with question and options
            prompt = f"{question}\n"
            # prompt = ""
            for i in range(len(final_options)):
                prompt += f"{chr(65 + i)}. {shuffled_options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            formatted_data.append({
                'text': prompt,
                'label': new_correct_index
            })
        
        return Dataset.from_list(formatted_data)
    
    train_dataset = format_dataset(train_data)
    test_dataset = format_dataset(test_data)
    
    # print length of train and test datasets
    print(f"Train dataset length SUPER_GPQA: {len(train_dataset)}")
    print(f"Test dataset length SUPER_GPQA: {len(test_dataset)}")
    print(f"Correct indices distribution SUPER_GPQA: {correct_indices}")
    
    return train_dataset, test_dataset


def load_mmlu_data(split="train"):
    """Load MMLU dataset for multiple-choice classification."""
    if split == "train":
        dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train")
    elif split == "validation":
        dataset = load_dataset("cais/mmlu", "all", split="validation")
    elif split == "test":
        dataset = load_dataset("cais/mmlu", "all", split="test")
    
    correct_indices = {}
    # Format data for 4-way classification (A, B, C, D)
    formatted_data = []
    if split == "train":
        incorrect_options = [] # {sub:[] for sub in validation_dataset['subject']}
        for item in dataset:
            options = item['choices']
            options.pop(item['answer'])
            incorrect_options.extend(options)
            
        # Shuffle incorrect options
        np.random.shuffle(incorrect_options)
        
        incorrect_options = list(set(incorrect_options))
        incorrect_options = incorrect_options[:10000]
        # for sub in incorrect_options:
        #     print(f"Subject: {sub}, Number of incorrect options: {len(incorrect_options[sub])}")
            
    for item in tqdm(dataset):
        question = item['question']
        choices = [item['choices'][i] for i in range(4)]
        correct_answer_idx = item['answer']
        
        if split != "train":
            options = choices.copy()
            prompt = f"{question}\n"
            # prompt = ""
            for i in range(4):
                prompt += f"{chr(65 + i)}. {options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            formatted_data.append({
                'text': prompt,
                'label': item['answer']  # This should be 0, 1, 2, or 3 for A, B, C, D
            })
        else:
            options = choices.copy()
            extra_options = np.random.choice(incorrect_options, size=NUM_OPTIONS - 4, replace=False)
            options.extend(extra_options)
        
            # Shuffle options
            np.random.shuffle(options)
            
            # Get the correct answer and its index
            correct_answer_index = item['answer']
            correct_answer = options.pop(correct_answer_index)
        
            distractors = options
            # Create the 4 options with the correct answer included
            final_options = distractors + [correct_answer]
            
            # Shuffle the options
            shuffled_indices = np.random.permutation(len(final_options))
            shuffled_options = [final_options[i] for i in shuffled_indices]
            
            # Find the new index of the correct answer
            new_correct_index = np.where(shuffled_indices == len(final_options) - 1)[0][0]  # 3 is the index of correct_answer in final_options
            
            if new_correct_index not in correct_indices:
                correct_indices[new_correct_index] = 0
            correct_indices[new_correct_index] += 1
            
            # Create prompt with question and options
            prompt = f"{question}\n"
            # prompt = ""
            for i in range(len(final_options)):
                prompt += f"{chr(65 + i)}. {shuffled_options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            formatted_data.append({
                'text': prompt,
                'label': new_correct_index
            })
            
    print(f"Correct indices distribution MMLU {split}: {correct_indices}")
    
    return Dataset.from_list(formatted_data)


def load_gqpa_data():
    """Load GQPA dataset for multiple-choice classification."""
    dataset = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
    
    # Shuffle the dataset with a fixed seed for reproducibility
    dataset = dataset.shuffle(seed=42)
    
    # Format data for 4-way classification (A, B, C, D)
    formatted_data = []
    # print(dataset.column_names)
    distribution = {0:0, 1:0, 2:0, 3:0}
    for item in dataset:
        # print(item)
        question = item['Question']
        options = [item['Incorrect Answer 1'], item['Incorrect Answer 2'], item['Incorrect Answer 3'], item['Correct Answer']]
        
        # Randomly shuffle the options np.shuffle(options)
        np.random.shuffle(options)
        
        answer = item['Correct Answer']
        answer_index = options.index(answer)
        distribution[answer_index] += 1
        # Create prompt with question and options
        prompt = f"{question}\n"
        # prompt = ""
        for i, choice_letter in enumerate(["A", "B", "C", "D"]):
            prompt += f"{choice_letter}. {options[i]}\n"
        prompt = prompt.strip()  # Remove trailing newline
        
        # print(question)
        # print(prompt)
        # print(answer_index)
        # print(answer)
        # print("--------------------------------")
        
        formatted_data.append({
            'text': prompt,
            'label': answer_index
        })
    
    print("GQPA answer distribution: ", distribution)
    return Dataset.from_list(formatted_data)

def load_hle():
    """Load Humanity's Last Exam dataset for multiple-choice classification."""

    dataset = load_dataset("cais/hle", split="test")
    
    # Only keep rows with answer_type == "multipleChoice"
    dataset = dataset.filter(lambda x: x["answer_type"] == "multipleChoice")
    
    formatted_data = []
    correct_indices = {}
    
    for item in dataset:
        question_text = item['question']
        answer = item['answer']
        
        # Extract options from the question text
        # Find where the answer choices begin
        if "Answer Choices:" in question_text:
            # Split the question to get the part with answer choices
            parts = question_text.split("Answer Choices:")
            only_question = parts[0].strip()
            options_text = parts[1].strip()
            
            # Extract options (assuming they're formatted as A. Option, B. Option, etc.)
            options = []
            for line in options_text.split('\n'):
                if line.strip() and line[0].isalpha() and line[1] == '.':
                    option_text = line[2:].strip()
                    options.append(option_text)
            
            # Skip if we don't have enough options
            # if len(options) < 4:
            #     print(f"Not enough options, got {len(options)}")
            #     continue
            
            # Get the correct answer (assuming it's in format like "A" or "B")
            correct_answer_index = ord(answer.upper()) - ord('A')
            if correct_answer_index < 0 or correct_answer_index >= len(options):
                print(f"Invalid answer index: {answer}")
                continue
                
            correct_answer = options.pop(correct_answer_index)
            
            # Randomly sample 3 distractors from remaining options
            remaining_options = options
            if len(remaining_options) > NUM_OPTIONS - 1:
                # Randomly select 3 distractors
                distractor_indices = np.random.choice(len(remaining_options), NUM_OPTIONS - 1, replace=False)
                distractors = [remaining_options[i] for i in distractor_indices]
            else:
                # Use all remaining options if we have exactly 3 left
                distractors = remaining_options
            
            # Create the 4 options with the correct answer included
            final_options = distractors + [correct_answer]
            
            # Shuffle the options
            shuffled_indices = np.random.permutation(len(final_options))
            shuffled_options = [final_options[i] for i in shuffled_indices]
            
            last_index = len(final_options) - 1
            # Find the new index of the correct answer
            new_correct_index = np.where(shuffled_indices == last_index)[0][0]  # 3 is the index of correct_answer in final_options
            
            if new_correct_index not in correct_indices:
                correct_indices[new_correct_index] = 0
            correct_indices[new_correct_index] += 1
            
            # Create prompt with options only (no question)
            prompt = f"{only_question}\n"
            prompt += "Answer Choices:\n"
            # prompt = ""
            for i in range(len(final_options)):
                prompt += f"{chr(65 + i)}. {shuffled_options[i]}\n"
            prompt = prompt.strip()  # Remove trailing newline
            
            # print("Question: ", question_text)
            # print("Prompt: ", prompt)
            # print("Answer index: ", new_correct_index)
            # print("Answer: ", answer)
            # print("--------------------------------")
            
            formatted_data.append({
                'text': prompt,
                'label': new_correct_index
            })
        else :
            assert False, "Question does not contain answer choices"
    
    print(f"HLE dataset length: {len(formatted_data)}")
    print(f"HLE correct indices distribution: {correct_indices}")
    
    return Dataset.from_list(formatted_data)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }



class MultiDatasetTrainer(Trainer):
    def __init__(self, eval_datasets=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.eval_datasets = eval_datasets or {}
        
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        # First, run evaluation on the primary dataset
        results = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        # results = {}
        
        # Then evaluate on all additional datasets
        all_results = {metric_key_prefix: results}
        
        for dataset_name, dataset in self.eval_datasets.items():
            dataset_results = super().evaluate(dataset, ignore_keys, metric_key_prefix=dataset_name)
            all_results[dataset_name] = dataset_results
            
            # Update the main results dictionary with the dataset-specific metrics
            for key, value in dataset_results.items():
                results[key] = value
        
        return results


def train_model(
    train_dataset, 
    eval_datasets: Dict[str, Dataset], 
    model_name: str,
    output_dir: str,
    use_lora: bool = True,  # Added parameter for LoRA
    local_model_path: str = None  # Added parameter for local model path
) -> Dict[str, float]:
    
    # Initialize tokenizer and model
    if local_model_path:
        print(f"Loading model from local path: {local_model_path}")
        tokenizer = AutoTokenizer.from_pretrained(local_model_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load model without quantization
    if local_model_path:
        model = AutoModelForSequenceClassification.from_pretrained(
            local_model_path,
            num_labels=NUM_OPTIONS,
            problem_type="single_label_classification",
            # device_map="auto",  # Automatically distribute model across available GPUs
            torch_dtype=torch.bfloat16,  # Use bfloat16 for better numerical stability
        )
    else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=NUM_OPTIONS,
            problem_type="single_label_classification",
            # device_map="auto",
            torch_dtype=torch.bfloat16,
        )

    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # For Llama models, ensure pad_token_id is properly set
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Apply LoRA for parameter-efficient fine-tuning
    if use_lora:
        # Configure LoRA
        lora_config = LoraConfig(
            r=64,  # Increased rank from 16 to 64 for higher model capacity
            lora_alpha=128,  # Increased alpha proportionally (2x the rank)
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type=TaskType.SEQ_CLS,
        )
        # Get PEFT model
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    
    # Enable gradient checkpointing before applying LoRA
    if hasattr(model, "gradient_checkpointing_enable"):
        model.config.use_cache = False  # This is important to avoid the warning
        model.gradient_checkpointing_enable()
    
    # Initialize classifier weights if needed
    if hasattr(model, 'classifier') and not use_lora:
        model.classifier.weight.data.normal_(mean=0.0, std=0.02)
        model.classifier.bias.data.zero_()
    elif hasattr(model, 'score') and not use_lora:
        model.score.weight.data.normal_(mean=0.0, std=0.02)
    
    def tokenize_function(examples):
        # Ensure we're returning the expected keys for the model
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=512
        )

    # Tokenize train dataset
    tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
    
    # Tokenize all evaluation datasets
    tokenized_eval_datasets = {}
    for name, dataset in eval_datasets.items():
        tokenized_eval_datasets[name] = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
    
    # Get the primary evaluation dataset (first one)
    primary_eval_dataset = next(iter(tokenized_eval_datasets.values())) if tokenized_eval_datasets else None
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=1e-4 if use_lora else 5e-6,  # Higher learning rate for LoRA
        per_device_train_batch_size=64,  # Very small batch size for 70B model without quantization
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=1,  # Increased significantly for 70B model
        num_train_epochs=100,
        weight_decay=0.1,
        eval_strategy="steps",
        eval_steps=20,
        metric_for_best_model="accuracy",
        save_total_limit=0,  # Don't save any checkpoints
        run_name=f"mmlu-pro-{NUM_OPTIONS}way-70B-lora",
        warmup_ratio=0.1,
        logging_steps=1,
        report_to="wandb" if accelerator.is_main_process else "none",
        bf16=True,  # Use bfloat16 instead of fp16 for better numerical stability
        save_strategy="no",  # Don't save any models
        gradient_checkpointing=True,
        lr_scheduler_type="constant_with_warmup",
        ddp_find_unused_parameters=False,
    )
    
    # Initialize trainer with default optimizer for LoRA
    # trainer = MultiDatasetTrainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=tokenized_train,
    #     eval_dataset=primary_eval_dataset,
    #     eval_datasets=tokenized_eval_datasets,
    #     compute_metrics=compute_metrics,
    # )
    # Initialize trainer
    trainer = MultiDatasetTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=primary_eval_dataset,
        eval_datasets=tokenized_eval_datasets,
        compute_metrics=compute_metrics,
        data_collator=DataCollatorWithPadding(tokenizer),  # Added data collator
    )


    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def main():
    # Load MMLU-Pro datasets only
    train_dataset, test_dataset = load_mmlu_pro_data()
    
    # Create a dictionary of evaluation datasets - only test on MMLU-Pro test
    eval_datasets = {
        "mmlu_pro_test": test_dataset
    }
    
    # Use the specified output directory
    output_dir = "/fast/nchandak/classification/mmlu_pro_lora"
    
    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="mcq-classifier-lora", name=f"mmlu-pro-noq-{NUM_OPTIONS}way-qwen72B-lora")
    
    # Path to local 70B model
    local_model_path = "/fast/rolmedo/models/qwen2.5-72b-it/snapshots/model/"  # Replace with your actual path
    
    # Train on MMLU-Pro train data, evaluate on MMLU-Pro test
    if accelerator.is_main_process:
        print("Training 70B model with LoRA on MMLU-Pro data...")
    
    # Train with LoRA
    eval_results = train_model(
        train_dataset,
        eval_datasets,
        model_name=None,  # Not used when local_model_path is provided
        output_dir=output_dir,
        use_lora=True,
        local_model_path=local_model_path,
    )
    
    # Log results
    if accelerator.is_main_process:
        # Log each dataset's results separately
        for dataset_name in eval_datasets.keys():
            prefix = dataset_name
            wandb.log({
                f"{prefix}_accuracy": eval_results[f'{prefix}_accuracy'],
                f"{prefix}_f1": eval_results[f'{prefix}_f1'],
                f"{prefix}_precision": eval_results[f'{prefix}_precision'],
                f"{prefix}_recall": eval_results[f'{prefix}_recall']
            })
        
        print("\nEvaluation performance:")
        print(eval_results)
        
        wandb.finish()
    
    return eval_results

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    test_results = main()
    print("\nTest Performance:")
    print(test_results)