import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    T5ForConditionalGeneration,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import KFold
import torch
import json
import wandb
from typing import Dict, List
from accelerate import Accelerator
from data_utils import load_paleka
import sklearn

def filter_data(ds):
    return [{
        'text': item['title'] + "\n" + item['body'],
        'label': 1 if item['resolution'] else 0,
    } for item in ds]


def compute_metrics2(pred):
    labels = pred.label_ids
    
    # Handle T5 models which return predictions as a tuple
    if isinstance(pred.predictions, tuple):
        # For T5 models, the first element contains the logits
        logits = pred.predictions[0]
        preds = logits.argmax(-1)
    else:
        # For other models like BERT/DeBERTa
        preds = pred.predictions.argmax(-1)
        
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


def train_fold(
    train_dataset, 
    val_dataset, 
    model_name: str,
    fold_idx: int,
    base_output_dir: str
) -> Dict[str, float]:
    
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Initialize tokenizer and model
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # For T5 models, we need to use a different approach
    if "t5" in model_name.lower():
        model = T5ForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.float16  # Use half precision
        )
    else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2,
            problem_type="single_label_classification"
        )

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=1024
        )

    # Tokenize datasets
    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_val = val_dataset.map(tokenize_function, batched=True)

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=f"{base_output_dir}/fold_{fold_idx}",
        learning_rate=2e-6,
        per_device_train_batch_size=1, # 8,
        per_device_eval_batch_size=1, # 8,        
        gradient_accumulation_steps=8,  # Added gradient accumulation (2*4=8, so effective batch size is the same)
        num_train_epochs=30,
        # weight_decay=0.01,
        evaluation_strategy="steps",
        eval_steps=100,
        # load_best_model_at_end=True,
        metric_for_best_model="f1",
        # save_total_limit=1,
        save_strategy="no",
        # report_to="wandb",  # Enable wandb logging
        run_name=f"fold_{fold_idx}",
        logging_steps=10,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine"
    )

    # Initialize trainer with accelerator
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        compute_metrics=compute_metrics,
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
    )

    # Prepare everything with accelerator
    trainer = accelerator.prepare(trainer)

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def cross_validate(n_folds: int = 5):
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Load all datasets
    train_dataset, test_dataset = load_paleka("spanned")
    train_dataset = filter_data(train_dataset)
    test_dataset = filter_data(test_dataset)
    
    # print first 10 rows of the train and test datasets
    # print(f"Train dataset: {train_dataset[:10]}")
    # print(f"Test dataset: {test_dataset[:10]}")
    
    # Convert to HuggingFace datasets
    train_dataset = Dataset.from_list(train_dataset)
    test_dataset = Dataset.from_list(test_dataset)
    
    # Print length of the train and test datasets
    print(f"Original Train dataset length: {len(train_dataset)}")
    print(f"Original Test dataset length: {len(test_dataset)}")
    
    print(f"Train label counts: {np.bincount(train_dataset['label'])}")
    print(f"Test label counts: {np.bincount(test_dataset['label'])}")
    
    # return 

    # Balance the training dataset
    num1s = np.sum(train_dataset['label'])
    num0s = len(train_dataset['label']) - num1s
    
    if num1s > num0s:
        # keep num0s number of 1s
        indices_1s = [i for i, label in enumerate(train_dataset['label']) if label == 1]
        indices_0s = [i for i, label in enumerate(train_dataset['label']) if label == 0]
        
        # Take only num0s samples from the 1s
        sampled_1s_indices = indices_1s[:num0s]
        only1s = train_dataset.select(sampled_1s_indices)
        only0s = train_dataset.select(indices_0s)
        train_dataset = concatenate_datasets([only1s, only0s])
    else:
        # only keep num1s number of 0s
        indices_1s = [i for i, label in enumerate(train_dataset['label']) if label == 1]
        indices_0s = [i for i, label in enumerate(train_dataset['label']) if label == 0]
        
        # Take only num1s samples from the 0s
        sampled_0s_indices = indices_0s[:num1s]
        only0s = train_dataset.select(sampled_0s_indices)
        only1s = train_dataset.select(indices_1s)
        train_dataset = concatenate_datasets([only0s, only1s])
    
    # Balance the test dataset
    test_num1s = np.sum(test_dataset['label'])
    test_num0s = len(test_dataset['label']) - test_num1s
    
    if test_num1s > test_num0s:
        # keep test_num0s number of 1s
        test_indices_1s = [i for i, label in enumerate(test_dataset['label']) if label == 1]
        test_indices_0s = [i for i, label in enumerate(test_dataset['label']) if label == 0]
        
        # Take only test_num0s samples from the 1s
        test_sampled_1s_indices = test_indices_1s[:test_num0s]
        test_only1s = test_dataset.select(test_sampled_1s_indices)
        test_only0s = test_dataset.select(test_indices_0s)
        test_dataset = concatenate_datasets([test_only1s, test_only0s])
    else:
        # only keep test_num1s number of 0s
        test_indices_1s = [i for i, label in enumerate(test_dataset['label']) if label == 1]
        test_indices_0s = [i for i, label in enumerate(test_dataset['label']) if label == 0]
        
        # Take only test_num1s samples from the 0s
        test_sampled_0s_indices = test_indices_0s[:test_num1s]
        test_only0s = test_dataset.select(test_sampled_0s_indices)
        test_only1s = test_dataset.select(test_indices_1s)
        test_dataset = concatenate_datasets([test_only0s, test_only1s])
    
    train_dataset = train_dataset.shuffle(seed=42)
    test_dataset = test_dataset.shuffle(seed=42)
    
    # return 
    
    # Print diagnostics only on the main process
    if accelerator.is_main_process:
        print(f"Train dataset: {train_dataset[-4:]}")
        print(f"Test dataset: {test_dataset[:4]}")
        print(f"Train label counts: {np.bincount(train_dataset['label'])}")
        print(f"Test label counts: {np.bincount(test_dataset['label'])}")

    model_name = "google/flan-t5-base"  # "microsoft/deberta-v3-large"
    
    model_name = "microsoft/deberta-v3-large"
    
    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="paleka-classifier", name=f"paleka-b1-temporal-4o-spanned-t5")
    
    # Train final model on all training data and evaluate on validation set
    if accelerator.is_main_process:
        print("Training final model on all training data...")
        
    train_all = concatenate_datasets([train_dataset])
    final_results = train_fold(
        train_all,
        test_dataset,
        model_name,
        fold_idx="final",
        base_output_dir="./results"
    )
    
    # Log results only on the main process
    if accelerator.is_main_process:
        wandb.log({
            "final_accuracy": final_results['eval_accuracy'],
            "final_f1": final_results['eval_f1'],
            "final_precision": final_results['eval_precision'],
            "final_recall": final_results['eval_recall']
        })
        wandb.finish()
    

if __name__ == "__main__":
    cross_validate(n_folds=5)