"""
This script is for,  Fed-ICL using single client's local data as training set. Do inference on each local dataset and report the highest performance over clients
"""

from datasets import load_dataset
from openicl import DatasetReader, PromptTemplate, PPLInferencer, AccEvaluator
from util.noise import gen_noise_classification
from util.template import gen_template_classification
from util.retriever import get_retriever, get_fedretriever
from util.verification import data_verification

# from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.icl_ppl_fed_weak_budget_inferencer import PPLFedWeakBudgetInferencer
from util.misc import setup_seed
from util.partition import (
    data2usename,
    data2numcls,
    datasplit_subset,
    cls_noniid_partition,
    cls_iid_partition,
)

import argparse
import os
import json
import wandb
import numpy as np
import socket

from huggingface_hub import login, HfApi, HfFolder

token = "HFTOKEN"
login(token=token)


def run(args):

    dataset = load_dataset(data2usename[args.dataset])

    if args.dataset.startswith("gen-"):
        input_col = "paraphrase"
    else:
        input_col = "sentence"

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # partition data into clients
    if args.partition == "iid":
        fed_dataset = cls_iid_partition(
            dataset=dataset,
            split="train",
            data_name=args.dataset,
            num_clients=args.num_clients,
            test_split=test_split,
            subset_num=args.subset_num,
        )
    elif args.partition == "noniid":
        fed_dataset = cls_noniid_partition(
            dataset=dataset,
            split="train",
            data_name=args.dataset,
            major_classes_num=args.major_classes_num,
            num_clients=args.num_clients,
            test_split=test_split,
            subset_num=args.subset_num,
        )
    else:
        raise ValueError(
            f"args.partition can only be 'iid' or 'noniid', rather than '{args.partition}'."
        )

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {input_col: "</text>"}, ice_token="</E>")

    #  ===== 2. Create the inferencer =====
    setup_seed(args.seed)
    inferencer = PPLFedWeakBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    #  ===== 4. Perform local ICL on each of client =====
    # loop through each client's local dateset
    for cid in range(args.num_clients):
        local_reader = DatasetReader(
            fed_dataset[f"train-client{cid}"],
            input_columns=[input_col],
            output_column="label",
        )

        # generate retriever for each client: each local train data will be used to build local train corpus
        RETRIEVER = get_fedretriever(args)
        fed_retrievers = [RETRIEVER(local_reader, ice_num=args.local_ice_num)]

        # perform local retriever and server inference
        local_output_dir = os.path.join(args.log_dir, f"client-{cid}")
        os.makedirs(local_output_dir, exist_ok=True)
        predictions = inferencer.inference(
            fed_retrievers,
            query_dataset=fed_dataset["test"],
            ice_template=template,
            concat="simple",
            budget_strategy="uniform",
            output_json_filepath=local_output_dir,
            args=args,
        )

        # ------ Save Prediction ------
        prediction_file = os.path.join(
            args.log_dir,
            f"client-{cid}",
            f"prediction_{'debug' if args.debug else 'run'}.json",
        )
        with open(prediction_file, "w") as f:
            json.dump(predictions, f)

        label_file = os.path.join(
            args.log_dir,
            f"client-{cid}",
            f"label_{'debug' if args.debug else 'run'}.json",
        )
        with open(label_file, "w") as f:
            json.dump(fed_dataset["test"]["label"], f)


def print_score(prediction_file, label_file, prefix="Local"):
    # load
    with open(prediction_file, "r") as f:
        prediction = data_loaded = json.load(f)
    with open(label_file, "r") as f:
        reference = data_loaded = json.load(f)

    evaluator = AccEvaluator()
    score = evaluator.score(predictions=prediction, references=reference)
    print(f"{prefix}: {score}")
    return score


