import numpy as np
import pandas as pd

data_path = "data"


def read_json_lines_to_df(file_path):
    df = pd.read_json(open(file_path, "r", encoding="utf8"), lines=True)
    return df


def get_sample_df(data_path=data_path, task="text_summarization"):
    df1 = read_json_lines_to_df(f"{data_path}/{task}.txt")
    df2 = read_json_lines_to_df(f"{data_path}/{task}_result.txt")
    df3 = read_json_lines_to_df(f"{data_path}/{task}_ppl.txt")
    # df4 = read_json_lines_to_df(f"{data_path}/{task}_score.txt")
    # merge by id, watermark_process column
    df = pd.merge(df1, df2, on=["id", "watermark_processor"])
    df = pd.merge(df, df3, on=["id", "watermark_processor"])
    # df4 misses some rows, so we use left join
    # df = pd.merge(df, df4, on=["id", "watermark_processor"], how="left")
    # print(len(set(df2.loc[:,"watermark_processor"].to_list())))
    # for i in set(df2.loc[:,"watermark_processor"].to_list()):
    #     print('-'*80)
    #     print(i)
    # print(df2.loc[:1,"watermark_processor"].to_list())
    # exit(0)
    # print(df.iloc[0,:])
    # exit(0)
    return df


def get_bootstrap_df(data_path=data_path, task="machine_translation"):
    df1 = read_json_lines_to_df(f"{data_path}/{task}_bleu.txt")
    return df1

def extract_watermark_info(df, return_wp_list=False):
    # reweight_names=[
    #     "No Watermark","Baseline","$\delta$-reweight","$\gamma$-reweight","$\\beta$-reweight","ExpMinSampling","InverseSampling"
    # ]
    # watermark_key_names=[
    #     "ngram","skip","key_set"
    # ]
    # john_wps_set = set()
    
    def get_reweight_name(wp_str):
        if "Delta" in wp_str:
            return "$\delta$-reweight"
        if "Gamma" in wp_str:
            return "$\gamma$-reweight"
        if "Beta" in wp_str:
            return "$\\beta$-reweight"
        if "John" in wp_str:
            import re
            delta = re.findall(r"delta=(\d+\.?\d*)", wp_str)[0]
            n = "Soft" + f"($\delta$={delta})"
            # john_wps_set.add(n)
            return n
        if "ExpMinSampling" in wp_str:
            return "ExpMinSampling"
        if "InverseSampling" in wp_str:
            return "InverseSampling"
        if "Baseline" in wp_str:
            return "Baseline"
        if "Gumbel" in wp_str:
            return "Gumbel"
        if wp_str == "None":
            return "No Watermark"
        
        raise ValueError("Unknown watermark: {}".format(wp_str))
    
    def get_watermark_key_name(wp_str):
        wm_names=[]
        if "NGramHashing" in wp_str:
            if 'ignore_history:False' in wp_str:
                wm_names.append('ngram')
            else:
                wm_names.append('ngram(woh)')
        
        if "TokenSkipping" in wp_str:
            wm_names.append('skip')
        
        if "FixedKeySet" in wp_str:
            wm_names.append('keyset')
        
        if "PositionHashing" in wp_str:
            wm_names.append('position')
        
        return ','.join(wm_names)
        
    
    def map_wp_str(wp_str):
        reweight_name=get_reweight_name(wp_str)
        wm_names=get_watermark_key_name(wp_str)
        return reweight_name+' '+wm_names
            

    df = df.assign(show_wp_name=df["watermark_processor"].apply(map_wp_str))
    # john_wps = sorted(list(john_wps_set))
    # show_wp = show_wp + john_wps
    # show_wp=set()
    # print(df)
    # exit(0)
    show_wp=df['show_wp_name']
    if return_wp_list:
        return df, show_wp
    else:
        return df


def sample_df_2_stat(df, bootstrap=False, show_wp=None):
    # print([c for c in df.columns if df[c].dtype == np.float64])
    # exit(0)
    sdf = df.melt(
        id_vars=["show_wp_name"],
        value_vars=[c for c in df.columns if df[c].dtype == np.float64],
        var_name="score",
        value_name="value",
    )
    # print(sdf)
    # exit(0)
    sdf = sdf.groupby(["show_wp_name", "score"]).agg(["mean", "std", "count"])

    def format_fn(x):
        mean = x["mean"]
        if not bootstrap:
            std = x["std"] / np.sqrt(x["count"])
        else:
            std = x["std"]
        if not np.isfinite(std):
            return f"{mean:.2f}±{std:.2f}"

        std = max(std, 1e-9)
        useful_digits = np.max(-int(np.floor(np.log10(std / 3))), 0)

        fmt_str = f"{{:.{useful_digits}f}}±{{:.{useful_digits}f}}"
        return fmt_str.format(mean, std)

    sdf = sdf["value"].apply(format_fn, axis=1).unstack()
    if show_wp:
        sdf = sdf.loc[show_wp]
    return sdf


