import os
import polars as pl

import experiments.config

data_folder = "data"
# data_folder = "final_data"



def print_ds(ds):
    with pl.Config(
        fmt_str_lengths=1000, tbl_width_chars=1000, tbl_rows=1000, tbl_cols=1000
    ):
        print(ds)

def show_alpha_temperature_by_candidate(configs, print_results=True, selected_candidate=10):
    tables_by_task = {}
    temperatures_by_task = {}
    
    for config in configs:
        print(f"Config: {config['name']}")
        
        save_path = os.path.join(
            os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
        )
        
        out_ds = pl.read_csv(
            save_path,
        ).with_columns(alpha=pl.col("alpha").str.json_decode())
        
        agg_ds =  out_ds.explode("alpha").with_columns(row_name=pl.col("algorithm") + "_" + pl.col("source")).group_by(["row_name", "num_candidate"]).agg(
                alpha_mean=pl.mean("alpha"),
                alpha_std=pl.std("alpha") / pl.count("alpha").sqrt(),
            ).with_columns(
                alpha_summ=pl.struct(["alpha_mean", "alpha_std"]).map_elements(
                    lambda s: f"{s['alpha_mean']:.3f}±{s['alpha_std']:.3f}",
                    return_dtype=pl.String,
                )
            ).pivot(
                index="row_name",
                on="num_candidate",
                values="alpha_summ",
                sort_columns=True,
                maintain_order=True,
            ).pipe(
                lambda df: df.select(
                    ["row_name"]
                    + out_ds["num_candidate"].unique().sort().cast(pl.String).to_list()
                )
            ).sort("row_name")

        
        task = config["data_kwargs"]['ds_name'] 
        temperature = config['generation_kwargs']['temperature']
        
        if task not in tables_by_task:
            tables_by_task[task] = []

        if task not in temperatures_by_task:
            temperatures_by_task[task] = []

        selected_candidate_str = str(selected_candidate)
        candidate_table = (
        agg_ds
        .select(["row_name", selected_candidate_str])  
        )
        tables_by_task[task].append(candidate_table)
        temperatures_by_task[task].append(temperature)
    
    with pl.Config(
        fmt_str_lengths=1000, tbl_width_chars=1000, tbl_rows=1000, tbl_cols=1000
    ):
        for task, tables in tables_by_task.items():
            renamed_tables = []
            for table, temperature in zip(tables, temperatures_by_task[task]):
                temperature_str = str(temperature)
                renamed_table = table.rename({selected_candidate_str: temperature_str})
                renamed_tables.append(renamed_table)

            combined_table = renamed_tables[0]
    
            for t in renamed_tables[1:]:
                combined_table = combined_table.join(t, on="row_name", how="outer", suffix="_right")
                combined_table = combined_table.drop([col for col in combined_table.columns if col.endswith("_right")])
            task_name = task.split("/")[-1]
            if print_results:
                print(f"Task: {task_name}")
                print(combined_table.sort("row_name"))


def read_do_stats(config):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = (
        pl.read_csv(
            save_path,
        )
        .with_columns(
            alpha=pl.col("alpha").str.json_decode(),
            row_name=pl.col("algorithm") + "_" + pl.col("source"),
        )
        .explode("alpha")
        .group_by(["row_name", "num_candidate"])
        .agg(
            alpha_mean=pl.mean("alpha"),
            alpha_std=pl.std("alpha") / pl.count("alpha").sqrt(),
        )
    )

    return out_ds


def show_alpha(config, print_results=True):
    print(f"Config: {config['name']}")
    out_ds = read_do_stats(config)

    print_ds(
        out_ds.with_columns(
            alpha_summ=pl.struct(["alpha_mean", "alpha_std"]).map_elements(
                lambda s: f"{s['alpha_mean']:.3f}±{s['alpha_std']:.3f}",
                return_dtype=pl.String,
            )
        )
        .pivot(
            index="row_name",
            on="num_candidate",
            values="alpha_summ",
            sort_columns=True,
            maintain_order=True,
        )
        .pipe(
            lambda df: df.select(
                ["row_name"]
                + out_ds["num_candidate"].unique().sort().cast(pl.String).to_list()
            )
        )
        .sort("row_name")
    )


