from pandasql import sqldf
import pandas as pd
from pathlib import Path
import numpy as np
import datetime
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"


def load_eval_files(args):
    df_list = [
        pd.read_csv(eval_file)
        for eval_file in args.evaluation_files
    ]
    return pd.concat(df_list)
    
def string_to_table_name(s):
    return s.replace("-", "_").replace(".","_").replace("/","_")

def format_column_name_for_latex(name):
    name = name.replace(
        "\_success\_rate", ""
    ).replace(
        "meta\_llama\_Meta\_",""
    ).replace(
        "\_chat\_hf","\_hf"
    ).replace(
        "mistralai\_Mistral\_7B\_Instruct\_v0\_1", "Mistral\_7B\_Inst.1"
    ).replace(
        "claude\_3\_haiku\_20240307", "Claude\_3\_haiku"
    )
    return name


def reciprocal_rank_fusion(row, success_rate_columns):
    K=60
    s = 0
    print("--")
    #print(row)
    #sys.exist(0)
    for rank_col_name in success_rate_columns:
        s += 1/(row[rank_col_name]+K)
    print(s)
    return s
    print(r, success_rate_columns)

def main_analyse(args):
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    responses = load_eval_files(args)
    program_name_field = args.program_name_field
    # args.output_dir
    # print(responses.columns)
    combined_df = None
    distinct_target_model_names = [
        row['eval_target_model']
        for ix, row in sqldf("SELECT DISTINCT eval_target_model FROM responses").iterrows()
    ]
    for target_model in distinct_target_model_names:
        target_model_prefix = string_to_table_name(target_model)
        results_df = sqldf(f"""
WITH records AS (
    SELECT eval_target_model, {program_name_field}, eval_harmful,
           CASE WHEN TRIM(LOWER(eval_harmful)) = '{args.report_category.lower().strip()}' THEN 1.0 ELSE 0.0 END AS is_harmful
    FROM responses
)
SELECT {program_name_field} as {program_name_field},
COUNT({program_name_field}) as {target_model_prefix}_trial_count, 
SUM(is_harmful) as {target_model_prefix}_success_count, 
AVG(is_harmful) as {target_model_prefix}_success_rate,
RANK () OVER ( 
		ORDER BY AVG(is_harmful) DESC
	) as {target_model_prefix}_success_rank
FROM records
WHERE eval_target_model = '{target_model}'
GROUP BY eval_target_model, {program_name_field}
    """)
        results_df = results_df.set_index(f"{program_name_field}")

        sorted_df = results_df.sort_values(by=f"{target_model_prefix}_success_rate", ascending=False)
        sorted_df.to_csv(output_dir/f"benchmark_{target_model_prefix}.csv")

        if combined_df is None:
            combined_df = results_df
        else:
            combined_df = combined_df.join(results_df, on=f"{program_name_field}", how="outer")
    
    success_rate_columns = [f"{string_to_table_name(m)}_success_rate" for m in distinct_target_model_names]
    success_rank_columns = [f"{string_to_table_name(m)}_success_rank" for m in distinct_target_model_names]
    # combined_df["rrf_score"]=combined_df.apply(lambda r: reciprocal_rank_fusion(r, success_rank_columns), axis=1)
    
    #combined_df = combined_df.sort_values(
    #    by="rrf_score", 
    #    ascending=False
    #)
    combined_df=combined_df.reset_index()
    combined_df.set_index(f"{program_name_field}")
    combined_df.to_csv(output_dir/f"benchmark.csv")
    exported_columns = [f"{program_name_field}"] + success_rate_columns  # "rrf_score"

    exported_df=combined_df[exported_columns]
    plot_exported_df(exported_df, output_dir, args)
    
    # #, gmap=-np.log(exported_df.values+1), vmin=-5.5)\
    styler = exported_df.style\
        .background_gradient(
            axis=0, cmap="Reds",low=0, high=1, vmin=0, vmax=1
        ).highlight_null(
            props="background-color:white; color:white;"
        ).format(
            precision=2, na_rep="", escape="latex"
        ).format_index(
            {0: lambda v: format_column_name_for_latex(v)}, escape="latex", axis=1
        ).applymap_index(
            lambda v: "rotatebox:{90}--rwrap--latex; transform: rotate(-90) translateX(-32px); height:100px; max-width:25px", axis=1
        )
        
    styler.to_latex(
        output_dir/f"benchmark.tex",
        # columns = exported_columns,
        # escape=True,
        # header=['\\rotatebox{{90}}{{' + format_column_name_for_latex(c) + '}}' for c in exported_df.columns],
        # na_rep="",
        convert_css=True, hrules=True,
        # position_float="centering",
        caption=f"Benchmark Resuts. Autogenerated on {datetime.datetime.now().isoformat()}"
        #environment="longtable"
    )

    # html
    html_styler = exported_df.style\
        .background_gradient(
            axis=0, cmap="Reds",low=0, high=1, vmin=0, vmax=1
        ).highlight_null(
            props="background-color:white; color:white;"
        ).format(
            precision=2, na_rep="", escape="html"
        ).format_index(
            {0: lambda v: format_column_name_for_latex(v)}, escape="html", axis=1
        )#.applymap_index(
        #    lambda v: "rotatebox:{90}--rwrap--latex; transform: rotate(-90) translateX(-32px); height:100px; max-width:25px", axis=1
        #)
        
    html_styler.to_html(output_dir/f"benchmark.html")
