import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import KFold
import torch
import json
import wandb
from typing import Dict, List
from accelerate import Accelerator

def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
        
    # if 'test' in file_path:
    #     # Print the first 10 items of data
    #     for i, item in enumerate(data):
    #         if len(item['options']) != 4:
    #             continue
    #         print(f"Item {i}:")
    #         print(f"  prompt: {item['prompt']}")
    #         print(f"  options: {item['options']}")
    #         print(f"  answer_idx: {item['answer_idx']}")
    #         print(f"  volume: {item['volume']}")
    #         print(f"  answer: {item['answer']}")
    #         if i > 100:
    #             break
        
    return [{
        'text': item['prompt'][1:-1],
        'label': item['answer_idx']
    } for item in data if item['volume'] >= 1000 and len(item['options']) == 4]

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def train_fold(
    train_dataset, 
    eval_dataset, 
    model_name: str,
    fold_idx: str,
    base_output_dir: str
) -> Dict[str, float]:
    
    # Initialize accelerator
    accelerator = Accelerator()

    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=4,  # For A, B, C, D options
        problem_type="single_label_classification",
    )

    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # For Llama models, ensure pad_token_id is properly set
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Enable gradient checkpointing before the model is wrapped
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    
    if hasattr(model, 'classifier'):
        model.classifier.weight.data.normal_(mean=0.0, std=0.02)
        model.classifier.bias.data.zero_()
    elif hasattr(model, 'score'):
        model.score.weight.data.normal_(mean=0.0, std=0.02)

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=2048
        )

    # Tokenize datasets
    tokenized_train = train_dataset.map(
        tokenize_function, 
        batched=True,
        remove_columns=['text']
    )
    tokenized_eval = eval_dataset.map(
        tokenize_function, 
        batched=True,
        remove_columns=['text']
    )
    
    # Create a custom data collator
    # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=f"{base_output_dir}/{fold_idx}",
        learning_rate=5e-6,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=1,
        num_train_epochs=30,
        weight_decay=0.1,
        eval_strategy="steps",
        eval_steps=20,
        metric_for_best_model="accuracy",
        save_total_limit=1,
        run_name="manifold-4options-qwen",
        warmup_ratio=0.1,
        logging_steps=10,
        report_to="wandb" if accelerator.is_main_process else "none",
        save_strategy="no",
        fp16=True,
        # gradient_checkpointing=True,
        lr_scheduler_type="constant_with_warmup",
        # remove_unused_columns=False,  # Add this line
    )
    
    # Initialize optimizer with different parameters based on model architecture
    if "llama" in model_name.lower() or "qwen" in model_name.lower():
        # For Llama models
        optimizer = torch.optim.AdamW(
            [
                {"params": [p for n, p in model.named_parameters() if "score" not in n], "lr": 5e-6},  # Lower LR for pretrained layers
                {"params": [p for n, p in model.named_parameters() if "score" in n], "lr": 5e-5}  # Higher LR for classification head
            ],
            weight_decay=0.1
        )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, None)  # Custom optimizer, default scheduler
    )

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def cross_validate():
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Load all datasets
    train_data = load_data("/fast/XXXX-3/forecasting/datasets/manifold/mcq_raw_train.json")
    # test_data = load_data("/fast/XXXX-3/forecasting/datasets/menge/binary_test.json")
    test_data = load_data("/fast/XXXX-3/forecasting/datasets/manifold/mcq_test.json")
    
    # Convert to HuggingFace datasets
    train_dataset = Dataset.from_list(train_data)
    test_dataset = Dataset.from_list(test_data)
    
    # print size of train_dataset and test_dataset
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Print first 10 items of train_dataset
    # print(train_dataset[:10])
    # print(test_dataset[:10])
    
    # shuffle both datasets
    train_dataset = train_dataset.shuffle(seed=42)
    test_dataset = test_dataset.shuffle(seed=42)
    
    # Add this diagnostic print
    train_label_counts = np.bincount(train_dataset['label'])
    test_label_counts = np.bincount(test_dataset['label'])
    print(f"Train label counts: {train_label_counts}")
    print(f"Test label counts: {test_label_counts}")
    return 
    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="mcq-classifier", name="manifold-4options-qwen")
    
    model_name = "microsoft/deberta-v3-large"
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"
    
    # Train final model on all training data and evaluate on validation set
    if accelerator.is_main_process:
        print("Training final model on all training data...")
    final_results = train_fold(
        train_dataset,
        test_dataset,
        model_name,
        fold_idx="manifold",
        base_output_dir="./results"
    )
    
    # Log results only on the main process
    if accelerator.is_main_process:
        wandb.log({
            "final_accuracy": final_results['eval_accuracy'],
            "final_f1": final_results['eval_f1'],
            "final_precision": final_results['eval_precision'],
            "final_recall": final_results['eval_recall']
        })
        
        wandb.finish()
    

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility across GPUs
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    cross_validate()
    # avg_metrics, final_results = cross_validate(n_folds=5)
    # print("\nAverage Cross-Validation Metrics:")
    # print(avg_metrics)
    # print("\nFinal Model Performance on Validation Set:")
    # print(final_results)