from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import LinearSVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from tmu.models.classification.vanilla_classifier import TMClassifier
import numpy as np
import tqdm

def evaluate_classifiers(data, tm_epochs=10):
    X_train_augmented, y_train, X_test_text, y_test = data
    
    results = {}
    vectorizer = CountVectorizer()
    X_train_vec = vectorizer.fit_transform(X_train_augmented)
    X_test_vec = vectorizer.transform(X_test_text)

    classifiers = {
        "RF": RandomForestClassifier(n_estimators=100, random_state=42),
        "LR": LogisticRegression(random_state=42),
        "NB": MultinomialNB(),
        "SVM": LinearSVC(random_state=42),
        "MLP": MLPClassifier(random_state=42)
    }

    progress = tqdm.tqdm(total=len(classifiers) + 1, desc="Classifiers", position=0, leave=True)
    
    for name, classifier in classifiers.items():
        classifier.fit(X_train_vec, y_train)
        y_pred = classifier.predict(X_test_vec)
        accuracy = accuracy_score(y_test, y_pred)
        results[name] = accuracy
        progress.update(1)
        
    X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)
    Y_train_tm = y_train.astype(np.uint32)

    X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)
    Y_test_tm = y_test.astype(np.uint32)
    
    num_clauses = 1000
    T = 8000
    s = 2.0
    device = "CPU"
    weighted_clauses = True
    epochs = tm_epochs
    clause_drop_p = 0.75

    tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)
    for epoch in tqdm.tqdm(range(epochs), desc="Training TM"):
        tm.fit(X_train_tm, Y_train_tm)
    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()
    results["TM"] = result
    
    progress.update(1)
    progress.close()
    return results