#  Copyright (c) Prior Labs GmbH 2025.
"""Example of using TabPFN for multiclass classification.

This example demonstrates how to use TabPFNClassifier on a multiclass classification task
using a local Iris dataset file.
"""

import time

import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit  # 使用 StratifiedShuffleSplit
from sklearn.preprocessing import MinMaxScaler  # 导入 MinMaxScaler

from tabpfn import TabPFNClassifier

# Load data from local CSV (no header)
data = pd.read_csv('gla-jsc.csv', header=None)  # 假设数据没有列名

# Features are all columns except the last one, and labels are the last column
X = data.iloc[:, :-1].values  # 所有行，去掉最后一列
y = data.iloc[:, -1].values  # 所有行，最后一列作为标签

# Apply Min-Max normalization to features
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)  # 对 X 进行最大最小归一化

# Initialize StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.33, random_state=42)

# Ensure each class is represented in both training and test data
for train_index, test_index in sss.split(X_normalized, y):
    X_train, X_test = X_normalized[train_index], X_normalized[test_index]
    y_train, y_test = y[train_index], y[test_index]

start_time = time.time()
# Initialize a classifier
clf = TabPFNClassifier()
clf.fit(X_train, y_train)

# Predict probabilities
prediction_probabilities = clf.predict_proba(X_test)

end_time = time.time()

# Check if the number of classes in y_test matches the number of columns in prediction_probabilities
num_classes = len(set(y_test))
print("Number of classes in y_test:", num_classes)
print("Shape of prediction probabilities:", prediction_probabilities.shape)

# Ensure that prediction probabilities have the same number of columns as the number of classes
if prediction_probabilities.shape[1] == num_classes:
    print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities, multi_class="ovr"))
else:
    print(
        f"Error: The number of classes in y_true ({num_classes}) does not match the number of columns in y_score ({prediction_probabilities.shape[1]})")

# Predict labels
predictions = clf.predict(X_test)

# Calculate Balanced Accuracy
balanced_acc = balanced_accuracy_score(y_test, predictions)
print("Balanced Accuracy:", balanced_acc)

# Also print the regular accuracy for comparison
accuracy = accuracy_score(y_test, predictions)
print("Accuracy:", accuracy)

print("Time", end_time - start_time)