def ablation_exp2_tables(selected_candidate=3):
    # data[short_ds_name][row_name][temperature] = (alpha_lower, alpha_upper, alpha)
    data = {}
    for c in experiments.config.ablation_exp2 + experiments.config.ablation_exp2_a1:
        out_ds = read_do_stats(c).filter(pl.col("num_candidate") == 3)
        temperature = c["generation_kwargs"]["temperature"]
        short_ds_name = experiments.config.short_name(c["data_kwargs"]["ds_name"])

        for row in out_ds.rows(named=True):
            data.setdefault(short_ds_name, {}).setdefault(row["row_name"], {})[
                temperature
            ] = (
                row["alpha_mean"] - row["alpha_std"],
                row["alpha_mean"] + row["alpha_std"],
                row["alpha_mean"],
            )

    # sort by temperature
    for short_ds_name in data:
        for row_name in data[short_ds_name]:
            data[short_ds_name][row_name] = dict(
                sorted(data[short_ds_name][row_name].items(), key=lambda x: x[0])
            )

    def theory_last(items):
        new_items = []
        new_items_last = []
        for item in items:
            row_name = item[0]
            if "theory" in row_name:
                new_items_last.append(item)
            else:
                new_items.append(item)
        return new_items + new_items_last

    for short_ds_name, rows in data.items():
        table_str = ""
        for row_name, row_data in theory_last(rows.items()):
            table_str += rf"""%shading for {row_name}
\addplot[name path={row_name}_upper, draw=none] coordinates {{
"""
            for temperature, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({temperature}, {alpha_upper:.3f}) "
            table_str += rf"""
}};
\addplot[name path={row_name}_lower, draw=none] coordinates {{
"""
            for temperature, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({temperature}, {alpha_lower:.3f}) "
            table_str += rf"""
}};
\addplot[{row_name}_style, opacity=0.2] fill between[of={row_name}_upper and {row_name}_lower];

% Actual {row_name} plot
\addplot[{row_name}_style, mark indices={{1,2,3,4,5,13,20}}] coordinates {{
"""
            for temperature, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({temperature}, {alpha:.3f}) "
            table_str += rf"""
}};

"""
        if not os.path.exists("tables"):
            os.makedirs("tables")
        with open(f"tables/ablation_exp2_table_{short_ds_name}.tex", "w") as f:
            f.write(table_str)

def ablation_exp3_tables():
    # data[short_ds_name][row_name][num_candidate] = (alpha_lower, alpha_upper, alpha)
    data = {}
    for c in experiments.config.ablation_exp3:
        out_ds = read_do_stats(c)
        short_ds_name = experiments.config.short_name(c["data_kwargs"]["ds_name"])

        for row in out_ds.rows(named=True):
            data.setdefault(short_ds_name, {}).setdefault(row["row_name"], {})[
                row["num_candidate"]
            ] = (
                row["alpha_mean"] - row["alpha_std"],
                row["alpha_mean"] + row["alpha_std"],
                row["alpha_mean"],
            )

    def theory_last(items):
        new_items = []
        new_items_last = []
        for item in items:
            row_name = item[0]
            if "theory" in row_name:
                new_items_last.append(item)
            else:
                new_items.append(item)
        return new_items + new_items_last

    for short_ds_name, rows in data.items():
        table_str = ""
        for row_name, row_data in theory_last(rows.items()):
            row_data = dict(sorted(row_data.items(), key=lambda x: x[0]))
            table_str += rf"""%shading for {row_name}
\addplot[name path={row_name}_upper, draw=none] coordinates {{
"""
            for num_candidate, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({num_candidate}, {alpha_upper:.3f}) "
            table_str += rf"""
}};
\addplot[name path={row_name}_lower, draw=none] coordinates {{
"""
            for num_candidate, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({num_candidate}, {alpha_lower:.3f}) "
            table_str += rf"""
}};
\addplot[{row_name}_style, opacity=0.2] fill between[of={row_name}_upper and {row_name}_lower];

% Actual {row_name} plot
\addplot[{row_name}_style] coordinates {{
"""
            for num_candidate, (alpha_lower, alpha_upper, alpha) in row_data.items():
                table_str += rf"({num_candidate}, {alpha:.3f}) "
            table_str += rf"""
}};

"""
        if not os.path.exists("tables"):
            os.makedirs("tables")
        with open(f"tables/ablation_exp3_table_{short_ds_name}.tex", "w") as f:
            f.write(table_str)