def collect_local_results(args):
    score_summary = {}
    score_list = []
    for cid in range(args.num_clients):
        local_output_dir = os.path.join(args.log_dir, f"client-{cid}")
        prediction_file = os.path.join(
            local_output_dir,
            f"prediction_{'debug' if args.debug else 'run'}.json",
        )
        label_file = os.path.join(
            local_output_dir,
            f"label_{'debug' if args.debug else 'run'}.json",
        )
        local_score = print_score(
            prediction_file, label_file, prefix=f"Client {cid} local"
        )
        score_summary[f"client-{cid}-accuracy"] = local_score["accuracy"]
        score_list.append(local_score["accuracy"])

    # calculate the statistic on local performance
    score_summary["max"] = max(score_list)
    score_summary["mean"] = sum(score_list) / args.num_clients
    score_summary["std"] = np.std(score_list).item()
    return score_summary


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--dataset", type=str, default="sst2", help="Dataset name")
    args.add_argument(
        "--subset_num",
        type=int,
        default=None,
        help="Number of subset test set for query",
    )
    args.add_argument("--partition", default="iid", choices=["iid", "noniid"], type=str)
    args.add_argument("--num_clients", type=int, default=3)
    args.add_argument("--major_classes_num", default=-1, type=int)
    args.add_argument(
        "--model",
        type=str,
        default="EleutherAI/gpt-neo-2.7B",
        help="Pretrained LLM model name",
    )
    args.add_argument(
        "--overall_local_ice_num",
        default=None,
        type=int,
        help="Optional choice, automatically assign local_ice_num based on overall_local_ice_num and num_clients",
    )
    args.add_argument(
        "--server_ice_num",
        default=-1,
        type=int,
        help="Server side ICE Number, server_ice_num <= num_clients * local_ice_num",
    )
    args.add_argument(
        "--retriever", type=str, default="topk", help="Server side Retriever Type"
    )  # use 'bm25'
    args.add_argument(
        "--log_dir",
        type=str,
        default=f"fed_icl_log",
        help="Logging directory",
    )
    args.add_argument("--debug", action="store_true")
    args.add_argument("--run", action="store_true")
    args.add_argument("--seed", default=0, type=int)
    args.add_argument("--proj_name", default="FL-ICL-debug")
    args.add_argument("--group_name", default=None, type=str)

    args = args.parse_args()

    args.num_classes = data2numcls[args.dataset]
    args.local_ice_num = args.overall_local_ice_num

    # prepare output files
    model = args.model.replace("/", "_")
    part_name = f"{args.partition}_clients={args.num_clients}"
    if args.partition == "noniid":
        part_name += f"_majorclass={args.major_classes_num}"

    data_folder = f"{args.dataset}"
    if args.subset_num is not None:
        data_folder += f"_query-num={args.subset_num}"

    host_name = socket.gethostname()
    if "server_name" in host_name.lower():
        cache_root = "cache/root/directory"
    else:
        raise ValueError(f"Check the server hostname for log_dir initialization.")

    args.log_dir = os.path.join(
        cache_root,
        args.log_dir,
        f"{data_folder}/{part_name}/singleton_model={model}_retriever=fed{args.retriever}_overall-ice={args.overall_local_ice_num}_server-ice={args.server_ice_num}/seed={args.seed}",
    )
    os.makedirs(args.log_dir, exist_ok=True)

    # perform FL-ICL pipeline
    if args.run:
        run_name = f"singleton_{data_folder}_{part_name}_model={model}_retriever=fed{args.retriever}_overall-local-ice={args.overall_local_ice_num}_server-ice={args.server_ice_num}_seed={args.seed}"
        wb_run = wandb.init(
            config=args, project=args.proj_name, name=run_name, group=args.group_name
        )

        if (
            args.partition == "noniid"
            and args.major_classes_num * args.num_clients < args.num_classes
        ):
            wb_run.finish()
            exit()

    if args.run:
        run(args)

    # read result from saved files and calculate performance score
    score_summary = collect_local_results(args)
    if args.run:
        wb_run.log(score_summary)
        wb_run.finish()
