import sys
import subprocess
import argparse
import csv
import os
import matplotlib.pyplot as plt
import numpy as np
from statistics import mean, stdev

def run_experiment(party, size, fuzzy_lengths, num_runs=10):
    results = []
    print(f"Running test with party={party}, size={size}, {num_runs} times")
    
    for run in range(num_runs):
        print(f"  Run {run+1}/{num_runs}")
        cmd = f"./bin/test_fix {party} {size} {fuzzy_lengths}"
        print (f"    Command: {cmd}")
        
        try:
            result = subprocess.run(cmd, shell=True, check=True,
                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                   universal_newlines=True)
            output_lines = result.stdout.strip().split('\n')
            
            csv_line = None
            for line in reversed(output_lines):
                if ',' in line and not line.startswith('size,'):
                    csv_line = line
                    break
            
            if csv_line:
                values = csv_line.replace(' MB', '').replace('MB', '').split(',')
                results.append(values)
                print(f"    Run {run+1} successful")
            else:
                print(f"    Error: Could not find CSV data in output for run {run+1}")
        
        except subprocess.CalledProcessError as e:
            print(f"    Error in run {run+1}: {e}")
            print(f"    Output: {e.output}")
    
    if not results:
        print(f"All runs failed for size={size}")
        return None
    
    print(f"Completed {len(results)}/{num_runs} successful runs for size={size}")
    return results

def calculate_statistics(multiple_results):
    if not multiple_results:
        return None
    
    float_results = []
    for result in multiple_results:
        float_result = []
        for val in result:
            try:
                float_result.append(float(val))
            except ValueError:
                float_result.append(val)
        float_results.append(float_result)
    num_cols = len(float_results[0])
    
    means = []
    stds = []
    
    for col in range(num_cols):
        if isinstance(float_results[0][col], (int, float)):
            values = [float(row[col]) for row in float_results]
            col_mean = mean(values)
            col_std = stdev(values) if len(values) > 1 else 0
            means.append(col_mean)
            stds.append(col_std)
        else:
            means.append(float_results[0][col])
            stds.append(0)
    
    return {"mean": means, "std": stds}

def save_results_to_csv(results_data, headers, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(headers + ["run_type"])
        for size_data in results_data:
            if size_data:
                for i, run in enumerate(size_data["runs"]):
                    writer.writerow(run + [f"run_{i+1}"])
                if "stats" in size_data:
                    writer.writerow(size_data["stats"]["mean"] + ["mean"])
                    writer.writerow(size_data["stats"]["std"] + ["std"])
    
    print(f"Results written to {filename}")

def plot_seperate(party, sizes, results_data, output_filename=None):
    """Plot separate graphs for time and communication with error bars."""
    if party == 2:
        print("Party 2 is not supported for plots.")
        return

    plt.rcParams['text.usetex'] = True
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = 'Times'

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(2*12, 7.4))

    small_font = 32
    big_font = 38
    
    width = (sizes[1]-sizes[0]) * 0.7
    
    time_positions = np.array(sizes)
    comm_positions = np.array(sizes)
    
    fuzzy_times_mean = [data["stats"]["mean"][9] for data in results_data]
    fuzzy_times_std = [data["stats"]["std"][9] for data in results_data]
    
    model_repair_times_mean = [data["stats"]["mean"][3] + data["stats"]["mean"][4] for data in results_data]
    model_repair_times_std = [np.sqrt(data["stats"]["std"][3]**2 + data["stats"]["std"][4]**2) for data in results_data]

    database_comm_mean = [data["stats"]["mean"][7]/1024.0 for data in results_data]
    database_comm_std = [data["stats"]["std"][7]/1024.0 for data in results_data]
    
    ot_comm_mean = [data["stats"]["mean"][6]/1024.0 for data in results_data]
    ot_comm_std = [data["stats"]["std"][6]/1024.0 for data in results_data]
    
    fuzzy_comm_mean = [data["stats"]["mean"][8]/1024.0 for data in results_data]
    fuzzy_comm_std = [data["stats"]["std"][8]/1024.0 for data in results_data]
    
    db_ot_comm_mean = [db + ot for db, ot in zip(database_comm_mean, ot_comm_mean)]
    db_ot_comm_std = [np.sqrt(db**2 + ot**2) for db, ot in zip(database_comm_std, ot_comm_std)]
    
    total_comm_mean = [db + ot + fuzzy for db, ot, fuzzy in zip(database_comm_mean, ot_comm_mean, fuzzy_comm_mean)]
    total_comm_std = [np.sqrt(db**2 + ot**2 + fuzzy**2) for db, ot, fuzzy in zip(database_comm_std, ot_comm_std, fuzzy_comm_std)]
    
    bottom = np.zeros(len(sizes))
    ax1.bar(time_positions, fuzzy_times_mean, width, label='Matching', bottom=bottom, yerr=fuzzy_times_std, capsize=5)
    bottom = bottom + np.array(fuzzy_times_mean)
    
    ax1.bar(time_positions, model_repair_times_mean, width, label='Model Repair', bottom=bottom, yerr=model_repair_times_std, capsize=5)
    bottom = bottom + np.array(model_repair_times_mean)

    bottom = np.zeros(len(sizes))
    ax2.bar(comm_positions, fuzzy_comm_mean, width, label='Matching', bottom=bottom, yerr=fuzzy_comm_std, capsize=5)
    bottom = bottom + np.array(fuzzy_comm_mean)
    
    ax2.bar(comm_positions, db_ot_comm_mean, width, label='Model Repair', bottom=bottom, yerr=db_ot_comm_std, capsize=5)
    bottom = bottom + np.array(db_ot_comm_mean)
    
    ax1.set_xlabel('Database Size', fontsize=big_font)
    ax1.set_ylabel('Time (seconds)', fontsize=big_font)
    
    ax2.set_xlabel('Database Size', fontsize=big_font)
    ax2.set_ylabel('Communication (GB)', fontsize=big_font)
    ax1.tick_params(axis='both', which='major', labelsize=small_font)
    ax2.tick_params(axis='both', which='major', labelsize=small_font)

    total_times = [fuzzy + repair for fuzzy, repair in zip(fuzzy_times_mean, model_repair_times_mean)]
    ax2.set_ylim(0, 1.13*max(total_comm_mean))
    ax1.set_ylim(0, 1.13*max(total_times))

    ax1.set_xticks(sizes)
    ax2.set_xticks(sizes)

    for i, size in enumerate(sizes):
        total_time = fuzzy_times_mean[i] + model_repair_times_mean[i]
        ax1.text(time_positions[i], total_time, f'{total_time:.1f}s', 
                ha='center', va='bottom', size=small_font)

    for i, size in enumerate(sizes):
        total = total_comm_mean[i]
        ax2.text(comm_positions[i], total, f'{total:.1f}GB', 
                ha='center', va='bottom', size=small_font)
    
    ax1.legend(loc='upper left', fontsize=small_font)
    ax2.legend(loc='upper left', fontsize=small_font)
    
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.2)
    
    if output_filename:
        plt.savefig(output_filename)
        print(f"Plot saved to {output_filename}")
    else:
        plt.show()

