from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, confusion_matrix
from scipy.special import softmax
import time

data = np.load('ADNI_NC_MCI.npy', allow_pickle=True).item()
adj = torch.from_numpy(data['adj']).float()
X = torch.from_numpy(data['timeseries']).float().transpose(1,2)
Y = torch.from_numpy(data['label']).long()

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.25, random_state = 113) # 1 11 111 112 113

clf = RandomForestClassifier(n_estimators=100, random_state=42)

start_time = time.time()
X_train_pooled = np.mean(X_train.numpy(), axis=1)
model = clf.fit(X_train_pooled, y_train)

total_nodes = sum(tree.tree_.node_count for tree in clf.estimators_)
print("Total number of nodes in all trees:", total_nodes)

X_test_pooled = np.mean(X_test.numpy(), axis=1)
Y_Pred = clf.predict(X_test_pooled)
Y_Prob = clf.predict_proba(X_test_pooled)
Y_Pred = torch.tensor(Y_Pred)
print(f"Total time elapsed: {time.time() - start_time:.4f}s")

correct = Y_Pred.eq(y_test.squeeze(1)).sum().item()
print('ACC: {:.4f}'.format(correct / y_test.shape[0]))

auroc = roc_auc_score(y_test, Y_Pred)
print('AUROC: {:.4f}'.format(auroc))

conf_matrix = confusion_matrix(y_test, Y_Pred)
tn = conf_matrix[0, 0]
fp = conf_matrix[0, 1]
fn = conf_matrix[1, 0]
tp = conf_matrix[1, 1]

# Sensitivity, Specificity, F1
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
f1_score = 2 * tp / (2 * tp + fp + fn)

print('Sensitivity: {:.4f}'.format(sensitivity))
print('Specificity: {:.4f}'.format(specificity))
print('F1: {:.4f}'.format(f1_score))
