import os
import time
import pandas as pd
from fire import Fire
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from Learning_transmission_probs.src.Monte_carlo_experiments.graph import create_graph
from Learning_transmission_probs.src.Monte_carlo_experiments.configs import TotalConfig
from Learning_transmission_probs.src.Monte_carlo_experiments.experiment import Experiment
from Learning_transmission_probs.src.Monte_carlo_experiments.form_tables import generate_q_hat_summaries


def main(classifier_name="SVM", graph_type="Tree", feature_type="Limited"):
    start_time = time.time()

    classifiers = {
        "Random Forest": RandomForestClassifier,
        "Decision Tree": DecisionTreeClassifier,
        "Naive Bayes": GaussianNB,
        "KNN": KNeighborsClassifier,
        "SVM": SVC,
        "SGD": SGDClassifier,
        "Logistic Regression": LogisticRegression,
    }

    valid_graph_types = {"Tree", "Loopy", "Insiders_network"}
    valid_feature_types = {"Limited", "Extended"}

    if classifier_name not in classifiers:
        raise ValueError(
            f"Classifier '{classifier_name}' not supported. Choose from: {list(classifiers.keys())}"
        )

    if graph_type not in valid_graph_types:
        raise ValueError(
            f"Graph type '{graph_type}' not supported. Choose from: {list(valid_graph_types)}"
        )

    if feature_type not in valid_feature_types:
        raise ValueError(
            f"Feature type '{feature_type}' not supported. Choose from: {list(valid_feature_types)}"
        )

    config = TotalConfig()

    # create graph
    if graph_type == "Insiders_network":
        edgesDf = pd.read_csv(
            "Learning_transmission_probs/Data/insiders_network_links.csv",
            sep=";"
        )
        graph = create_graph(graph_type=graph_type, df=edgesDf)

    else:
        graph = create_graph(
            node_count=config.graph_node_count,
            seed=config.random_seed,
            graph_type=graph_type,
        )

    clf = classifiers[classifier_name]

    P = [0.1, 0.3, 0.5, 0.7, 0.9]
    Q = [0.1, 0.3, 0.5, 0.7, 0.9]

    print(f"Classifier: {classifier_name}")
    
    for p_i in P:
        for q_i in Q:
            for _ in range(10):  # For the robustness tables in the paper, we run each experiment 10 times
                experiment = Experiment(graph=graph, config=config, graph_type=graph_type)
                print(f"\nGround truth transmission probabilities: p = {p_i}, q = {q_i}")
                experiment.run(
                    classifier_name=classifier_name,
                    classifier=clf,
                    graph_type=graph_type,
                    GT_P=p_i,
                    GT_Q=q_i,
                    feature_type=feature_type,
                )
                
    results_dir = os.path.join(os.getcwd(), "Learning_transmission_probs/Results/synthetic_results")
    robustness_dir = os.path.join(results_dir, "Robustness_tables", graph_type)
    os.makedirs(robustness_dir, exist_ok=True)
    generate_q_hat_summaries(resultDF_path = os.path.join(results_dir, f"Result_{graph_type}_graph_{classifier_name}_classifier_{feature_type}_features.csv"), output_dir=robustness_dir, classifier_name=classifier_name, feature_type=feature_type)
    print(f"Results saved in {robustness_dir}")
    
    end_time = time.time()
    print("Time taken: ", end_time - start_time)


if __name__ == "__main__":
    Fire(main)

# python -m Learning_transmission_probs.src.Monte_carlo_experiments.main --classifier_name="SVM" --graph_type="Tree" --feature_type="Limited"
