import os
import copy
import pandas as pd
import numpy as np
import networkx as nx
from pathlib import Path
from queue import Queue
from sklearn.svm import SVC
from random import uniform, choices
from sklearn.metrics import accuracy_score
from scipy.optimize import minimize
from Learning_transmission_probs.src.Empirical_experiment.configs import (
    ExperimentConfig,
)
from Learning_transmission_probs.src.Empirical_experiment.analyse_data import (
    get_company_specific_transaction_df,
)
from Learning_transmission_probs.src.Empirical_experiment.baseline_trade_prob import (
    baseline_model,
)
from Learning_transmission_probs.src.Empirical_experiment.hyper_parameter_optimisation import (
    hyper_parameter_optimise,
)


np.random.seed(42)

project_dir = Path(__file__).parent


class Experiment:

    def __init__(
        self,
        graph: nx.DiGraph,
        config: ExperimentConfig,
        seeds=None,
        announcementDayWithWindowedInvestorDf=None,
        uniqueInvestorOfCompany=None,
        company_name=None,
    ) -> None:

        self.GT_P = None
        self.GT_Q = None
        self.nodes = []
        self.seeds = seeds
        self.results = []
        self.classifier = None
        self.trade_options = config.trade_options
        self.results_P_and_Q = []
        self.node_best_hyper_params = {}
        self.non_announcement_trade_days = []
        self.inv_in_non_announcement_days = []
        self.inv_for_classification = []
        self.simulated_data_class_label = config.simulated_data_class_label
        self.observed_data_class_label = config.observed_data_class_label
        self.trade_matrix_shape_0 = 0
        self.trade_matrix_shape_1 = 0
        self.baseline_prob_dict = {}
        self.compute_non_announcement_days = (
            None  # config.compute_non_announcement_days
        )
        self.simulated_data_feature_vector_DF = []
        self.observed_data_feature_vector_DF = []
        self.uniqueInvestorOfCompany = uniqueInvestorOfCompany
        self.company_name = company_name
        self.graph = copy.deepcopy(graph)
        self.companyNonAnnouncementTransactionDf = []
        self.number_of_classification_vectors = config.number_of_classification_vectors
        self.number_of_information_cascades_for_bootstrap_observed = (
            config.number_of_information_cascades_for_bootstrap_observed
        )
        self.announcementDayWithWindowedInvestorDf = (
            announcementDayWithWindowedInvestorDf
        )
        self.uniqueWindowedCompanyInvestor = list(
            set(sum(self.announcementDayWithWindowedInvestorDf["InvestorIDs"], []))
        )
        self.companyAllTransactionDf = get_company_specific_transaction_df(
            unique_investors_from_company=self.uniqueInvestorOfCompany,
            unique_investors_from_graph=self.graph.nodes(),
            company_name=self.company_name,
        )

    def run(self, classifier=None, non_announcement=None, company_name="Nokia"):
        self.compute_non_announcement_days = non_announcement
        self.classifier = classifier
        if self.compute_non_announcement_days:

            self.get_non_anouncement_days()
            self.set_node_property()
            self.assign_baseline_probabilities()
            self.GT_finding_ouside_of_announcement_period()
            self.inference_via_classification()
        else:
            self.set_node_property()
            self.assign_baseline_probabilities()
            self.filtering_investors_who_traded_in_atleast_specified_windowed_announcement()
            self.GT_finding()
            self.inference_via_classification()

    def filtering_investors_who_traded_in_atleast_specified_windowed_announcement(self):
        investor_occurrences = []
        for index, investor_list in self.announcementDayWithWindowedInvestorDf[
            "InvestorIDs"
        ].items():
            investor_occurrences.extend((investor, index) for investor in investor_list)

        investor_df = pd.DataFrame(
            investor_occurrences, columns=["InvestorID", "RowIndex"]
        )

        investor_counts = investor_df.groupby("InvestorID")["RowIndex"].nunique()

        investors_in_multiple_lists = investor_counts[
            investor_counts >= 2
        ].index.tolist()

        self.inv_for_classification = list(
            set(self.inv_for_classification) & set(investors_in_multiple_lists)
        )
     
    def assign_baseline_probabilities(self):
        print("Assigning baseline probabilities")
        baseline_DF = baseline_model(
            CompanyName=self.company_name,
            announcementDayWithWindowedInvestorDf=self.announcementDayWithWindowedInvestorDf,
            companyAllTransactionDf=self.companyAllTransactionDf,
        )

        investor_dict = {
            int(row["InvestorID"]): [
                row["ProfitTradeProb"],
                row["LossTradeProb"],
                row["NoTradeProb"],
            ]
            for _, row in baseline_DF.iterrows()
        }
        inv_which_has_baseline = []
        for node, properties in self.graph.nodes(data=True):

            if node in investor_dict:
                inv_which_has_baseline.append(node)
                properties["baseline_trade_prob"] = investor_dict[node]

        self.baseline_prob_dict = investor_dict

        if self.compute_non_announcement_days:
            self.inv_for_classification = list(
                set(inv_which_has_baseline) & set(self.inv_in_non_announcement_days)
            )  # 485 for Nokia
        else:
            self.inv_for_classification = list(
                set(inv_which_has_baseline) & set(self.uniqueWindowedCompanyInvestor)
            )  # 406 for Nokia

    def test_random(self, success_probability) -> bool:
        random = uniform(0, 1)
        return random < success_probability

    def set_node_property(self):
        is_multiple_seeds = len(self.seeds) > 1
        seeds_set = set(self.seeds) if is_multiple_seeds else None

        for node, properties in self.graph.nodes(data=True):
            if is_multiple_seeds:
                properties["is_seed"] = node in seeds_set
            else:
                properties["is_seed"] = node == self.seeds[0]
            properties["is_in_pathway"] = False
            properties["is_informed"] = False
            properties["trade"] = 0
        return True

    def set_P_and_Q(self):

        for node, properties in self.graph.nodes(data=True):
            properties["informed_correct_trade_probability_Q"] = (
                self.informed_correct_trade_probability_Q
            )
            neighbors = list(self.graph.neighbors(node))
            for neighbor in neighbors:
                self.graph[node][neighbor][
                    "info_propagation_probability_P"
                ] = self.info_propagation_probability_P
        return True

    def propagate_information_from_seeds(self):
        queue = Queue()
        for seed in self.seeds:
            queue.put(seed)
        while not queue.empty():
            current_node = queue.get()
            self.graph.nodes[current_node]["is_informed"] = True
            neighbours = [*self.graph.neighbors(current_node)]
            for neighbour in neighbours:
                if self.graph.nodes[neighbour]["is_informed"]:
                    continue
                if self.test_random(
                    self.graph[current_node][neighbour][
                        "info_propagation_probability_P"
                    ]
                ):
                    self.graph.nodes[neighbour]["is_informed"] = True
                    queue.put(neighbour)

    def trade(self):
        trades = []
        informed_nodes = 0
        for node, parameters in self.graph.nodes(data=True):
            if parameters["is_informed"]:
                informed_nodes += 1
                current_trade = [
                    (
                        self.trade_options[0]
                        if self.test_random(
                            parameters["informed_correct_trade_probability_Q"]
                        )
                        else self.trade_options[-1]
                    )
                ]

            else:
                if node in self.inv_for_classification:
                    current_trade = choices(
                        self.trade_options, self.baseline_prob_dict[node]
                    )
                else:
                    current_trade = [self.trade_options[-1]]

            parameters["trade"] = current_trade[0]
            trades.append(current_trade[0])

    def get_non_anouncement_days(self):
        company_trade_days = pd.unique(self.companyAllTransactionDf["TradeDay"])
        Announcement_days = self.announcementDayWithWindowedInvestorDf[
            "AnnouncementDay"
        ].unique()
        Announcement_days = pd.to_datetime(Announcement_days)

        print("Announcement days", len(Announcement_days))
        print("Company trade days", len(company_trade_days))

        df = pd.DataFrame({"TradeDay": company_trade_days})

        indexes_to_remove = set()

        for announcement in Announcement_days:
            if announcement in df["TradeDay"].values:
                idx = df[df["TradeDay"] == announcement].index[0]

                indexes_to_remove.update(range(max(0, idx - 4), idx + 1))

        df_filtered = df.drop(indexes_to_remove).reset_index(drop=True)

        self.non_announcement_trade_days = df_filtered["TradeDay"].tolist()

        self.companyNonAnnouncementTransactionDf = self.companyAllTransactionDf[
            self.companyAllTransactionDf["TradeDay"].isin(
                self.non_announcement_trade_days
            )
        ]

        print(
            "unique investors outside announcements:",
            len(self.companyNonAnnouncementTransactionDf["InvestorID_EC"].unique()),
        )

        print(
            "Non announcement trade days", len(self.companyNonAnnouncementTransactionDf)
        )

        inv_in_non_announcement_days = list(
            set(self.companyNonAnnouncementTransactionDf["InvestorID_EC"])
        )
        self.inv_in_non_announcement_days = set(self.graph.nodes) & set(
            inv_in_non_announcement_days
        )

    def investor_trade_vector(self):
        trade_vector = [
            1 if parameters["trade"] == 1 else (-1 if parameters["trade"] == -1 else 0)
            for node, parameters in self.graph.nodes(data=True)
            if node in self.inv_for_classification
        ]
        return trade_vector

    def create_feature_DF(self, class_label):
        trade_data_matrix = np.zeros(
            (self.trade_matrix_shape_0, self.trade_matrix_shape_1), dtype=int
        )
        for round_idx in range(self.trade_matrix_shape_1):
            self.propagate_information_from_seeds()
            self.trade()
            trade_vector = self.investor_trade_vector()
            trade_data_matrix[:, round_idx] = trade_vector
            for _, properties in self.graph.nodes(data=True):
                properties["is_informed"] = False
                properties["trade"] = 0
        feature_df = self.compute_feature_DF_with_bootsrapping(
            trade_matrix=trade_data_matrix, class_label=class_label
        )

        return feature_df

    def GT_finding(self):
        announcement = 0
        no_transaction_on_announcement_day = 0

        companyGroupedTradedayDf = self.companyAllTransactionDf.groupby(
            "TradeDay", as_index=False
        ).agg({"Price": "mean"})

        companyGroupedTradedayDf["TradeDay"] = pd.to_datetime(
            companyGroupedTradedayDf["TradeDay"]
        )

        for (
            index,
            singleAnnouncementData,
        ) in self.announcementDayWithWindowedInvestorDf.iterrows():

            flag_for_no_trade_on_announcement_day = 0
            self.count_unique_windowed_inv_connected_to_seeds = 0
            current_announcement_trade_vector = []
            investorIDs = singleAnnouncementData["InvestorIDs"]

            lastTransactionDaysList = singleAnnouncementData["LastTransaction"]

            announcementDay = singleAnnouncementData["AnnouncementDay"]

            for node, parameters in self.graph.nodes(data=True):
                if node in self.inv_for_classification:
                    self.count_unique_windowed_inv_connected_to_seeds += 1
                    if node in investorIDs:
                        lastTradeDay = lastTransactionDaysList.get(node)

                        currentInvestorTransaction = self.companyAllTransactionDf[
                            self.companyAllTransactionDf["InvestorID_EC"] == node
                        ]
                        currentInvestorLastTransactions = currentInvestorTransaction[
                            currentInvestorTransaction["TradeDay"] == lastTradeDay
                        ]

                        current_price = currentInvestorLastTransactions["Price"].mean()
                        volume = sum(currentInvestorLastTransactions["Volume"])

                        test = companyGroupedTradedayDf[
                            companyGroupedTradedayDf["TradeDay"] == announcementDay
                        ]

                        if test.empty:

                            sorted_dates = companyGroupedTradedayDf[
                                "TradeDay"
                            ].to_numpy(dtype="datetime64[ns]")
                            announcementDay_np = np.datetime64(announcementDay, "ns")

                            idx = np.searchsorted(sorted_dates, announcementDay_np)

                            if idx < len(sorted_dates):
                                future_price = companyGroupedTradedayDf.iloc[idx][
                                    "Price"
                                ]
                            else:
                                print(
                                    "Unavailable announcement date", announcementDay_np
                                )
                                print("No available future date")
                                no_transaction_on_announcement_day += 1
                                flag_for_no_trade_on_announcement_day = 1
                                break

                        else:
                            companies_transactions_on_last_trade_day = (
                                companyGroupedTradedayDf[
                                    companyGroupedTradedayDf["TradeDay"]
                                    == announcementDay
                                ].index[0]
                            )
                            if companies_transactions_on_last_trade_day + 1 < len(
                                companyGroupedTradedayDf
                            ):
                                future_price = companyGroupedTradedayDf.iloc[
                                    companies_transactions_on_last_trade_day + 1
                                ]["Price"]
                            else:
                                future_price = companyGroupedTradedayDf.iloc[
                                    companies_transactions_on_last_trade_day
                                ]["Price"]

                        if volume == 0:
                            current_announcement_trade_vector.append(
                                self.trade_options[-1]
                            )
                        elif (current_price < future_price and volume > 0) or (
                            current_price > future_price and volume < 0
                        ):
                            current_announcement_trade_vector.append(
                                self.trade_options[0]
                            )
                        else:
                            current_announcement_trade_vector.append(
                                self.trade_options[1]
                            )
                    else:
                        current_announcement_trade_vector.append(self.trade_options[-1])

            if flag_for_no_trade_on_announcement_day == 0:
                if announcement == 0:
                    trade_data_matrix = np.zeros(
                        (self.count_unique_windowed_inv_connected_to_seeds, 0),
                        dtype=int,
                    )

                current_announcement_trade_vector = np.array(
                    current_announcement_trade_vector
                ).reshape(-1, 1)
                trade_data_matrix = np.hstack(
                    (trade_data_matrix, current_announcement_trade_vector)
                )

            announcement += 1

        self.trade_matrix_shape_0 = trade_data_matrix.shape[0]
        self.trade_matrix_shape_1 = trade_data_matrix.shape[1]

        self.observed_data_feature_vector_DF = (
            self.compute_feature_DF_with_bootsrapping(
                trade_data_matrix, self.observed_data_class_label
            )
        )

    def GT_finding_ouside_of_announcement_period(self):
        announcement = 0

        for non_announcement_trade_day in self.non_announcement_trade_days:
            inv_traded_in_current_non_announcement_day_IDs = (
                self.companyNonAnnouncementTransactionDf[
                    self.companyNonAnnouncementTransactionDf["TradeDay"]
                    == non_announcement_trade_day
                ]["InvestorID_EC"].unique()
            )

            flag_for_no_trade_on_announcement_day = 0
            self.count_unique_windowed_inv_connected_to_seeds = 0
            current_announcement_trade_vector = []
            for node, parameters in self.graph.nodes(data=True):
                if node in self.inv_for_classification:
                    self.count_unique_windowed_inv_connected_to_seeds += 1
                    if node in inv_traded_in_current_non_announcement_day_IDs:

                        currentInvestorLastTransactions = (
                            self.companyNonAnnouncementTransactionDf[
                                (
                                    self.companyNonAnnouncementTransactionDf[
                                        "InvestorID_EC"
                                    ]
                                    == node
                                )
                                & (
                                    self.companyNonAnnouncementTransactionDf["TradeDay"]
                                    == non_announcement_trade_day
                                )
                            ]
                        )

                        current_price = currentInvestorLastTransactions["Price"].mean()
                        volume = sum(currentInvestorLastTransactions["Volume"])

                        future_price = currentInvestorLastTransactions["P_week"].iloc[0]

                        if volume == 0:
                            current_announcement_trade_vector.append(
                                self.trade_options[-1]
                            )
                        elif (current_price < future_price and volume > 0) or (
                            current_price > future_price and volume < 0
                        ):
                            current_announcement_trade_vector.append(
                                self.trade_options[0]
                            )

                        else:
                            current_announcement_trade_vector.append(
                                self.trade_options[1]
                            )
                    else:
                        current_announcement_trade_vector.append(self.trade_options[-1])
            if announcement == 0:
                trade_data_matrix = np.zeros(
                    (self.count_unique_windowed_inv_connected_to_seeds, 0), dtype=int
                )

            current_announcement_trade_vector = np.array(
                current_announcement_trade_vector
            ).reshape(-1, 1)

            trade_data_matrix = np.hstack(
                (trade_data_matrix, current_announcement_trade_vector)
            )

            announcement += 1
        self.trade_matrix_shape_0 = trade_data_matrix.shape[0]
        self.trade_matrix_shape_1 = trade_data_matrix.shape[1]

        self.observed_data_feature_vector_DF = (
            self.compute_feature_DF_with_bootsrapping(
                trade_data_matrix, self.observed_data_class_label
            )
        )
        print("Ground Truth Feature vectros are computed from the transactions data.")

    def compute_feature_DF_with_bootsrapping(self, trade_matrix, class_label):
        n_investors = trade_matrix.shape[0]
        n_information_prop = trade_matrix.shape[1]
        all_feature_vectors = {}
        self.n_information_prop_train = 3 * n_information_prop // 5

        self.n_classification_vector_train, n_classification_vector_test = (
            3 * self.number_of_classification_vectors // 5,
            2 * self.number_of_classification_vectors // 5,
        )

        bootstrapped_trade_matrices_train = [
            trade_matrix[
                :,
                np.random.choice(
                    self.n_information_prop_train,
                    self.number_of_information_cascades_for_bootstrap_observed,
                    replace=False,
                ),
            ]
            for _ in range(self.n_classification_vector_train)
        ]
        bootstrapped_trade_matrices_test = [
            trade_matrix[
                :,
                np.random.choice(
                    np.arange(self.n_information_prop_train + 1, n_information_prop),
                    self.number_of_information_cascades_for_bootstrap_observed,
                    replace=False,
                ),
            ]
            for _ in range(n_classification_vector_test)
        ]

        bootstrapped_trade_matrices = (
            bootstrapped_trade_matrices_train + bootstrapped_trade_matrices_test
        )

        round = 0
        for bootstrapped_trade_matrix in bootstrapped_trade_matrices:
            feature_vectors = []
            for i in range(0, n_investors):
                single_feature_vector = []
                n_information_prop_bootstarpped = bootstrapped_trade_matrix.shape[1]
                current_information_propagation_trade = bootstrapped_trade_matrix[i]
                count_1 = (current_information_propagation_trade == 1).sum()
                count_neg1 = (current_information_propagation_trade == -1).sum()
                count_0 = (current_information_propagation_trade == 0).sum()
                frac_1 = count_1 / n_information_prop_bootstarpped
                frac_neg1 = count_neg1 / n_information_prop_bootstarpped
                frac_0 = count_0 / n_information_prop_bootstarpped
                single_feature_vector = [frac_1, frac_0, frac_neg1, class_label]
                feature_vectors.append(single_feature_vector)
            all_feature_vectors[f"round_{round}"] = feature_vectors
            round += 1

        feature_df = pd.DataFrame(all_feature_vectors)

        return feature_df

    def classification_of_feature_vector(
        self, observed_feature_vector, simulated_feature_vector, investor=None
    ):
        combined_dataset = observed_feature_vector + simulated_feature_vector
        combined_dataset = np.array(combined_dataset)

        X = combined_dataset[:, :-1]

        Y = combined_dataset[:, -1]

        self.n_classification_vector_train, n_classification_vector_test = (
            3 * self.number_of_classification_vectors // 5,
            2 * self.number_of_classification_vectors // 5,
        )
        n_train_feature_vectors = self.n_classification_vector_train * 2
        X_train = X[:n_train_feature_vectors]  # First 60 elements for training
        X_test = X[n_train_feature_vectors:]  # Remaining 40 elements for testing

        y_train = Y[:n_train_feature_vectors]  # First 60 elements for training
        y_test = Y[n_train_feature_vectors:]  # Remaining 40 elements for testing

        if not self.node_best_hyper_params:
            self.classifier.fit(X_train, y_train)
            y_pred = self.classifier.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            return accuracy
        else:
            current_investors_best_hyper_param = self.node_best_hyper_params[
                str(investor)
            ]
            model = SVC(**current_investors_best_hyper_param)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)

            accuracy = accuracy_score(y_test, y_pred)
            return accuracy

    def perform_classification_on_investor_level(self, params):
       
        (
            self.info_propagation_probability_P,
            self.informed_correct_trade_probability_Q,
        ) = params

        P = self.info_propagation_probability_P
        Q = self.informed_correct_trade_probability_Q

        epsilon = 0.01
        bounds_initialised = [(epsilon, 1 - epsilon), (epsilon, 1 - epsilon)]
        if not (
            bounds_initialised[0][0] <= P <= bounds_initialised[0][1]
            and bounds_initialised[1][0] <= Q <= bounds_initialised[1][1]
        ):
            return 1.0

        self.set_P_and_Q()  # Set Q

        self.simulated_data_feature_vector_DF = self.create_feature_DF(
            class_label=self.simulated_data_class_label
        )

        classification_accuracies_for_current_sim = []
        for inv in range(0, len(self.inv_for_classification)):
            observed_inv_features = self.observed_data_feature_vector_DF.iloc[inv]
            simulated_inv_features = self.simulated_data_feature_vector_DF.iloc[inv]
            accuracy = self.classification_of_feature_vector(
                observed_inv_features.tolist(),
                simulated_inv_features.tolist(),
                investor=inv,
            )
            classification_accuracies_for_current_sim.append(accuracy)

        print(
            "Company:",
            self.company_name,
            "P:",
            self.info_propagation_probability_P,
            "Q:",
            self.informed_correct_trade_probability_Q,
            "Accuracy:",
            np.mean(classification_accuracies_for_current_sim),
            "Classifier",
            "SVM",
            "Non announcement period computation:",
            self.compute_non_announcement_days,
        )
        # print(classification_accuracies_for_current_sim) #To check entity level accuracies
        self.results_P_and_Q.append(
            (
                np.mean(classification_accuracies_for_current_sim),
                self.info_propagation_probability_P,
                self.informed_correct_trade_probability_Q,
            )
        )

        return np.mean(classification_accuracies_for_current_sim)

    def inference_via_classification(self):
        
        # Step 1: Set initial weights and bounds
        initial_weights = np.array([0.1, 0.1])
        epsilon = 0.01
        bounds_initialised = [(epsilon, 1 - epsilon), (epsilon, 1 - epsilon)]

        # Step 2: Helper to run Powell optimization
        def run_powell_optimization(weights):
            return minimize(
                self.perform_classification_on_investor_level,
                weights,
                bounds=bounds_initialised,
                method="Powell",
                options={"maxiter": 300, "disp": True},
            )

        # Step 3: First Powell optimization
        result_initial = run_powell_optimization(initial_weights)
        optimized_weights_initial = result_initial.x
        p_hat, q_hat = optimized_weights_initial

        # Step 4: Print Powell optimization summary
        print(
            "p_hat and q_hat during initial optimization:", *optimized_weights_initial
        )
        print("Accuracy during initial optimization:", result_initial.fun)

        # Step 5: Set P and Q from Powell output and simulate feature vectors
        self.info_propagation_probability_P = p_hat
        self.informed_correct_trade_probability_Q = q_hat
        self.set_P_and_Q()
        self.simulated_data_feature_vector_DF = self.create_feature_DF(
            class_label=self.simulated_data_class_label
        )

        # Step 6: Hyperparameter optimization
        (
            classification_accuracies_from_hyper_param_optimisation,
            self.node_best_hyper_params,
        ) = hyper_parameter_optimise(
            self.observed_data_feature_vector_DF,
            self.simulated_data_feature_vector_DF,
        )
        print("Best hyper parameters:", self.node_best_hyper_params)
        print(
            "Final accuracy of Powell:",
            result_initial.fun,
            "Optimized hyperparameter models accuracy:",
            np.mean(classification_accuracies_from_hyper_param_optimisation),
        )

        # Step 7: Final Powell optimization after hyperparameter tuning
        self.results_P_and_Q = []
        result_final = run_powell_optimization(initial_weights)
        optimized_weights_final = result_final.x
        p_hat_final, q_hat_final = optimized_weights_final
        accuracy = result_final.fun

        # Step 8: Print comparison of optimization before and after tuning
        print(
            "GT p and q:",
            self.info_propagation_probability_P,
            self.informed_correct_trade_probability_Q,
        )
        print("p_hat and q_hat before optimization:", *optimized_weights_initial)
        print("p_hat and q_hat after optimization:", *optimized_weights_final)
        print("Accuracy before:", result_initial.fun)
        print("Accuracy after:", result_final.fun)

        # Step 9: Store final results
        results = [
            self.company_name,
            p_hat_final,
            q_hat_final,
            accuracy,
            self.compute_non_announcement_days,
        ]
       
        results_df = pd.DataFrame(
            [results],
            columns=[
                "Company",
                "P_hat",
                "Q_hat",
                "Accuracy",
                "Is_non_announcement_period_computation",
            ],
        )

        # Step 10: Save or append to CSV
        file_name = f"Infered_p_and_q_of_all_companies.csv"
        results_dir = os.path.join(
            os.getcwd(), "Learning_transmission_probs/Results/empirical_results"
        )
        os.makedirs(results_dir, exist_ok=True)
        file_path = os.path.join(results_dir, file_name)
        if os.path.exists(file_path):
            existing_df = pd.read_csv(file_path)
            results_df = pd.concat([existing_df, results_df], ignore_index=True)
        results_df.to_csv(file_path, index=False)
        print(f"Results saved to {file_path}")
        
