import random
import multiprocessing as mp
import time
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy import stats
import pandas as pd
from our_method import MCGraph, BidirectionalSolver, get_dag_signature
from baseline_method import MultiNodeDAGToSimplify

class SpeedupExperiment:
    def __init__(self, num_trials=50, random_seed=42):
        self.num_trials = num_trials
        self.random_seed = random_seed
        random.seed(random_seed)
        np.random.seed(random_seed)
    
    def run_single_trial(self, num_nodes, edge_prob, num_to_remove):
        """given num nodes in graph, edge prob, and num of nodes being removed
        determines how long it takes for the two methods to simplify the graph"""
        # MAKING THE GRAPH
        nodes = [f'N{i+1}' for i in range(num_nodes)]
        mcgraph = MCGraph(nodes=nodes)
        
        for i in range(num_nodes):
            for j in range(i+1, num_nodes):
                if random.random() < edge_prob:
                    mcgraph.add_edge(nodes[i], '-->', nodes[j])
        
        if num_nodes <= 2 or num_to_remove >= num_nodes:
            return None  # Skip invalid configurations
            
        nodes_to_remove = random.sample(nodes, num_to_remove)
        
        try:
            # TESTING BN SIMPLIFICATION
            start_time = time.monotonic()
            within_BNs = MultiNodeDAGToSimplify(mcgraph, nodes_to_remove)
            topo_orders = within_BNs.get_topological_orders(nodes_to_remove)
            bn_graphs, strats = within_BNs.solve()
            end_time = time.monotonic()
            bn_time = end_time - start_time
            
            # SIMPLIFYING THE MAGs
            start_time = time.monotonic()
            solver = BidirectionalSolver(mcgraph, nodes_to_remove)
            mag_solutions = solver.solve()
            end_time = time.monotonic()
            mag_time = end_time - start_time
            
            if mag_time > 0:
                return bn_time / mag_time
            else:
                return None
                
        except Exception as e:
            print(f"Error in trial: {e}")
            return None
    
    def run_experiment(self, num_nodes, edge_prob, num_to_remove):
        """given a set of stats, calls the function above self.num_trials
        number of times, reports the result"""
        speedups = []
        
        print(f"Running experiment: nodes={num_nodes}, edge_prob={edge_prob:.2f}, remove={num_to_remove}")
        
        for trial in range(self.num_trials):
            if trial % 10 == 0:
                print(f"  Trial {trial}/{self.num_trials}")
                
            speedup = self.run_single_trial(num_nodes, edge_prob, num_to_remove)
            if speedup is not None and np.isfinite(speedup):
                speedups.append(speedup)
        
        if len(speedups) == 0:
            return None
            
        speedups = np.array(speedups)
        
        #  statistics; mean, sample std, std error
        mean_speedup = np.mean(speedups)
        std_speedup = np.std(speedups, ddof=1)  
        sem_speedup = std_speedup / np.sqrt(len(speedups))  
        
        # 95% CI
        confidence_level = 0.95
        alpha = 1 - confidence_level
        t_critical = stats.t.ppf(1 - alpha/2, len(speedups) - 1)
        margin_error = t_critical * sem_speedup
        
        return {
            'mean': mean_speedup,
            'std': std_speedup,
            'sem': sem_speedup,
            'ci_lower': mean_speedup - margin_error,
            'ci_upper': mean_speedup + margin_error,
            'n_samples': len(speedups),
            'raw_data': speedups
        }

def experiment_vary_nodes():
    """experiment 1, vary number of nodes"""
    exp = SpeedupExperiment(num_trials=30)
    
    # fix other 2 params
    edge_prob = 0.4
    num_to_remove = 4
    
    # num of nodes varies
    node_values = [5, 7, 10, 12, 15, 18, 20, 22, 25]
    
    results = []
    for num_nodes in node_values:
        result = exp.run_experiment(num_nodes, edge_prob, num_to_remove)
        if result:
            results.append({
                'num_nodes': num_nodes,
                **result
            })
    
    return pd.DataFrame(results), 'num_nodes'

