import time
import os
import pandas as pd
from fire import Fire
from sklearn.svm import SVC
from Learning_transmission_probs.src.Empirical_experiment.analyse_data import (
    get_seeds_and_edges,
)
from Learning_transmission_probs.src.Empirical_experiment.analyse_data import (
    get_significant_announcement_day_with_investors_traded_on_window,
)
from Learning_transmission_probs.src.Empirical_experiment.create_graph import (
    create_graph,
)
from Learning_transmission_probs.src.Empirical_experiment.experiment import Experiment
from Learning_transmission_probs.src.Empirical_experiment.configs import TotalConfig
from Learning_transmission_probs.src.Empirical_experiment.form_emprirical_inferred_table import (
    prepare_inferred_table,
)


def main():
    start_time = time.time()

    config = TotalConfig()
    companyNameDf = pd.read_csv(f"Learning_transmission_probs/Data/company_names.csv")
    Companies = companyNameDf["company"].tolist()
    clf = SVC()

    for company in Companies:
        for announcement in range(2):  # 0 for Pre-Announcement, 1 for Non-Announcement
            non_announcement = announcement == 0
            period = (
                "Pre-Announcement period"
                if non_announcement
                else "Non-Announcement period"
            )
            print(
                f"Learning transmission probabilities for company: {company} during the {period}"
            )

            uniqueInvestorOfCompany, edgesDf, seeds = get_seeds_and_edges(company)
            graph = create_graph(df=edgesDf)

            announcementDayWithWindowedInvestorDf = (
                get_significant_announcement_day_with_investors_traded_on_window(
                    companyName=company, consider_all_announcements=True
                )
            )

            for j in range(50):  # repeat experiment 50 times
                experiment = Experiment(
                    graph=graph,
                    config=config,
                    seeds=seeds,
                    announcementDayWithWindowedInvestorDf=announcementDayWithWindowedInvestorDf,
                    uniqueInvestorOfCompany=uniqueInvestorOfCompany,
                    company_name=company,
                )
                experiment.run(
                    classifier=clf,
                    non_announcement=non_announcement,
                    company_name=company,
                )

    result_df = pd.read_csv(
        os.path.join(
            os.getcwd(),
            "Learning_transmission_probs/Results/empirical_results",
            "Infered_p_and_q_of_all_companies.csv",
        ),
        sep=",",
    )
    print(result_df)
    table_df = prepare_inferred_table(result_df)
    table_df.to_csv(
        os.path.join(
            os.getcwd(),
            "Learning_transmission_probs/Results/empirical_results",
            "Inferred_p_and_q_of_all_companies_table.csv",
        ),
        sep=",",
        index=False,
    )
    print(f"Results saved")

    end_time = time.time()
    print("Time taken: ", end_time - start_time)


if __name__ == "__main__":
    Fire(main)

# python -m Learning_transmission_probs.src.Empirical_experiment.main
