from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import os
os.sys.path.append('src')
from tabpfn import TabPFNClassifier
from embedding import TabPFNEmbedding
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder

if __name__ == "__main__":
    X, y = fetch_openml(name='kc1', version=1, as_frame=False, return_X_y=True)
    le = LabelEncoder()
    y = le.fit_transform(y)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.33, random_state=42
    )

    print("Loading classification dataset (kc1)...")
    model = TabPFNClassifier(n_estimators=1,ignore_pretraining_limits=True)
    print('Vanilla TabPFN Embedding:')
    embedding_extractor = TabPFNEmbedding(tabpfn_clf=model, n_fold=0)
    embedding_extractor.fit(X_train, y_train)
    X_train_embedding = embedding_extractor.get_embeddings(X_train,y_train,X_test,data_source='train')[0]
    X_test_embedding = embedding_extractor.get_embeddings(X_train,y_train,X_test,data_source='test')[0]
    pca = PCA(n_components=2)
    X_train_embedding_pca = pca.fit_transform(X_train_embedding)
    X_test_embedding_pca = pca.transform(X_test_embedding)
    color_train_map = {0: 'red', 1: 'blue', 2: 'green'}
    color_map = {0: 'lightcoral', 1: 'lightskyblue', 2: 'lightgreen'}

    plt.scatter(X_train_embedding_pca[:, 0], X_train_embedding_pca[:, 1], c=[color_train_map[i] for i in y_train], alpha=0.5,marker='x')
    plt.scatter(X_test_embedding_pca[:, 0], X_test_embedding_pca[:, 1], c=[color_map[i] for i in y_test], alpha=0.5)
    plt.title('Vanilla Embeddings')
    plt.savefig('vanilla_embeddings_pca.pdf')
    plt.close()
    lr = LogisticRegression()

    lr.fit(X_train_embedding, y_train)
    y_pred = lr.predict(X_test_embedding)
    print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

    print('K-fold TabPFN Embedding:')
    embedding_extractor = TabPFNEmbedding(tabpfn_clf=model, n_fold=10)
    embedding_extractor.fit(X_train, y_train)
    X_train_embedding = embedding_extractor.get_embeddings(X_train,y_train,X_test,data_source='train')[0]
    X_test_embedding = embedding_extractor.get_embeddings(X_train,y_train,X_test,data_source='test')[0]
    lr = LogisticRegression()
    lr.fit(X_train_embedding, y_train)
    y_pred = lr.predict(X_test_embedding)
    print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
    pca = PCA(n_components=2)
    X_train_embedding_pca = pca.fit_transform(X_train_embedding)
    X_test_embedding_pca = pca.transform(X_test_embedding)
    plt.scatter(X_train_embedding_pca[:, 0], X_train_embedding_pca[:, 1], c=[color_train_map[i] for i in y_train], alpha=0.5,marker='x')
    plt.scatter(X_test_embedding_pca[:, 0], X_test_embedding_pca[:, 1], c=[color_map[i] for i in y_test], alpha=0.5)
    plt.title('K-fold Embeddings')
    plt.savefig('k_fold_embeddings_pca.pdf')
    plt.close()