def experiment_vary_edge_prob():
    """experiment 2, vary edge probability"""
    exp = SpeedupExperiment(num_trials=30)
    
    num_nodes = 10
    num_to_remove = 4
    
    edge_prob_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    
    results = []
    for edge_prob in edge_prob_values:
        result = exp.run_experiment(num_nodes, edge_prob, num_to_remove)
        if result:
            results.append({
                'edge_prob': edge_prob,
                **result
            })
    
    return pd.DataFrame(results), 'edge_prob'

def experiment_vary_remove():
    """experiment 3, vary number of nodes to remove"""
    exp = SpeedupExperiment(num_trials=30)
    
    num_nodes = 5
    edge_prob = 0.4
    
    # remove_values = [1, 2, 3, 4, 5, 6, 7]
    remove_values = [1]
    
    results = []
    for num_to_remove in remove_values:
        if num_to_remove < num_nodes:  
            result = exp.run_experiment(num_nodes, edge_prob, num_to_remove)
            if result:
                results.append({
                    'num_to_remove': num_to_remove,
                    **result
                })
    
    return pd.DataFrame(results), 'num_to_remove'

def plot_results(df, x_col, title, xlabel):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # main plot with error bars
    ax.errorbar(df[x_col], df['mean'], 
                yerr=[df['mean'] - df['ci_lower'], df['ci_upper'] - df['mean']], 
                fmt='o-', linewidth=2, markersize=8, capsize=5, capthick=2,
                label='Mean ± 95% CI')
    
    # add individual data points in grey
    for i, row in df.iterrows():
        y_jitter = np.random.normal(0, 0.02 * row['mean'], len(row['raw_data']))
        ax.scatter([row[x_col]] * len(row['raw_data']), 
                  row['raw_data'] + y_jitter, 
                  alpha=0.3, s=10, color='gray')
    plt.yscale('log')
    
    ax.set_xlabel(xlabel, fontsize=16)
    ax.set_ylabel('Speedup (Previous Method Time / Our Method Time)', fontsize=12)
    ax.set_ylim(ymin=1)
    # ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=14)
    
    # sample size annotations
    for i, row in df.iterrows():
        ax.annotate(f'{round(row["mean"])}', 
                   xy=(row[x_col], row['mean']), 
                   xytext=(3, 10), textcoords='offset points',
                   fontsize=12, alpha=0.7)
    
    plt.tight_layout()
    return fig

def run_all_experiments():
    """Run all three experiments and generate plots"""
    print("="*50)
    print("RUNNING SPEEDUP EXPERIMENTS")
    print("="*50)
    
    experiments = [
        (experiment_vary_nodes, "Effect of Number of Nodes on Speedup", "Number of Nodes"),
        (experiment_vary_edge_prob, "Effect of Edge Probability on Speedup", "Edge Probability"),
        (experiment_vary_remove, "Effect of Number of Nodes Removed on Speedup", "Number of Nodes Removed")
    ]
    
    results_summary = {}
    
    for exp_func, title, xlabel in experiments:
        print(f"\n{title}")
        print("-" * len(title))
        
        df, x_col = exp_func()
        
        if not df.empty:
            # Print summary statistics
            print(f"\nResults Summary:")
            for _, row in df.iterrows():
                print(f"{xlabel}={row[x_col]:.2f}: "
                      f"Speedup = {row['mean']:.2f} ± {row['sem']:.3f} "
                      f"(95% CI: [{row['ci_lower']:.2f}, {row['ci_upper']:.2f}]) "
                      f"n={row['n_samples']}")
            
            # Create and save plot
            fig = plot_results(df, x_col, title, xlabel)
            filename = title.lower().replace(" ", "_").replace(":", "").replace(",", "") + ".png"
            fig.savefig(filename, dpi=300, bbox_inches='tight')
            plt.show()
            
            results_summary[x_col] = df
        else:
            print("No valid results obtained for this experiment.")
    
    return results_summary

if __name__ == "__main__":
    # Run all experiments
    all_results = run_all_experiments()
    
    # Save results to CSV files
    for param_name, df in all_results.items():
        filename = f"speedup_vs_{param_name}.csv"
        df_to_save = df.drop('raw_data', axis=1)  # Remove raw data for CSV
        df_to_save.to_csv(filename, index=False)
        print(f"Saved results to {filename}")