from datasets import load_dataset
from openicl import DatasetReader, PromptTemplate, PPLInferencer, AccEvaluator
from util.noise import gen_noise_classification
from util.template import gen_template_classification
from util.retriever import get_retriever
from util.verification import data_verification
from util.misc import setup_seed
from util.partition import (
    data2usename,
    data2numcls,
)

import argparse
import os
import json
import wandb
import socket

from huggingface_hub import login, HfApi, HfFolder

token = "HFTOKEN"
login(token=token)


def run(args):

    dataset = load_dataset(data2usename[args.dataset])

    dataset = data_verification(dataset, data2numcls[args.dataset], debug=args.debug)

    # Gen Noise
    dataset = gen_noise_classification(
        dataset,
        num_class=args.num_class,
        p=0,
        split="train",
        dataname=args.dataset,
    )
    setup_seed(args.seed)
    tp_dict = gen_template_classification(args)
    retriever = get_retriever(args)

    # Loader
    if args.dataset.startswith("gen-"):
        input_col = "paraphrase"
    else:
        input_col = "sentence"

    data = DatasetReader(
        dataset, input_columns=[input_col], output_column="new_label"
    )  #  datareader.reference is "new_label" column
    template = PromptTemplate(tp_dict, {input_col: "</text>"}, ice_token="</E>")
    retriever = retriever(data, ice_num=args.ice_num)
    inferencer = PPLInferencer(
        model_name=args.model,
        output_json_filepath=args.log_dir,
    )
    predictions = inferencer.inference(
        retriever, ice_template=template, output_json_filepath=args.log_dir
    )

    # save
    prediction_file = (
        f'{args.log_dir}/prediction_{"debug" if args.debug else "run"}.json'
    )
    with open(prediction_file, "w") as f:
        json.dump(predictions, f)

    label_file = f'{args.log_dir}/label_{"debug" if args.debug else "run"}.json'
    with open(label_file, "w") as f:
        json.dump(data.references, f)


def print_score(args):
    # load
    prediction_file = (
        f'{args.log_dir}/prediction_{"debug" if args.debug else "run"}.json'
    )
    label_file = f'{args.log_dir}/label_{"debug" if args.debug else "run"}.json'
    with open(prediction_file, "r") as f:
        prediction = data_loaded = json.load(f)
    with open(label_file, "r") as f:
        reference = data_loaded = json.load(f)

    evaluator = AccEvaluator()
    score = evaluator.score(predictions=prediction, references=reference)
    print(score)
    return score


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--dataset", type=str, default="sst2", help="Dataset name")
    args.add_argument(
        "--model",
        type=str,
        default="EleutherAI/gpt-neo-2.7B",
        help="Pretrained LLM model name",
    )
    args.add_argument(
        "--retriever", type=str, default="random", help="Retriever Type"
    )  # use 'bm25'
    args.add_argument("--ice_num", default=8, type=int)
    args.add_argument(
        "--log_dir",
        type=str,
        default=f"centr_icl_log",
        help="Logging directory",
    )
    args.add_argument("--debug", action="store_true")
    args.add_argument("--run", action="store_true")
    args.add_argument("--seed", default=0, type=int)
    args.add_argument("--proj_name", default="FL-ICL-debug")

    args = args.parse_args()

    model = args.model.replace("/", "_")
    host_name = socket.gethostname()
    if "server_name" in host_name.lower():
        cache_root = "cache/root/directory"
    else:
        raise ValueError(f"Check the server hostname for log_dir initialization.")

    args.log_dir = os.path.join(
        cache_root,
        args.log_dir,
        f"{args.dataset}/model={model}_retriever={args.retriever}_icenum={args.ice_num}/seed={args.seed}",
    )
    os.makedirs(args.log_dir, exist_ok=True)

    args.num_class = data2numcls[args.dataset]

    run_name = f"centr_{args.dataset}_model={model}_retriever={args.retriever}_icenum={args.ice_num}_seed={args.seed}"
    wb_run = wandb.init(config=args, project=args.proj_name, name=run_name)

    if args.run:
        run(args)

    scores = print_score(args)
    wb_run.log(scores)
    wb_run.finish()