def sample_df_2_stat_undectectable_exp(df, bootstrap=False, show_wp=None):
    wp_name_list = df.loc[:, "show_wp_name"]
    # df = df.loc[
    #     :,
    #     [
    #         "show_wp_name",
    #         "reference_id",
    #         "bertscore.precision",
    #         "bertscore.recall",
    #         "bertscore.f1",
    #         "ppl",
    #         "rouge1",
    #         "rouge2",
    #         "rougeL",
    #     ],
    # ]
    # print(df.columns)
    # exit(0)
    df = df.loc[
        :,
        [
            "show_wp_name",
            "reference_id",
            "bertscore.precision",
            "bertscore.recall",
            "bertscore.f1",
            "ppl",
            "rouge1",
            "rouge2",
            "rougeL",
        ],
    ]
    df_mean = df.groupby(["show_wp_name", "reference_id"]).mean()
    df_std = df.groupby(["show_wp_name", "reference_id"]).std()

    # print(df_mean)
    # exit(0)
    # print(df_mean.index)
    # exit(0)
    baseline = df_mean.loc["No Watermark ", :]
    # exit(0)

    def format_num(x):
        mean = np.mean(x)
        std = np.std(x)
        # useful_digits=4
        if std < 1e-8:  # std==0
            useful_digits = 1
        else:
            useful_digits = np.max(-int(np.floor(np.log10(std / 3))), 0)
        useful_digits = max(4, useful_digits)
        fmt_str = f"{{:.{useful_digits}f}}±{{:.{useful_digits}f}}"
        return fmt_str.format(mean, std)

    avg_abs_diff = []  # average absolute difference
    ms_diff = []  # mean squared difference
    avg_std = []  # average std for each prompt
    avg_mean = []  # average mean for each prompt

    naive_avg = []  # average across all responses

    name_list = []
    for wp_name in sorted(list(set(wp_name_list))):
        name_list.append(wp_name)
        cur_mean = df_mean.loc[wp_name, :]
        cur_std = df_std.loc[wp_name, :]
        cur_diff = cur_mean - baseline

        # calculate avg_abs_diff
        cur_abs_diff = cur_diff.abs()
        cur_avg_abs_diff = cur_abs_diff.apply(format_num, axis=0)
        avg_abs_diff.append(cur_avg_abs_diff)

        # calculate ms_diff
        cur_ms_diff = (cur_diff.apply(lambda x: x**2)).apply(format_num, axis=0)
        ms_diff.append(cur_ms_diff)

        # calculate avg_std
        cur_avg_std = cur_std.apply(format_num, axis=0)
        avg_std.append(cur_avg_std)

        # calculate avg_mean
        cur_avg_mean = cur_mean.apply(format_num, axis=0)
        avg_mean.append(cur_avg_mean)

    def gather_res(series_list, name_list):
        df_res = pd.concat(series_list, axis=1).T
        df_res.index = name_list
        return df_res
    
    # pd.set_option('display.max_columns',None)

    avg_abs_diff = gather_res(avg_abs_diff, name_list)
    ms_diff = gather_res(ms_diff, name_list)
    avg_std = gather_res(avg_std, name_list)
    avg_mean = gather_res(avg_mean, name_list)
    print("-" * 80)
    print("Average Absolute Difference:")
    print(avg_abs_diff.to_string())

    print("-" * 80)
    print("Mean Squared Difference:")
    print(ms_diff.to_string())

    print("-" * 80)
    print("Average Mean:")
    print(avg_mean.to_string())

    print("-" * 80)
    print("Average Standard Deviation:")
    print(avg_std.to_string())

    return


def merge_stat_df(df1, df2):
    df = pd.merge(df1, df2, left_index=True, right_index=True)
    return df


if __name__ == "__main__":
    tsdf, show_wp = extract_watermark_info(get_sample_df(task='text_summarization'), return_wp_list=True)
    sample_df_2_stat_undectectable_exp(tsdf, show_wp=show_wp)
    # print(df)
