from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import Dataset
import torch
import json
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return [{
        'text': item['question'],
        'label': item['resolution']
    } for item in data]

def compute_metrics(labels, preds, probs=None):
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    
    metrics = {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }
    
    # Compute Brier score if probabilities are provided
    if probs is not None:
        # Convert lists to numpy arrays
        probs = np.array(probs)
        labels = np.array(labels)
        brier_score = np.mean((probs - labels) ** 2)
        metrics['brier_score'] = brier_score
    
    return metrics

def evaluate_on_test(model_path: str, test_data_path: str, tokenizer_name: str):
    # Load model and tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    # Set model to evaluation mode
    model.eval()
    
    # Load test data
    test_data = load_data(test_data_path)
    test_dataset = Dataset.from_list(test_data)
    
    print("Length of test dataset:", len(test_dataset))
    
    # Tokenize test data
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
    
    # Manually perform evaluation
    all_preds = []
    all_probs = []
    all_labels = []
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    batch_size = 16
    
    for i in range(0, len(test_dataset), batch_size):
        batch = test_dataset[i:i+batch_size]
        inputs = tokenize_function(batch)
        
        # Move inputs to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Get probabilities using softmax
            probs = torch.nn.functional.softmax(logits, dim=1)
            
            # Get predicted class (0 or 1)
            preds = torch.argmax(logits, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of class 1
        all_labels.extend([batch['label'][j] for j in range(len(batch['label']))])
    
    # Compute metrics
    test_results = compute_metrics(all_labels, all_preds, all_probs)
    
    return test_results

# Example usage
test_results = evaluate_on_test(
    "./results/fold_0/checkpoint-4152",
    "/fast/XXXX-3/forecasting/datasets/menge/binary_test.json",
    "microsoft/deberta-v3-base"
)
print("Test Set Results:")
print(test_results)