import os
import grp
import pandas as pd
from datasets import DatasetDict, load_from_disk

from adaboost import (
    calculate_avg_unweighted_error,
    calculate_avg_option_error,
    calculate_adaboost_alphas,
)


def format_run_name(sargs, mode="strong"):
    return (
        f"{sargs.model_name.split('/')[1]}_"
        + (f"{sargs.strong_model_name.split('/')[1]}_" if mode == "strong" else "")
        + (f"{int(sargs.is_easy_to_hard)}_" if mode == "strong" else "")
        + f"{sargs.dataset_name}_"
        + f"ar{sargs.adaboost_rounds}_"
        + f"ne{sargs.num_epochs}_"
        + f"lr{sargs.learning_rate}_"
        + f"tbe{int(sargs.is_token_based_error)}_"
        + f"wbt{int(sargs.is_weight_by_token)}_"
        + f"co{int(sargs.is_completion_only)}_"
        + f"pb{sargs.probability_bias}_"
        + f"tpws{sargs.token_prob_window_size}_"
        + f"ltk{sargs.logits_top_k}_"
        + f"cp{int(sargs.is_combine_probs)}_"
        + f"tkp{int(sargs.is_top_k_pooling)}_"
        + f"tl{sargs.test_limit}_"
    )


def save_tables(train_dataset, eval_dataset, transfer_dataset, sargs, table_dir):
    dataset = DatasetDict(
        {"train": train_dataset, "eval": eval_dataset, "transfer": transfer_dataset}
    )
    counter = 0
    while True:
        if os.path.exists(
            os.path.join(table_dir, format_run_name(sargs), f"{counter}")
        ):
            counter += 1
        else:
            break
    dataset.save_to_disk(os.path.join(table_dir, format_run_name(sargs), f"{counter}"))
    os.chown(
        os.path.join(table_dir, format_run_name(sargs), f"{counter}"),
        -1,
        grp.getgrnam(sargs.grp_name).gr_gid,
    )
    os.chmod(
        os.path.join(table_dir, format_run_name(sargs), f"{counter}"),
        os.stat(os.path.join(table_dir, format_run_name(sargs), f"{counter}")).st_mode
        | 0o020,
    )
    print(f"Dataset tables saved to {table_dir}/{format_run_name(sargs)}")


def save_results(train_dataset, eval_dataset, transfer_dataset, sargs):
    result_path = os.path.join(
        sargs.w2s_folder, "results", format_run_name(sargs) + ".csv"
    )
    adaboost_alphas = [1.0] + calculate_adaboost_alphas(
        train_dataset, sargs.adaboost_rounds, sargs.is_weight_by_token
    )
    (
        avg_train_errors,
        avg_eval_errors,
        avg_eval_option_errors,
        avg_transfer_option_errors,
        avg_strong_errors,
        avg_strong_option_errors,
    ) = ([], [], [], [], [], [])
    data = {"adaboost_round": list(range(sargs.adaboost_rounds + 1))}
    for t in range(sargs.adaboost_rounds + 1):
        avg_train_errors.append(
            calculate_avg_unweighted_error(train_dataset, t, sargs.is_weight_by_token)
        )
        avg_eval_errors.append(
            calculate_avg_unweighted_error(eval_dataset, t, sargs.is_weight_by_token)
        )
        avg_eval_option_errors.append(calculate_avg_option_error(eval_dataset, t))
        avg_transfer_option_errors.append(
            calculate_avg_option_error(transfer_dataset, t)
        )
        avg_strong_errors.append(
            calculate_avg_unweighted_error(
                eval_dataset, t, sargs.is_weight_by_token, mode="strong"
            )
        )
        avg_strong_option_errors.append(
            calculate_avg_option_error(eval_dataset, t, mode="strong")
        )
    data["adaboost_alphas"] = adaboost_alphas
    data["avg_train_errors"] = avg_train_errors
    data["avg_eval_errors"] = avg_eval_errors
    data["avg_eval_option_errors"] = avg_eval_option_errors
    data["avg_transfer_option_errors"] = avg_transfer_option_errors
    data["avg_strong_errors"] = avg_strong_errors
    data["avg_strong_option_errors"] = avg_strong_option_errors
    df = pd.DataFrame(data)
    result = df.to_string(index=False)

    with open(result_path, "w") as file:
        file.write(result)
    os.chown(result_path, -1, grp.getgrnam(sargs.grp_name).gr_gid)
    os.chmod(result_path, os.stat(result_path).st_mode | 0o020)


def load_tables(sargs, table_dir):
    counter = 0
    while True:
        if os.path.exists(
            os.path.join(table_dir, format_run_name(sargs), f"{counter}")
        ):
            counter += 1
        else:
            break
    counter -= 1
    dataset = load_from_disk(
        os.path.join(table_dir, format_run_name(sargs), f"{counter}")
    )
    return dataset["train"], dataset["eval"], dataset["transfer"]
