import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from tabulate import tabulate

from xgboost import XGBClassifier
from catboost import CatBoostClassifier
from sklearn.linear_model import LogisticRegression

from tabpfn import TabPFNClassifier
from deeptlf import DeepTFL
from tabm.tabm_pipeline import full_pipeline_tabm

# --- Optional: TabNet imports ---
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier

# Config
train_size = 400
repeats = 5
experiment_name = "new_OSDT_gating_per_layer"
results_dir = f'comparing_dt_nn_plots/{experiment_name}'
os.makedirs(results_dir, exist_ok=True)

datasets = [
    "default-of-credit-card-clients_categorical"
]
models = ["CatBoost", "LogReg", "MLP", "XGBoost", "TabNet"]  # Add others as needed

# --- SimpleMLP model ---
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )
    def forward(self, x):
        return self.net(x)

def load_openml_processed(test_data_file, test_labels_file, label_column="label"):
    X = pd.read_csv(test_data_file)
    y = pd.read_csv(test_labels_file)
    if label_column in X.columns:
        X = X.drop(columns=[label_column])
    y = y[label_column] if label_column in y.columns else y.iloc[:, 0]
    return X, y

def tabnet_pipeline(X_train, Y_train, X_valid, Y_valid):
    # TabNet expects np arrays
    X_train = X_train.astype(np.float32)
    X_valid = X_valid.astype(np.float32)
    Y_train = np.array(Y_train)
    Y_valid = np.array(Y_valid)

    # Pretraining (unsupervised)
    pretrainer = TabNetPretrainer(
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        mask_type='entmax'
    )
    pretrainer.fit(
        X_train=X_train,
        pretraining_ratio=0.8,
        max_epochs=30,
        patience=10,
        batch_size=512,
        virtual_batch_size=128
    )

    # Main classifier
    clf = TabNetClassifier(
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        scheduler_params={"step_size": 10, "gamma": 0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='sparsemax'
    )
    clf.fit(
        X_train=X_train, y_train=Y_train,
        eval_set=[(X_valid, Y_valid)],
        eval_name=['valid'],
        eval_metric=['accuracy'],
        from_unsupervised=pretrainer,
        max_epochs=30,
        patience=10,
        batch_size=512,
        virtual_batch_size=128
    )
    # Validation accuracy
    valid_accuracy = clf.history['valid_accuracy'][-1]
    return valid_accuracy

def evaluate_once(model_name, X, Y, seed):
    X_train, X_test, Y_train, Y_test = train_test_split(
        X, Y, train_size=train_size, random_state=seed, stratify=Y
    )

    if model_name == "MLP":
        input_dim = X.shape[1]
        output_dim = len(np.unique(Y))
        model = SimpleMLP(input_dim, output_dim)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        batch_size = 64
        epochs = 30
        train_loader = DataLoader(
            TensorDataset(torch.tensor(X_train.values, dtype=torch.float32),
                          torch.tensor(Y_train.values, dtype=torch.long)),
            batch_size=batch_size, shuffle=True)
        model.train()
        for _ in range(epochs):
            for xb, yb in train_loader:
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()
        model.eval()
        with torch.no_grad():
            preds = model(torch.tensor(X_test.values, dtype=torch.float32)).argmax(dim=1).cpu().numpy()
        return accuracy_score(Y_test, preds)

    elif model_name == "XGBoost":
        model = XGBClassifier(use_label_encoder=False, eval_metric='mlogloss')
        model.fit(X_train, Y_train)
        preds = model.predict(X_test)
        return accuracy_score(Y_test, preds)

    elif model_name == "LogReg":
        model = LogisticRegression(max_iter=5000, multi_class='auto')
        model.fit(X_train, Y_train)
        preds = model.predict(X_test)
        return accuracy_score(Y_test, preds)

    elif model_name == "CatBoost":
        model = CatBoostClassifier(verbose=0)
        model.fit(X_train, Y_train)
        preds = model.predict(X_test)
        return accuracy_score(Y_test, preds)

    elif model_name == "TabM":
        return full_pipeline_tabm(X_train.values, Y_train.values, X_test.values, Y_test.values)

    elif model_name == "TabNet":
        return tabnet_pipeline(X_train.values, Y_train.values, X_test.values, Y_test.values)

    elif model_name == "TabPFN":
        clf = TabPFNClassifier()
        clf.fit(X_train.values, Y_train.values)
        return accuracy_score(Y_test.values, clf.predict(X_test.values))

    elif model_name == "DeepTFL":
        model = DeepTFL(task='class')
        model.fit(X_train.values, Y_train.values)
        return accuracy_score(Y_test.values, model.predict(X_test.values))
    else:
        raise ValueError(f"Unknown model: {model_name}")

# Main
final_results = []

for dataset in datasets:
    data_dir = f"openml_datasets/{dataset}"
    test_data_file = os.path.join(data_dir, "test_df.csv")
    test_labels_file = os.path.join(data_dir, "test_labels.csv")

    if not os.path.exists(test_data_file) or not os.path.exists(test_labels_file):
        print(f"Dataset '{dataset}' not found. Skipping...")
        continue

    X, y = load_openml_processed(test_data_file, test_labels_file, label_column="label")
    X = X.astype(np.float32)
    y = pd.Series(LabelEncoder().fit_transform(y))

    for model_name in models:
        accuracies = []
        for i in range(repeats):
            try:
                acc = evaluate_once(model_name, X, y, seed=i)
                accuracies.append(acc)
            except Exception as e:
                print(f"{dataset} | {model_name} | Run {i} failed: {e}")
                continue

        if accuracies:
            mean_acc = np.mean(accuracies)
            std_acc = np.std(accuracies)
            result_row = [dataset, model_name, len(X), f"{mean_acc:.4f} ± {std_acc:.4f}"]
        else:
            result_row = [dataset, model_name, len(X), "FAILED"]

        final_results.append(result_row)

    print("\n" + "=" * 80)
    print(f"Cumulative Results (up to and including {dataset}):")
    print(tabulate(final_results, headers=["Dataset", "Model", "Samples", "Accuracy (mean ± std)"], tablefmt="grid"))
    print("=" * 80 + "\n")

# Save final results
results_df = pd.DataFrame(final_results, columns=["Dataset", "Model", "Samples", "Accuracy (mean ± std)"])
results_df.to_csv(os.path.join(results_dir, "baseline_model_results.csv"), index=False)