def main():
    parser = argparse.ArgumentParser(description='Run test_fix with multiple sizes and save results to CSV')
    parser.add_argument('party', type=int, choices=[1, 2], help='Party number (1 or 2)')
    parser.add_argument('sizes', type=int, nargs='+', help='List of sizes to test')
    parser.add_argument('--fuzzy', type=int, help='length of the fuzzy keys')
    parser.add_argument('--plot-only', action='store_true', help='Only plot results from existing CSV')
    parser.add_argument('--csv-filename', type=str, help='CSV filename for existing results')
    parser.add_argument('--output', type=str, help='Output filename for the plot')
    parser.add_argument('--runs', type=int, default=10, help='Number of runs for each experiment (default: 10)')
    args = parser.parse_args()
    
    party = args.party
    sizes = args.sizes
    num_runs = args.runs
    
    exp_name = f"party:{party}_sizes:{sizes}_fuzzy:{args.fuzzy}"
    csv_filename = args.csv_filename if args.csv_filename else exp_name + '.csv'
    fuzzy_lengths = args.fuzzy
    

    if party == 1:
        headers = ['size', 'entry_size', 'total_time', 'recv_db_time', 'recv_keys_time', 'decrypt_time', 'ot_comm', 'db_comm', 'fuzzy_time']
    else:
        headers = ['size', 'entry_size', 'total_time', 'encrypt_time', 'send_db_time', 'send_keys_time', 'comm']
    
    results_data = []
    
    if not args.plot_only:
        for size in sizes:
            multiple_results = run_experiment(party, size, fuzzy_lengths, num_runs)
            if multiple_results:
                stats = calculate_statistics(multiple_results)
                
                results_data.append({
                    "size": size,
                    "runs": multiple_results,
                    "stats": stats
                })
                
                print(f"Successfully collected {len(multiple_results)} results for size={size}")
            else:
                print(f"No valid results for size={size}")

        save_results_to_csv(results_data, headers, csv_filename)
    else:
        try:
            with open(csv_filename, 'r', newline='') as csvfile:
                reader = csv.reader(csvfile)
                headers = next(reader)
                all_rows = list(reader)
                
                current_size = None
                current_data = None
                
                for row in all_rows:
                    if not row:
                        continue
                        
                    size = int(float(row[0]))
                    run_type = row[-1]
                    
                    if run_type.startswith("run_"):
                        if current_size != size:
                            if current_data:
                                results_data.append(current_data)
                            
                            current_size = size
                            current_data = {"size": size, "runs": []}
                        
                        current_data["runs"].append(row[:-1])
                    elif run_type == "mean":
                        if not current_data:
                            current_data = {"size": size, "runs": []}
                        
                        if "stats" not in current_data:
                            current_data["stats"] = {}
                        
                        current_data["stats"]["mean"] = [float(val) if val != "mean" else val for val in row[:-1]]
                    elif run_type == "std":
                        if not current_data:
                            current_data = {"size": size, "runs": []}
                        
                        if "stats" not in current_data:
                            current_data["stats"] = {}
                        
                        current_data["stats"]["std"] = [float(val) if val != "std" else val for val in row[:-1]]
                
                if current_data:
                    results_data.append(current_data)

            sizes = [data["size"] for data in results_data]
            print(f"Loaded statistics for {len(sizes)} sizes from {csv_filename}")
        
        except FileNotFoundError:
            print(f"Error: CSV file {csv_filename} not found.")
            return

    output_filename = args.output if args.output else exp_name + ".png"
    plot_seperate(party, sizes, results_data, output_filename)

if __name__ == "__main__":
    main()