import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score

node_best_params = {}


def hyper_parameter_optimise(
    observed_feature_vector, simulated_feature_vector, classifier="SVM"
):
    # Generate synthetic data
    combined_dataset = observed_feature_vector + simulated_feature_vector
    combined_dataset = np.array(combined_dataset)

    X = combined_dataset[:, :-1]
    # print(X)
    Y = combined_dataset[:, -1]
    # print(y)
    X_train, X_test, y_train, y_test = train_test_split(
        X, Y, test_size=0.4, random_state=42
    )
    # Hyperparameter grids for each classifier

    param_grid_SVM = {
        "C": [1, 10, 100, 1000],
        "kernel": ["linear", "rbf", "poly"],
        "gamma": [0.01, 0.1, 1],
    }

    param_grid_RF = {
        "n_estimators": [100, 200],
        "max_depth": [None, 10, 20],
        "min_samples_split": [2, 5],
    }

    param_grid_nb = {"var_smoothing": [1e-9, 1e-8, 1e-7, 1e-6]}

    param_grid_sgd = {
        "loss": ["hinge", "log"],
        "penalty": ["l2", "elasticnet"],
        "alpha": [0.0001, 0.001],
    }

    param_grid_dt = {
        "criterion": ["gini", "entropy"],
        "max_depth": [None, 10, 20],
        "min_samples_split": [2, 5],
    }

    param_grid_knn = {
        "n_neighbors": [5, 10],
        "weights": ["uniform", "distance"],
        "algorithm": ["auto", "ball_tree"],
    }

    param_grid_lr = {
        "penalty": ["l2"],
        "C": [0.01, 0.1, 1],
        "solver": ["liblinear"],
        "max_iter": [100, 200],
    }

    if classifier == "SVM":
        grid_search = GridSearchCV(SVC(), param_grid_SVM, cv=5, scoring="accuracy")
        grid_search.fit(X_train, y_train)

    elif classifier == "Random Forest":
        rf = RandomForestClassifier()
        grid_search = GridSearchCV(
            estimator=rf, param_grid=param_grid_RF, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    elif classifier == "Decision Tree":
        dt = DecisionTreeClassifier()
        grid_search = GridSearchCV(
            estimator=dt, param_grid=param_grid_dt, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    elif classifier == "Naive Bayes":
        nb = GaussianNB()
        grid_search = GridSearchCV(
            estimator=nb, param_grid=param_grid_nb, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    elif classifier == "KNN":
        knn = KNeighborsClassifier()
        grid_search = GridSearchCV(
            estimator=knn, param_grid=param_grid_knn, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    elif classifier == "SGD":
        sgd = SGDClassifier()
        grid_search = GridSearchCV(
            estimator=sgd, param_grid=param_grid_sgd, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    elif classifier == "Logistic Regression":
        lr = LogisticRegression()
        grid_search = GridSearchCV(
            estimator=lr, param_grid=param_grid_lr, cv=3, n_jobs=-1, verbose=2
        )
        grid_search.fit(X_train, y_train)

    best_hyper_param = grid_search.best_params_
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test)

    return best_hyper_param, accuracy_score(y_test, y_pred)


def hyper_parameter_optimise_multi_classifier(
    observed_fv_DF, simulated_fv_DF, classifier="SVM"
):

    classification_accuracies_for_current_sim = []
    for inv in range(0, len(observed_fv_DF)):
        observed_inv_features = observed_fv_DF.iloc[inv]
        simulated_inv_features = simulated_fv_DF.iloc[inv]
        best_hyper_parameter_current_inv, accuracy_current_inv = (
            hyper_parameter_optimise(
                observed_feature_vector=observed_inv_features.tolist(),
                simulated_feature_vector=simulated_inv_features.tolist(),
                classifier=classifier,
            )
        )
        node_best_params[f"{inv}"] = best_hyper_parameter_current_inv
        classification_accuracies_for_current_sim.append(accuracy_current_inv)
    return classification_accuracies_for_current_sim, node_best_params
