import copy
import os
import random
from queue import Queue
from random import choices, sample, uniform


import numpy as np
import pandas as pd
import networkx as nx
from scipy.stats import entropy
from scipy.optimize import minimize
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from Learning_transmission_probs.src.Monte_carlo_experiments.configs import ExperimentConfig
from Learning_transmission_probs.src.Monte_carlo_experiments.classifier_hyper_parameter_optimise import (
    hyper_parameter_optimise_multi_classifier,
)


np.random.seed(42)


class Experiment:
    def __init__(
        self, graph: nx.DiGraph, config: ExperimentConfig, graph_type="Tree"
    ) -> None:
        self.classifier = None
        self.classifier_name = None
        self.config = config
        self.GT_P = None
        self.GT_Q = None
        self.graph = copy.deepcopy(graph)
        self.inv_for_classification = []
        self.node_best_hyper_params = {}
        self.non_activated_entity_symptom_probs = (
            config.non_activated_entity_symptom_probs
        )  # q_baseline
        self.number_of_classification_vectors = config.number_of_classification_vectors
        self.number_of_propagation = (
            config.number_of_propagation_for_empirical_graph
            if graph_type == "Insiders_network"
            else config.number_of_propagation_for_synthetic_graph
        )
        self.observed_data_class_label = config.observed_data_class_label
        self.results_P_and_Q = []
        self.seeds = (
            config.random_seed
            if graph_type != "Insiders_network"
            else config.empirical_seeds
        )
        self.simulated_data_class_label = config.simulated_data_class_label
        self.symtptom_options = config.symtptom_options

    def run(
        self,
        classifier_name,
        classifier,
        graph_type,
        GT_P,  # Ground Truth propagation probability 'p'
        GT_Q,  # Ground Truth symptom probability 'q',
        feature_type,
    ):
        self.node_best_hyper_params = {}
        self.classifier_name = classifier_name
        self.classifier = classifier
        self.GT_P = GT_P
        self.GT_Q = GT_Q
        self.graph_type = graph_type
        self.propagation_prob = GT_P
        self.activated_entity_positive_symptom_prob = GT_Q
        self.feature_type = feature_type
        self.set_node_property()
        if graph_type == "Insiders_network":
            self.find_shortest_distance_from_seeds()
            self.inv_for_classification = list(
                set(self.nodes) - set(self.isolated_nodes)
            )
        else:
            self.explore_network_with_random_restart_to_select_entity()
        self.set_P_and_Q(isFromSimulation=False)
        self.inference_via_classification()

    def test_random(self, success_probability) -> bool:
        return uniform(0, 1) < success_probability

    def find_shortest_distance_from_seeds(self):
        shortest_paths = {}

        for seed in self.seeds:
            paths_from_seed = nx.single_source_shortest_path(self.graph, seed)

            for target, path in paths_from_seed.items():
                if target not in shortest_paths or len(shortest_paths[target]) > len(
                    path
                ):
                    shortest_paths[target] = path

        self.shortest_path = []
        self.nodes = []
        self.farthest_nodes = []
        self.isolated_nodes = []
        self.shortest_distance = []
        temp_short_dist = []

        for target_node in self.graph.nodes:
            self.nodes.append(target_node)
            path = shortest_paths.get(target_node, "no_path_value")
            self.shortest_path.append(path)

            if path == "no_path_value":
                self.farthest_nodes.append(target_node)
                self.isolated_nodes.append(target_node)
                self.shortest_distance.append("no distance")
                temp_short_dist.append(15)
            else:
                dist = len(path) - 1
                self.shortest_distance.append(dist)
                temp_short_dist.append(dist)

                if dist > 5:
                    self.farthest_nodes.append(target_node)

    def set_node_property(self):
        if isinstance(self.seeds, int):
            self.seeds = [self.seeds]

        seeds_set = set(self.seeds)
        for node, properties in self.graph.nodes(data=True):
            properties["is_seed"] = node in seeds_set
            properties["is_in_pathway"] = False
            properties["is_informed"] = False
            properties["symptom"] = 0
        return True

    def explore_network_with_random_restart_to_select_entity(self):
        self.inv_for_classification = set(self.seeds)
        current_node = self.seeds[0]

        while len(self.inv_for_classification) < self.graph.number_of_nodes():
            second_order_neighbours = (
                set(self.graph.neighbors(current_node)) - self.inv_for_classification
            )
            if not second_order_neighbours:
                current_node = random.choice(list(self.inv_for_classification))
                continue
            if random.random() < 0.5:
                current_node = self.seeds[0]
            else:
                selected_neighbour = sample(sorted(second_order_neighbours), 1)[0]
                self.inv_for_classification.update(second_order_neighbours)
                current_node = selected_neighbour

    def set_P_and_Q(self, isFromSimulation=False):
        for node, properties in self.graph.nodes(data=True):
            properties["activated_entity_positive_symptom_prob"] = (
                self.activated_entity_positive_symptom_prob
            )
            neighbors = list(self.graph.neighbors(node))
            for neighbor in neighbors:
                self.graph[node][neighbor]["propagation_prob"] = self.propagation_prob
        return True

    def propagate_from_seeds(self):
        queue = Queue()
        for seed in self.seeds:
            queue.put(seed)
        while not queue.empty():
            current_node = queue.get()
            self.graph.nodes[current_node]["is_informed"] = True
            neighbours = [*self.graph.neighbors(current_node)]
            for neighbour in neighbours:
                if self.graph.nodes[neighbour]["is_informed"]:
                    continue
                if self.test_random(
                    self.graph[current_node][neighbour]["propagation_prob"]
                ):
                    self.graph.nodes[neighbour]["is_informed"] = True
                    queue.put(neighbour)

    def symptom(self):
        symptoms = []
        informed = []
        informed_nodes = 0
        for node, parameters in self.graph.nodes(data=True):
            informed.append(parameters["is_informed"])
            if parameters["is_informed"]:
                informed_nodes += 1
                current_symptom = [
                    (
                        self.symtptom_options[0]
                        if self.test_random(
                            parameters["activated_entity_positive_symptom_prob"]
                        )
                        else self.symtptom_options[-1]
                    )
                ]

            else:
                current_symptom = choices(
                    self.symtptom_options, self.non_activated_entity_symptom_probs
                )
            parameters["symptom"] = current_symptom[0]
            symptoms.append(current_symptom[0])
        self.symptoms = symptoms
        self.informed = informed

    def save_results_to_csv(self, results, file_name):
        results_df = pd.DataFrame(
            [results],
            columns=[
                "GT_P",
                "GT_Q",
                "P_hat",
                "Q_hat",
                "MSE_P",
                "MSE_Q",
                "MSE_PQ",
                "Accuracy",
            ],
        )
        results_dir = os.path.join(os.getcwd(), "Learning_transmission_probs/Results/synthetic_results")
        os.makedirs(results_dir, exist_ok=True)
        file_path = os.path.join(results_dir, file_name)
        if os.path.exists(file_path):
            existing_df = pd.read_csv(file_path)
            results_df = pd.concat([existing_df, results_df], ignore_index=True)
            results_df = results_df.sort_values(by=["GT_P", "GT_Q"]).reset_index(
                drop=True
            )
        results_df.to_csv(file_path, index=False)
        print(f"Results saved to {file_path}")

    def compute_feature_vector(self, symptom_matrix, class_label):
        feature_vectors = []
        n_entitys = symptom_matrix.shape[0]
        n_prop = symptom_matrix.shape[1]
        symptomDf = pd.DataFrame(symptom_matrix)
        for i in range(0, n_entitys):
            single_feature_vector = []
            current_propagation_symptom = symptomDf.iloc[i]
            count_1 = (current_propagation_symptom == 1).sum()
            count_neg1 = (current_propagation_symptom == -1).sum()
            count_0 = (current_propagation_symptom == 0).sum()
            frac_1 = count_1 / n_prop
            frac_neg1 = count_neg1 / n_prop
            frac_0 = count_0 / n_prop

            if self.feature_type == "Limited":
                single_feature_vector = [frac_1, frac_0, frac_neg1, class_label]
            elif self.feature_type == "Extended":
                # Constructing feature vector with statistical and temporal properties:
                # - Basic stats: mean and variance of signal values
                # - Distribution-based entropy over signal types (1, 0, -1)
                # - Temporal difference: change in counts of 1s, 0s, and -1s between first and second half
                mean = np.mean(current_propagation_symptom)
                variance = np.var(current_propagation_symptom)
                counts = np.array([count_1, count_neg1, count_0])
                probs = counts / n_prop
                probs = probs[probs > 0]
                entropy_val = entropy(probs)
                first_half = current_propagation_symptom[: n_prop // 2]
                second_half = current_propagation_symptom[n_prop // 2 :]
                count_1_first = np.sum(first_half == 1)
                count_1_second = np.sum(second_half == 1)
                diff_1 = count_1_second - count_1_first
                count_neg1_first = np.sum(first_half == -1)
                count_neg1_second = np.sum(second_half == -1)
                diff_neg1 = count_neg1_second - count_neg1_first
                count_0_first = np.sum(first_half == 0)
                count_0_second = np.sum(second_half == 0)
                diff_0 = count_0_second - count_0_first
                single_feature_vector = [
                    frac_1,
                    frac_0,
                    frac_neg1,
                    mean,
                    variance,
                    entropy_val,
                    diff_1,
                    diff_neg1,
                    diff_0,
                    class_label,
                ]

            feature_vectors.append(single_feature_vector)
        return feature_vectors

    def entity_symptom_vector(self):
        symptom_vector = [
            (
                1
                if parameters["symptom"] == 1
                else (-1 if parameters["symptom"] == -1 else 0)
            )
            for node, parameters in self.graph.nodes(data=True)
            if node in self.inv_for_classification
        ]
        return symptom_vector

    def entity_symptom_matrix(self, class_label):
        symptom_data_matrix = np.zeros(
            (len(self.inv_for_classification), self.number_of_propagation),
            dtype=int,
        )
        for round_idx in range(self.number_of_propagation):
            self.propagate_from_seeds()
            self.symptom()
            symptom_vector = self.entity_symptom_vector()
            symptom_data_matrix[:, round_idx] = symptom_vector
            for _, properties in self.graph.nodes(data=True):
                properties["is_informed"] = False
                properties["symptom"] = 0
        round_feature_vectors = self.compute_feature_vector(
            symptom_matrix=symptom_data_matrix, class_label=class_label
        )
        return round_feature_vectors

    def get_classification_vectors(self, class_label):
        round_features = []
        for round in range(0, self.number_of_classification_vectors):
            current_round_feature_vectors = self.entity_symptom_matrix(
                class_label=class_label
            )
            round_features.append(
                pd.DataFrame({f"round_{round}": current_round_feature_vectors})
            )
        feature_df = pd.concat(round_features, axis=1)
        return feature_df

    def classification_of_feature_vector(
        self, observed_feature_vector, simulated_feature_vector, entity=None
    ):
        combined_dataset = observed_feature_vector + simulated_feature_vector
        combined_dataset = np.array(combined_dataset)

        X = combined_dataset[:, :-1]
        Y = combined_dataset[:, -1]
        X_train, X_test, y_train, y_test = train_test_split(
            X, Y, test_size=0.4, random_state=42
        )
        scaler = MinMaxScaler()
        if not self.node_best_hyper_params:
            model = (
                self.classifier(max_iter=100)
                if self.classifier_name == "Logistic Regression"
                else self.classifier()
            )
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            return accuracy
        else:
            current_entitys_best_hyper_param = self.node_best_hyper_params[str(entity)]
            model = self.classifier(**current_entitys_best_hyper_param)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            return accuracy

    def perform_classification_on_entity_level(self, params):
        (
            self.propagation_prob,
            self.activated_entity_positive_symptom_prob,
        ) = params
        self.set_P_and_Q(isFromSimulation=True)
        self.simulated_data_feature_vector_DF = self.get_classification_vectors(
            class_label=self.simulated_data_class_label
        )
        classification_accuracies_for_current_sim = []
        for inv in range(0, len(self.inv_for_classification)):
            observed_inv_features = self.observed_data_feature_vector_DF.iloc[inv]
            simulated_inv_features = self.simulated_data_feature_vector_DF.iloc[inv]
            accuracy = self.classification_of_feature_vector(
                observed_inv_features.tolist(),
                simulated_inv_features.tolist(),
                entity=inv,
            )
            classification_accuracies_for_current_sim.append(accuracy)
        # print(classification_accuracies_for_current_sim) #To check entity level accuracies
        print(
            "P:",
            self.propagation_prob,
            "Q:",
            self.activated_entity_positive_symptom_prob,
            "Accuracy:",
            np.mean(classification_accuracies_for_current_sim),
        )
        self.results_P_and_Q.append(
            (
                np.mean(classification_accuracies_for_current_sim),
                self.propagation_prob,
                self.activated_entity_positive_symptom_prob,
            )
        )
        return np.mean(classification_accuracies_for_current_sim)

    def inference_via_classification(self):
        # Step 1: Generate observed feature vectors
        self.observed_data_feature_vector_DF = self.get_classification_vectors(
            class_label=self.observed_data_class_label
        )
        print(
            "Observed data feature vector",
            len(self.observed_data_feature_vector_DF),
            self.observed_data_feature_vector_DF,
        )

        # Step 2: Helper function to run Powell optimization
        def run_powell_optimization(initial_weights):
            return minimize(
                self.perform_classification_on_entity_level,
                initial_weights,
                bounds=[(0.0, 1.0), (0.0, 1.0)],
                method="Powell",
                options={"maxiter": 300, "disp": True, "ftol": 0.001, "xtol": 0.001},
            )

        # Step 3: Helper to extract best (p, q) from results
        def extract_best_p_q(results_df):
            min_acc = results_df["accuracy"].min()
            best_rows = results_df[results_df["accuracy"] == min_acc]

            if len(best_rows) == 1:
                return best_rows.iloc[0]["p"], best_rows.iloc[0]["q"]
            else:
                return best_rows["p"].mean(), best_rows["q"].mean()

        # Step 4: Helper to calculate MSEs
        def calculate_mse(p_hat, q_hat):
            mse_p = (self.GT_P - p_hat) ** 2
            mse_q = (self.GT_Q - q_hat) ** 2
            return mse_p, mse_q, (mse_p + mse_q) / 2

        # Step 5: First optimization
        initial_weights = np.array([0.1, 0.1])
        result_initial = run_powell_optimization(initial_weights)
        result_df = pd.DataFrame(self.results_P_and_Q, columns=["accuracy", "p", "q"])
        optimized_weights_initial = result_initial.x
        p_hat, q_hat = optimized_weights_initial
        mse_p, mse_q, mse_pq = calculate_mse(p_hat, q_hat)
        print("Powell's prediction:", p_hat, q_hat)

        # Step 6: Hyperparameter optimization
        (hyper_accs, self.node_best_hyper_params) = (
            hyper_parameter_optimise_multi_classifier(
                observed_fv_DF=self.observed_data_feature_vector_DF,
                simulated_fv_DF=self.simulated_data_feature_vector_DF,
                classifier=self.classifier_name,
            )
        )
        print("Best hyper parameters", self.node_best_hyper_params)
        print("Acc Powell:", result_initial.fun, "Hyper opt acc:", np.mean(hyper_accs))

        # Step 7: Re-run Powell optimization with classifiers now having optimised hyper parameters
        self.results_P_and_Q = []
        result_final = run_powell_optimization(np.array([0.1, 0.1]))
        result_df = pd.DataFrame(self.results_P_and_Q, columns=["accuracy", "p", "q"])
        final_p, final_q = extract_best_p_q(result_df)

        if result_final.fun > result_df["accuracy"].min():
            print(
                "Best p and q:",
                final_p,
                final_q,
                "accuracy:",
                result_df["accuracy"].min(),
            )
            p_hat, q_hat = final_p, final_q
            mse_p, mse_q, mse_pq = calculate_mse(p_hat, q_hat)
            accuracy = result_df["accuracy"].min()
        else:
            print("Best p and q:", *result_final.x, "accuracy:", result_final.fun)
            p_hat, q_hat = result_final.x
            mse_p, mse_q, mse_pq = calculate_mse(p_hat, q_hat)
            accuracy = result_final.fun

        results = [
            self.GT_P,
            self.GT_Q,
            round(p_hat, 2),
            round(q_hat, 2),
            mse_p,
            mse_q,
            mse_pq,
            accuracy,
        ]
        self.save_results_to_csv(
            results,
            f"Result_{self.graph_type}_graph_{self.classifier_name}_classifier_{self.feature_type}_features.csv",
        ) 
    @property
    def node_count(self):
        return self.graph.order()