#     tyler.to_latex(
#     convert_css=True, hrules=True, position_float="centering",
#     caption="Switching to Hypothesis' state-machines made.",
# )




def plot_exported_df(df, output_dir, args):
    model_names = sorted(list(df))
    data = []

    for target_model in df:
        if target_model != 'program_name':
            for ix in range(len(df['program_name'])):
                asr = df[target_model][ix]
                program_name = df['program_name'][ix]

                data.append({
                        "target_model": target_model,
                        "program_name": program_name, 
                        "asr": asr
                    })
    print(sorted(set([d['target_model'] for d in data])))
                
    distinct_target_models = args.target_model_names # sorted(set([d['target_model'] for d in data]))
    distinct_target_models_to_ix={m_name:ix for ix, m_name in enumerate(distinct_target_models)}
    distinct_target_models_display_names = [n.replace("_success_rate", "").replace("meta_llama_Meta_", "") for n in distinct_target_models]


    # program_names = args.program_names

    distinct_programs = args.program_names #sorted(set([d['program_name'] for d in data]))
    distinct_programs_to_ix={p_name:ix for ix, p_name in enumerate(distinct_programs)}
    distinct_programs_display_names = [n.replace("meta-llama/Meta-", "") for n in distinct_programs]

    print("program names")
    for n in list(df['program_name']):
        print(n)
    print()


    plt.figure(figsize=(50, 7))
    plot_x = []
    plot_y = []
    plot_c = []
    for d in data:
        if d["program_name"] not in distinct_programs: continue
        if d["target_model"] not in distinct_target_models: continue
        plot_y.append(
            distinct_target_models_to_ix[d["target_model"]]
        )
        plot_x.append(
            distinct_programs_to_ix[d["program_name"]]
        )
        plot_c.append(d['asr'])
    
    plt.scatter(plot_x, plot_y, c=plot_c, cmap="Reds", marker="s", s=600)
    plt.xticks([distinct_programs_to_ix[p] for p in distinct_programs], labels=distinct_programs_display_names, rotation=90)
    plt.yticks([distinct_target_models_to_ix[p] for p in distinct_target_models], labels=distinct_target_models_display_names, fontsize=50)
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(output_dir/"benchmark_plot.pdf", bbox_inches="tight")
    # plt.savefig(output_dir/"benchmark_plot.jpg", bbox_inches="tight")