from pathlib import Path
import csv
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# plt.rcParams.update({'font.size': 12})
plt.rcParams['font.size'] = 16



def average_score(programs):
    return np.mean([float(program['syn_score']) for program in programs])

def average_top_kscore(programs, k ):
    sorted_programs = sorted(
        programs,
        key = lambda r: float(r['syn_score']),
        reverse=True
    )[:k]
    return average_score(sorted_programs)

def load_data_rows(args):
    rows = []
    for input_file in args.evaluation_files:
        with open(input_file) as input_file:
            reader = csv.DictReader(input_file)
            rows.extend(list(reader))
    return rows

def plot_asr_distribution_by_condition(rows, args):
    conditions = sorted(set([row['condition'] for row in rows]))
    condition_lengths = []
    fig_width=len(conditions)*6
    fig, axes = plt.subplots(1, len(conditions), sharey=True, figsize=(fig_width, 5))
    bins=[0, 0.2, 0.4, 0.6, 0.8, 1]

    for plot_ix, condition in enumerate(conditions):
        contition_asrs = [float(row['syn_score']) for row in rows if row['condition'] == condition]
        condition_lengths.append(len(contition_asrs))
        axes[plot_ix].hist(contition_asrs, bins=bins)
        axes[plot_ix].title.set_text(condition.replace("_tgt_", "\n") + f"\n({len(contition_asrs)} attacks)")


    plt.xticks(bins)
    plt.tight_layout()
    plt.savefig(Path(args.output_dir) / "asr_hist_by_condition.pdf")


def main_analyse_prog_synthesis(args):
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    k = 20
    rows = load_data_rows(args)

    for ix in range(len(rows)):
        condition = rows[ix]['syn_algo'] + '_' + rows[ix]['syn_examples_type']
        if 'synthesis_target' in rows[ix]:
            condition += "_tgt_" + rows[ix]['synthesis_target']
        rows[ix]['condition'] = condition

    plot_asr_distribution_by_condition(rows, args)

    all_conditions = sorted(set(r['condition'] for r in rows))
    all_iterations = sorted(set([
        int(row['syn_iteration']) 
        for row in rows
    ]))
    
    conditions_data = {}

    for condition in all_conditions:
        condition_rows = [row for row in rows if row['condition'] == condition]
        average_rewards = []
        average_top_k_rewards = []
        condition_iterations = set([
            int(row['syn_iteration']) 
            for row in rows if row['condition'] == condition
        ])

        for it in all_iterations:
            if it in condition_iterations:
                programs_up_to_iteration = [r for r in condition_rows if int(r['syn_iteration']) <= it]
                average_rewards.append(average_score(programs_up_to_iteration))
                average_top_k_rewards.append(average_top_kscore(programs_up_to_iteration, k=k))
            else:
                average_rewards.append(np.nan)
                average_top_k_rewards.append(np.nan)

        conditions_data[condition] = {
            "average_rewards": average_rewards,
            "average_top_k_rewards": average_top_k_rewards
        }


        
    plt.figure(figsize=(20,5))
    for condition_ix, condition in enumerate(all_conditions):
        plt.plot(
            all_iterations, 
            conditions_data[condition]['average_rewards'], 
            linewidth=3.0,
            color=args.color_palette[condition_ix]
        )
        plt.title("Mean Attack Success Rate (All Generated Programs)")
        # plt.xticks(all_iterations)
        plt.xlabel("Iterations")
        plt.ylabel("Mean ASR")
    
    plt.legend(all_conditions)
    plt.savefig(Path(args.output_dir) / f"{Path(args.output_dir).stem}_average_rewards.{args.plot_extension}")

    plt.figure(figsize=(20,5))
    for condition_ix, condition in enumerate(conditions_data):
        plt.plot(
            all_iterations, 
            conditions_data[condition]['average_top_k_rewards'], 
            linewidth=3.0,
            color=args.color_palette[condition_ix]
        )
        plt.title(f"Mean Attack Success Rate (Top {k} Generated Programs)")
        # plt.xticks(all_iterations)
        plt.xlabel("Iterations")
        plt.ylabel("Mean ASR")

    plt.legend(all_conditions)
    plt.savefig(Path(args.output_dir) / f"{Path(args.output_dir).stem}_average_top_{k}_rewards.{args.plot_extension}")
            
