import torch

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from accelerate import Accelerator
from rank_bm25 import BM25Okapi
import numpy as np
from tqdm import trange
from nltk.tokenize import word_tokenize

from openicl.icl_retriever import BaseRetriever, BM25Retriever
from openicl.utils.check_type import _check_str
from openicl import DatasetReader, PromptTemplate, PPLInferencer, AccEvaluator

import argparse
import json
import os
import wandb
import socket
import copy

from util.template import gen_template_classification
from util.retriever import get_retriever, get_fedretriever
from util.partition import (
    cls_iid_partition,
    cls_noniid_partition,
    datasplit_subset,
    data2usename,
    data2numcls,
)
from util.icl_bm25_fedretriever import BM25FedRetriever
from util.icl_topk_fedretriever import TopkFedRetriever
from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.icl_ppl_fed_opt_budget_inferencer import PPLFedOptBudgetInferencer
from util.misc import setup_seed


def print_score(args, pred_file, label_file):
    # load
    with open(pred_file, "r") as f:
        prediction = data_loaded = json.load(f)
    with open(label_file, "r") as f:
        reference = data_loaded = json.load(f)

    evaluator = AccEvaluator()
    score = evaluator.score(predictions=prediction, references=reference)
    print(score)
    return score


def run(args):
    # ===== 1. Data Setting Preparation =====
    # load dataset
    dataset = load_dataset(data2usename[args.dataset])

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # slice a subset for proxy dataset on server
    if args.proxy_split is not None:
        proxy_data, remain_data = datasplit_subset(
            dataset[args.proxy_split],
            subset_num=args.proxy_size,
            split=args.proxy_split,
            verbose=True,
            return_remain=True,
        )
        dataset[args.proxy_split] = remain_data

    # partition data into clients
    if args.partition == "iid":
        fed_dataset = cls_iid_partition(
            dataset=dataset,
            split="train",
            data_name=args.dataset,
            num_clients=args.num_clients,
            test_split=test_split,
            subset_num=args.subset_num,
        )
    elif args.partition == "noniid":
        fed_dataset = cls_noniid_partition(
            dataset=dataset,
            split="train",
            data_name=args.dataset,
            major_classes_num=args.major_classes_num,
            num_clients=args.num_clients,
            test_split=test_split,
            subset_num=args.subset_num,
        )
    else:
        raise ValueError(
            f"args.partition can only be 'iid' or 'noniid', rather than '{args.partition}'."
        )

    # add client idx for each local training sample
    for cid in range(args.num_clients):
        local_sample_num = len(fed_dataset[f"train-client{cid}"])
        cid_col = [cid for _ in range(local_sample_num)]
        fed_dataset[f"train-client{cid}"] = fed_dataset[
            f"train-client{cid}"
        ].add_column("cid", copy.deepcopy(cid_col))
    print("Add client ID column for each local training dataset.")

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # make datareader for all clients' local training data
    fed_reader = []
    for cid in range(args.num_clients):
        fed_reader.append(
            DatasetReader(
                fed_dataset[f"train-client{cid}"],
                input_columns=["sentence"],
                output_column="label",
            )
        )

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {"sentence": "</text>"}, ice_token="</E>")

    # generate retriever for each client: each local train data will be used to build local train corpus
    fed_retrievers = []
    RETRIEVER = get_fedretriever(args)
    for cid in range(args.num_clients):
        retriever = RETRIEVER(fed_reader[cid], ice_num=args.local_ice_num)
        fed_retrievers.append(retriever)

    # ===== 2. Get budget ground for dataset =====
    setup_seed(args.seed)
    inferencer = PPLFedOptBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    # ----- get optimal budget for proxy dataset -----
    if args.proxy_split is not None:
        print("=" * 20 + "Get the optimal budget for proxy dataset" + "=" * 20)
        proxy_ice_source, proxy_per_client_budgets, proxy_preds = (
            inferencer.proxy_data_opt_budget(
                fed_retrievers,
                query_dataset=proxy_data,
                ice_template=template,
                output_json_filepath=args.log_dir,
                inference=False,
                args=args,
                local_ice_num=args.server_ice_num,
                prefix="proxy",
            )
        )

        # get the embedding of proxy dataset
        proxy_orig_idxs = proxy_data["idx"]
        proxy_datalist = fed_retrievers[0].dataset_reader.generate_input_field_corpus(
            proxy_data
        )
        proxy_dataloader = fed_retrievers[0].create_dataloader(proxy_datalist)
        proxy_res_list = fed_retrievers[0].forward(
            inferencer.retriever_model,
            proxy_dataloader,
            orig_idxs=proxy_orig_idxs,
            process_bar=True,
            information="Embedding proxy data queries...",
        )

        # save the proxy dataset embedding
        torch.save(
            {"res_list": proxy_res_list},
            os.path.join(args.log_dir, "proxy_forward_result.pt"),
        )
        print("Proxy forward result saved.")
        print("-" * 60)

    # ----- get optimal budget for test dataset -----
    print("=" * 20 + "Get the optimal budget for test dataset" + "=" * 20)
    test_ice_source, test_per_client_budgets, test_preds = (
        inferencer.proxy_data_opt_budget(
            fed_retrievers,
            query_dataset=fed_dataset[test_split],
            ice_template=template,
            output_json_filepath=args.log_dir,
            inference=False,
            args=args,
            local_ice_num=args.local_ice_num,
            prefix="test",
        )
    )

    # get the embedding of remain test dataset
    test_orig_idxs = fed_dataset[test_split]["idx"]
    test_datalist = fed_retrievers[0].dataset_reader.generate_input_field_corpus(
        fed_dataset[test_split]
    )
    test_dataloader = fed_retrievers[0].create_dataloader(test_datalist)
    test_res_list = fed_retrievers[0].forward(
        inferencer.retriever_model,
        test_dataloader,
        orig_idxs=test_orig_idxs,
        process_bar=True,
        information="Embedding test data queries...",
    )

    # save test dataset embedding
    torch.save(
        {"res_list": test_res_list},
        os.path.join(args.log_dir, "test_forward_result.pt"),
    )
    print("Test forward result saved.")
    print("-" * 60)

    # ----- save & evaluation on proxy dataset -----
    # proxy_prediction_file = os.path.join(args.log_dir, f"proxy_prediction_{'debug' if args.debug else 'run'}.json")
    # with open(proxy_prediction_file, "w") as f:
    #     json.dump(proxy_preds, f)

    # proxy_label_file = os.path.join(args.log_dir, f"proxy_label_{'debug' if args.debug else 'run'}.json")
    # with open(proxy_label_file, "w") as f:
    #     json.dump(proxy_data["label"], f)

    # # ===== 3. Execute Federated Inference Pipeline =====
    # setup_seed(args.seed)

    # predictions=None

    # # ===== 4. Save Prediction =====
    # prediction_file = os.path.join(args.log_dir, f"prediction_{"debug" if args.debug else "run"}.json")
    # with open(prediction_file, "w") as f:
    #     json.dump(predictions, f)

    # label_file = os.path.join(args.log_dir, f"label_{"debug" if args.debug else "run"}.json")
    # with open(label_file, "w") as f:
    #     # json.dump(data.references, f)
    #     json.dump(fed_dataset["test"]["label"], f)


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--dataset", type=str, default="sst2", help="Dataset name")
    args.add_argument(
        "--subset_num",
        type=int,
        default=None,
        help="Number of subset test set for query",
    )
    args.add_argument("--proxy_split", default=None, type=str)
    args.add_argument("--proxy_size", default=500, type=int)
    args.add_argument("--partition", default="iid", choices=["iid", "noniid"], type=str)
    args.add_argument("--num_clients", type=int, default=3)
    args.add_argument("--major_classes_num", default=-1, type=int)
    args.add_argument(
        "--model",
        type=str,
        default="EleutherAI/gpt-neo-2.7B",
        help="Pretrained LLM model name",
    )
    args.add_argument("--local_ice_num", default=None, type=int)
    args.add_argument(
        "--overall_local_ice_num",
        default=None,
        type=int,
        help="Optional choice, automatically assign local_ice_num based on overall_local_ice_num and num_clients",
    )
    args.add_argument(
        "--server_ice_num",
        default=-1,
        type=int,
        help="Server side ICE Number, server_ice_num <= num_clients * local_ice_num",
    )
    args.add_argument(
        "--concat", default="simple", type=str, choices=["simple", "merge", "reorder"]
    )
    args.add_argument(
        "--retriever", type=str, default="bm25", help="Server side Retriever Type"
    )  # use 'bm25'
    args.add_argument(
        "--log_dir",
        type=str,
        default=f"fed_icl_log",
        help="Logging directory",
    )
    args.add_argument("--debug", action="store_true")
    args.add_argument("--run", action="store_true")
    args.add_argument("--seed", default=0, type=int)
    args.add_argument("--proj_name", default="FL-ICL-debug")
    args.add_argument("--group_name", default=None, type=str)

    args = args.parse_args()

    args.num_classes = data2numcls[args.dataset]

    if args.overall_local_ice_num is not None and args.local_ice_num is None:
        args.local_ice_num = int(args.overall_local_ice_num / args.num_clients)

    if args.server_ice_num == -1:
        args.server_ice_num = int(args.local_ice_num * args.num_clients)

    if args.local_ice_num * args.num_clients > args.server_ice_num:
        args.concat = "reorder"

    # prepare output files
    model = args.model.replace("/", "_")
    part_name = f"{args.partition}_clients={args.num_clients}"
    if args.partition == "noniid":
        part_name += f"_majorclass={args.major_classes_num}"

    data_folder = f"{args.dataset}"
    if args.subset_num is not None:
        data_folder += f"_query-num={args.subset_num}"

    if args.proxy_split is not None:
        data_folder += f"_proxy={args.proxy_split}-{args.proxy_size}"

    host_name = socket.gethostname()
    if "server_name" in host_name.lower():
        cache_root = "cache/root/directory"
    else:
        raise ValueError(f"Check the server hostname for log_dir initialization.")

    args.log_dir = os.path.join(
        cache_root,
        args.log_dir,
        f"{data_folder}/{part_name}/model={model}_retriever=fed{args.retriever}_local-ice={args.local_ice_num}_server-ice={args.server_ice_num}_concat={args.concat}/seed={args.seed}",
    )
    os.makedirs(args.log_dir, exist_ok=True)

    # perform FL-ICL pipeline
    if args.run:
        run_name = f"fed_{data_folder}_{part_name}_model={model}_retriever=fed{args.retriever}_local-ice={args.local_ice_num}_server-ice={args.server_ice_num}_concat={args.concat}_seed={args.seed}"
        wb_run = wandb.init(
            config=args, project=args.proj_name, name=run_name, group=args.group_name
        )

        if (
            args.partition == "noniid"
            and args.major_classes_num * args.num_clients < args.num_classes
        ):
            wb_run.finish()
            exit()

    if args.run:
        run(args)

    # # read result from saved files and calculate performance score
    # prediction_file = os.path.join(
    #     args.log_dir, f"prediction_{'debug' if args.debug else 'run'}.json"
    # )
    # label_file = os.path.join(
    #     args.log_dir, f"label_{'debug' if args.debug else 'run'}.json"
    # )
    # scores = print_score(args, prediction_file, label_file)

    # proxy_prediction_file = os.path.join(
    #     args.log_dir, f"proxy_prediction_{'debug' if args.debug else 'run'}.json"
    # )
    # proxy_label_file = os.path.join(
    #     args.log_dir, f"proxy_label_{'debug' if args.debug else 'run'}.json"
    # )
    # proxy_scores = print_score(args, proxy_prediction_file, proxy_label_file)

    # results = {
    #     "test_accuracy": scores["accuracy"],
    #     "proxy_accuracy": proxy_scores["accuracy"],
    # }

    if args.run:
        # wb_run.log(results)
        wb_run.finish()
