import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    AutoConfig
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import KFold
import torch
import json
import wandb
from typing import Dict, List

def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return [{
        'text': item['question'],
        'label': item['resolution']
    } for item in data]

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def train_fold(
    train_dataset, 
    val_dataset, 
    model_name: str,
    fold_idx: int,
    base_output_dir: str
) -> Dict[str, float]:
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        problem_type="single_label_classification"
    )

    # Initialize model with random weights instead of pre-trained weights
    config = AutoConfig.from_pretrained(model_name)
    config.num_labels = 2  # For binary classification
    config.problem_type = "single_label_classification"
    model = AutoModelForSequenceClassification.from_config(config)
    
    # Print number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params}")

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=512
        )

    # # Freeze all layers except the classification head
    # for param in model.parameters():
    #     param.requires_grad = False
    
    # # # # Unfreeze only the classification head
    # for param in model.classifier.parameters():
    #     param.requires_grad = True
    
    # # # Print number of trainable parameters
    # trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # total_params = sum(p.numel() for p in model.parameters())
    # print(f"Trainable parameters: {trainable_params} ({trainable_params/total_params:.2%} of total)")

    # Tokenize datasets
    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_val = val_dataset.map(tokenize_function, batched=True)

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=f"{base_output_dir}/fold_{fold_idx}",
        learning_rate=1e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=50,
        weight_decay=0.01,
        evaluation_strategy="steps",
        eval_steps=50,
        # load_best_model_at_end=True,
        metric_for_best_model="f1",
        # save_total_limit=1,
        save_strategy="no",
        report_to="wandb",  # Enable wandb logging
        run_name=f"fold_{fold_idx}",
        logging_steps=20,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        compute_metrics=compute_metrics,
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def cross_validate(n_folds: int = 5):
    # Load all datasets
    train_data = load_data("/fast/XXXX-3/forecasting/datasets/menge/binary_train.json")
    val_data = load_data("/fast/XXXX-3/forecasting/datasets/menge/binary_validation.json")
    test_data = load_data("/fast/XXXX-3/forecasting/datasets/menge/binary_test.json")
    # test_data = load_data("/fast/XXXX-3/forecasting/datasets/infinitegames/binary_balanced_test.json")
    
    # Convert to HuggingFace datasets
    train_dataset = Dataset.from_list(train_data)
    val_dataset = Dataset.from_list(val_data)
    test_dataset = Dataset.from_list(test_data)
    
    # print kength of eahc dataset
    print(f"Train dataset length: {len(train_dataset)}")
    print(f"Val dataset length: {len(val_dataset)}")
    print(f"Test dataset length: {len(test_dataset)}")
    
    # return 
    # Initialize k-fold
    # kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    # Initialize wandb
    wandb.init(project="forecasting-classifier", name="mengye-temporal-train-random-init")
    
    # Store results for each fold
    fold_results = []
    model_name = "microsoft/deberta-v3-large"
    
    # Combine train and validation datasets for cross-validation
    combined_train = concatenate_datasets([train_dataset])
    
    # # Perform k-fold cross validation
    # for fold_idx, (train_idx, cv_val_idx) in enumerate(kf.split(range(len(combined_train)))):
    #     print(f"Training fold {fold_idx + 1}/{n_folds}")
        
    #     # Split the combined dataset
    #     fold_train = combined_train.select(train_idx)
    #     fold_val = combined_train.select(cv_val_idx)
        
    #     # Train and evaluate on this fold
    #     model_name = "microsoft/deberta-v3-large"
    #     results = train_fold(
    #         fold_train, 
    #         fold_val,
    #         model_name,
    #         fold_idx,
    #         "./results"
    #     )
        
    #     fold_results.append(results)
        
    #     # Log results to wandb
    #     wandb.log({
    #         f"fold_{fold_idx}_accuracy": results['eval_accuracy'],
    #         f"fold_{fold_idx}_f1": results['eval_f1'],
    #         f"fold_{fold_idx}_precision": results['eval_precision'],
    #         f"fold_{fold_idx}_recall": results['eval_recall']
    #     })
    
    # # Calculate and log average metrics
    # avg_metrics = {
    #     'accuracy': np.mean([r['eval_accuracy'] for r in fold_results]),
    #     'f1': np.mean([r['eval_f1'] for r in fold_results]),
    #     'precision': np.mean([r['eval_precision'] for r in fold_results]),
    #     'recall': np.mean([r['eval_recall'] for r in fold_results])
    # }
    
    # wandb.log({
    #     "avg_accuracy": avg_metrics['accuracy'],
    #     "avg_f1": avg_metrics['f1'],
    #     "avg_precision": avg_metrics['precision'],
    #     "avg_recall": avg_metrics['recall']
    # })
    
    # Train final model on all training data and evaluate on validation set
    print("Training final model on all training data...")
    train_all = concatenate_datasets([train_dataset, val_dataset])
    final_results = train_fold(
        train_all,
        test_dataset,
        model_name,
        fold_idx="final",
        base_output_dir="./results"
    )
    
    wandb.log({
        "final_accuracy": final_results['eval_accuracy'],
        "final_f1": final_results['eval_f1'],
        "final_precision": final_results['eval_precision'],
        "final_recall": final_results['eval_recall']
    })
    
    wandb.finish()

if __name__ == "__main__":
    cross_validate(n_folds=5)
    # print("\nAverage Cross-Validation Metrics:")
    # print(avg_metrics)
    # print("\nFinal Model Performance on Validation Set:")
    # print(final_results)