def main_exp_tables(selected_candidate=3):
    shown_ds_name_dict = {
        "alpaca": "Alpaca",
        "cnn_dailymail": "CNN-DailyMail",
        "wmt14-de2en": "WMT'14",
    }
    sample_strategy_dict = {
        "K_Seq_verify": r"With \\ Replacement",
        "K_Seq_theory": r"With \\ Replacement",
        "wr_optimal_theory": r"With \\ Replacement",
        "wr_recursive_verify": r"With \\ Replacement",
        "wor_optimal_theory": r"Without \\ Replacement",
        "wor_recursive_verify": r"Without \\ Replacement",
        "GCSpS_verify": r"Greedy",
        "GCSpS_theory": r"Greedy",
    }
    method_dict = {
        "K_Seq_verify": r"K-SEQ",
        "K_Seq_theory": r"$\alpha^{\opn{K-SEQ}}$",
        "wr_optimal_theory": r"$\alpha^\ast$",
        "wr_recursive_verify": r"RRS",
        "wor_optimal_theory": r"$\alpha^\ast$",
        "wor_recursive_verify": r"RRS",
        "GCSpS_verify": r"Verify",
        "GCSpS_theory": r"$\alpha^\ast$",
    }

    def to_shown_model_name(model_name):
        if "llama" in model_name:
            model_name = model_name.replace("llama", "LLaMA")
        if "opt" in model_name:
            model_name = model_name.replace("opt", "OPT")
        else:
            model_name = (
                model_name.replace("EAGLE-Vicuna-7B-v1.3", "Eagle-0.24B")
                .replace("EAGLE-Qwen2-7B-Instruct", "Eagle-0.26B")
                .replace("vicuna-7b-v1.3", "Vicuna-7B")
                .replace("Qwen2-7B-Instruct", "Qwen2-7B")
            )
        return model_name.replace("b", "B").replace("m", "M")

    # data[model_pair][sampling_strategy][method][shown_ds_name]= (alpha, alpha_std)
    data = {}

    for c in experiments.config.main_exp:
        out_ds = read_do_stats(c).filter(pl.col("num_candidate") == 3)
        shown_ds_name = shown_ds_name_dict[
            experiments.config.short_name(c["data_kwargs"]["ds_name"])
        ]
        model_pair = rf"{to_shown_model_name(experiments.config.short_name(c['model_kwargs']['draft_model_str']))} \\ {to_shown_model_name(experiments.config.short_name(c['model_kwargs']['target_model_str']))}"

        for row in out_ds.rows(named=True):
            row_name = row["row_name"]
            sampling_strategy = sample_strategy_dict[row_name]
            method = method_dict[row_name]
            data.setdefault(model_pair, {}).setdefault(
                sampling_strategy, {}
            ).setdefault(method, {})[shown_ds_name] = (
                row["alpha_mean"],
                row["alpha_std"],
            )

    # reorder
    data = dict(sorted(data.items(), key=lambda x: ({"O": 0, "L": 1, "E": 2}[x[0][0]])))
    method_order = [
        "RRS",
        "K-SEQ",
        r"$\alpha^{\opn{K-SEQ}}$",
        "Verify",
        r"$\alpha^\ast$",
    ]
    ds_order = ["Alpaca", "CNN-DailyMail", "WMT'14"]
    for model_pair in data:
        data[model_pair] = dict(
            sorted(
                data[model_pair].items(),
                key=lambda x: (
                    {"With": 0, "Without": 1, "Greedy": 2}[x[0].split(" ")[0]]
                ),
            )
        )
        for sampling_strategy in data[model_pair]:
            data[model_pair][sampling_strategy] = dict(
                sorted(
                    data[model_pair][sampling_strategy].items(),
                    key=lambda x: (method_order.index(x[0])),
                )
            )
            for method in data[model_pair][sampling_strategy]:
                data[model_pair][sampling_strategy][method] = dict(
                    sorted(
                        data[model_pair][sampling_strategy][method].items(),
                        key=lambda x: (ds_order.index(x[0])),
                    )
                )

    table_str = (
        rf"""
\begin{{tabular}}{{lllcccccc}}
\toprule
\multirow{{2}}{{*}}{{Model Pairs}} & \multirow{{2}}{{*}}{{Draft Sampling}} & \multirow{{2}}{{*}}{{Method}} & """
        + " & ".join([rf"\multicolumn{{2}}{{c}}{{{ds}}}" for ds in ds_order])
        + rf"""\\

"""
        + "".join(
            [
                rf"\cmidrule(r){{{4 + 2 * i} - {5 + 2 * i}}}"
                for i in range(len(ds_order))
            ]
        )
        + rf"""
&  &  & """
        + " & ".join([r"$\alpha$ & $\Delta \alpha$" for _ in ds_order])
        + rf""" \\
\midrule
"""
    )
    col_num = 3 + 2 * len(ds_order)
    first_model_pair = True
    for model_pair in data:
        method_pair_row_num = len(
            [
                method
                for sampling_strategy in data[model_pair]
                for method in data[model_pair][sampling_strategy]
            ]
        )
        col1 = (
            rf"\multirow{{{method_pair_row_num}}}{{*}}{{\makecell[l]{{{model_pair}}}}}"
        )
        first_sampling_strategy = True
        for sampling_strategy in data[model_pair]:
            sampling_strategy_row_num = len(
                [method for method in data[model_pair][sampling_strategy]]
            )
            col2 = rf"\multirow{{{sampling_strategy_row_num}}}{{*}}{{\makecell[l]{{{sampling_strategy}}}}}"
            best_alphas = {
                ds_name: data[model_pair][sampling_strategy][method][ds_name]
                for ds_name in data[model_pair][sampling_strategy][r"$\alpha^\ast$"]
            }
            first_method = True
            for method in data[model_pair][sampling_strategy]:
                row_str = ""
                if not first_model_pair and first_sampling_strategy and first_method:
                    row_str += r"""\midrule
"""
                if not first_sampling_strategy and first_method:
                    row_str += rf"""\cmidrule(r){{2-{col_num}}}"""
                if first_sampling_strategy and first_method:
                    row_str += col1
                row_str += " & "
                if first_method:
                    row_str += col2
                row_str += f" & {method} & "
                numbers = []
                for ds_name in ds_order:
                    alpha, alpha_std = data[model_pair][sampling_strategy][method][
                        ds_name
                    ]
                    alpha_diff = alpha - best_alphas[ds_name][0]
                    std = (alpha_std**2 + best_alphas[ds_name][1] ** 2) ** 0.5
                    significance = abs(alpha_diff) / std
                    almost_best = significance < 1.96

                    if almost_best:
                        numbers.append(rf"$\bm{{{100*alpha:.1f}}} \pm {100*alpha_std:.1f}$")
                    else:
                        numbers.append(rf"${100*alpha:.1f} \pm {100*alpha_std:.1f}$")
                    if method == r"$\alpha^\ast$":
                        numbers.append("-")
                    else:
                        if significance > 1.96:
                            diff_str = (
                                rf"${100*alpha_diff:.1f}"
                                + (r"\uparrow" if alpha_diff > 0 else r"\downarrow")
                                + "$"
                            )
                        else:
                            diff_str = rf"${100*alpha_diff:.1f}$"
                        numbers.append(diff_str)
                row_str += " & ".join(numbers)
                row_str += rf""" \\
"""
                if first_model_pair:
                    first_model_pair = False
                if first_sampling_strategy:
                    first_sampling_strategy = False
                if first_method:
                    first_method = False

                table_str += row_str
    table_str += rf"""
\bottomrule
\end{{tabular}}
"""
    with open("tables/main_exp_table.tex", "w") as f:
        f.write(table_str)


def analyze_all():
    # show_alpha(experiments.config.debug_exp1)
    # show_alpha(experiments.config.debug_exp2)
    # show_alpha(experiments.config.debug_exp3)
    # show_alpha(experiments.config.debug_exp4)
    # show_alpha(experiments.config.debug_exp5)
    # show_alpha(experiments.config.debug_exp6)
    ablation_exp2_tables()
    ablation_exp3_tables()
    main_exp_tables()


if __name__ == "__main__":
    analyze_all()
