from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from openicl import DatasetReader, PromptTemplate, PPLInferencer, AccEvaluator
from util.noise import gen_noise_classification
from util.template import gen_template_classification
from util.retriever import get_retriever
from util.verification import data_verification
from util.partition import (
    cls_iid_partition,
    cls_noniid_partition,
    data2usename,
    data2numcls,
)
from util.retriever import get_retriever, get_server_retriever
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler
from util.misc import save_json
from queue import Queue

from util.concatenate import collect_samples, merge_concatenate, simple_concatenate
from util.icl_bm25_fedretriever import BM25FedRetriever

from typing import List, Union, Optional, Tuple, Dict
from openicl import DatasetReader, PromptTemplate
from openicl.icl_retriever import BaseRetriever, BM25Retriever, TopkRetriever
from openicl.utils.check_type import _check_str
from accelerate import Accelerator
from rank_bm25 import BM25Okapi
import numpy as np
from tqdm import trange
from nltk.tokenize import word_tokenize
import argparse
import json
import os
import random

from matplotlib import pyplot as plt


import tqdm
from openicl.icl_dataset_reader import DatasetEncoder
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl.utils.logging import get_logger
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import faiss
import copy
from collections import Counter

import sys
from sys import getsizeof

from util.partition import *
from util.icl_ppl_fed_opt_budget_inferencer import PPLFedOptBudgetInferencer
from util.retriever import get_retriever, get_fedretriever
from models import get_model
from util.misc import AverageMeter


def read_json(file_name):
    with open(file_name) as f:
        data = json.load(f)
    return data


def print_score(args):
    # load
    prediction_file = (
        f'{args.log_dir}/prediction_{"debug" if args.debug else "run"}.json'
    )
    label_file = f'{args.log_dir}/label_{"debug" if args.debug else "run"}.json'
    with open(prediction_file, "r") as f:
        prediction = data_loaded = json.load(f)
    with open(label_file, "r") as f:
        reference = data_loaded = json.load(f)

    evaluator = AccEvaluator()
    score = evaluator.score(predictions=prediction, references=reference)
    print(score)


def make_args(**kwargs):
    parser = argparse.ArgumentParser()
    # add some arguments
    # add the other arguments
    for k, v in kwargs.items():
        parser.add_argument("--" + k, default=v)
    args = parser.parse_args([])
    return args


def test_inference():
    data_name = "sst5"
    print(f"==== Data: {data_name} ===")
    args = make_args(
        **{
            "dataset": data_name,
            "num_class": data2numcls[data_name],
            "noise_p": 0.0,
            "debug": False,
            "run": True,
            "num_clients": 4,
            "model": "EleutherAI/gpt-neo-2.7B",
            "retriever": "topk",
            "partition": "iid",
            "proxy_split": "test",
            "proxy_size": 500,
            "log_dir": "debug",
            "local_ice_num": 32,
            "server_ice_num": 32,
            "subset_num": None,
            "concat": "reorder",
        }
    )

    os.makedirs(args.log_dir, exist_ok=True)

    # load dataset
    dataset = load_dataset(data2usename[args.dataset])

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # slice a subset for proxy dataset on server
    proxy_data, remain_data = datasplit_subset(
        dataset[args.proxy_split],
        subset_num=args.proxy_size,
        split=args.proxy_split,
        verbose=True,
        return_remain=True,
    )
    dataset[args.proxy_split] = remain_data

    # partition data into clients
    fed_dataset = cls_iid_partition(
        dataset=dataset,
        split="train",
        data_name=args.dataset,
        num_clients=args.num_clients,
        test_split=test_split,
        subset_num=args.subset_num,
    )

    # add client idx for each local training sample
    for cid in range(args.num_clients):
        local_sample_num = len(fed_dataset[f"train-client{cid}"])
        cid_col = [cid for _ in range(local_sample_num)]
        fed_dataset[f"train-client{cid}"] = fed_dataset[
            f"train-client{cid}"
        ].add_column("cid", copy.deepcopy(cid_col))
    print("Add client ID column for each local training dataset.")

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # make datareader for all clients' local training data
    fed_reader = []
    for cid in range(args.num_clients):
        fed_reader.append(
            DatasetReader(
                fed_dataset[f"train-client{cid}"],
                input_columns=["sentence"],
                output_column="label",
            )
        )

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {"sentence": "</text>"}, ice_token="</E>")

    # generate retriever for each client: each local train data will be used to build local train corpus
    fed_retrievers = []
    RETRIEVER = get_fedretriever(args)
    for cid in range(args.num_clients):
        retriever = RETRIEVER(fed_reader[cid], ice_num=args.local_ice_num)
        fed_retrievers.append(retriever)

    inferencer = PPLFedOptBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    # proxy dataset
    retrievers = fed_retrievers
    query_dataset = proxy_data
    num_clients = len(retrievers)

    query_num = len(query_dataset)
    output_handler = PPLInferencerOutputHandler(inferencer.accelerator)

    sub_predictions = []
    ppl = []
    ice = []  # {cid: [] for cid in range(num_clients)}

    output_json_filename = inferencer.output_json_filename
    output_json_filepath = "debug"

    query_orig_idxs = query_dataset["idx"]
    query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
        query_dataset
    )
    query_dataloader = retrievers[0].create_dataloader(query_datalist)
    query_res_list = retrievers[0].forward(
        inferencer.retriever_model,
        query_dataloader,
        orig_idxs=query_orig_idxs,
        process_bar=True,
        information="Embedding test data queries...",
    )

    budget_file = "debug/proxy_opt_client_budget.json"
    proxy_client_budget_raw = read_json(budget_file)
    proxy_client_budget = {}
    for key in proxy_client_budget_raw:
        proxy_client_budget[int(key)] = copy.deepcopy(proxy_client_budget_raw[key])

    assert proxy_data["idx"] == list(proxy_client_budget.keys())

    budgets_predion = []
    for res in query_res_list:
        orig_idx = res["idx"]
        budgets_predion.append(proxy_client_budget[orig_idx])

    local_budgets = [[*content] for content in zip(*budgets_predion)]
    # print(len(local_budgets))
    # for ll in local_budgets:
    #     print(len(ll))

    ice_idx_dict = {}
    for cid in range(num_clients):
        print(f"Client {cid} retrieves local ICE samples...")
        ice_idx_dict[cid] = retrievers[cid].retrieve_with_budget(
            query_res_list,
            ice_budgets=local_budgets[cid],
            model=inferencer.retriever_model,
        )

    samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

    # ice original index in original training set
    local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
    for qid in range(query_num):
        q_orig_idx = query_res_list[qid]["idx"]
        local_sample_orig_idx[q_orig_idx] = {}
        for cid in range(num_clients):
            orig_idx = samples_for_all_query[qid][cid]["idx"]
            local_sample_orig_idx[q_orig_idx][cid] = orig_idx

    local_ice_index_file = f"{output_json_filepath}/local_ice_indices.json"
    save_json(local_sample_orig_idx, local_ice_index_file)

    concat_func2 = merge_concatenate

    server_ice2 = (
        []
    )  # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
    for qid in trange(
        query_num,
        disable=False,
        desc="Construct server side ICE dataset for each query: ",
    ):
        ice_data = concat_func2(
            samples_for_all_query[qid], chunk_sample_num=inferencer.args.server_ice_num
        )
        server_ice2.append(ice_data)

    # ground_truth_file = "debug/proxy_server_ice_indices.json"
    # ground_truth_server_ice_raw = read_json(ground_truth_file)
    # ground_truth_server_ice = {}
    # for key in ground_truth_server_ice_raw:
    #     ground_truth_server_ice[int(key)] = copy.deepcopy(ground_truth_server_ice_raw[key])

    # check_res_gt = []
    # for qid, q_orig_idx in enumerate(query_orig_idxs):
    #     tmp = set(ground_truth_server_ice[q_orig_idx]) == set(server_ice2[qid]["idx"])
    #     check_res_gt.append(tmp)

    # print(all(check_res_gt))

    server_ice_datasets = server_ice2
    labels = [0, 1, 2, 3, 4]

    ice_template = template
    prompt_template = None

    # 4. Generate in-context examples indices for testing inputs
    ice_dataset_retrievers = []
    server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
    for qid in trange(
        query_num,
        disable=False,
        desc="Server side generating ICE: ",
    ):
        data_dict = DatasetDict(
            {
                "train": server_ice_datasets[qid],
                "test": Dataset.from_list([query_dataset[qid]]),
            }
        )
        data = DatasetReader(
            data_dict,
            input_columns=retrievers[0].dataset_reader.input_columns,
            output_column=retrievers[0].dataset_reader.output_column,
        )
        per_query_retriever = get_server_retriever(args)(
            data, ice_num=args.server_ice_num
        )

        # check whether need to do reorder on server side
        ice_num = len(server_ice_datasets[qid])
        if args.concat == "reorder":
            ice_idxs = per_query_retriever.retrieve(
                model=inferencer.retriever_model, use_trange=False
            )[0]
        else:
            ice_idxs = list(range(ice_num))

        q_orig_idx = per_query_retriever.test_ds[0]["idx"]
        server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
            ice_idxs
        )["idx"]
        # use selected samples in each query's ICE dataset to generate ICE
        ice.append(
            per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
        )
        ice_dataset_retrievers.append(per_query_retriever)

    output_handler.save_ice(ice)
    server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
    save_json(server_sample_orig_idx, server_ice_index_file)

    # 5. Calculating PPL for prompts in each label's class
    for label in labels:
        index = 0
        prompt_list = []
        sub_ppl_list = []
        normalizing_prompt_list = []
        context_length_list = []

        # 5.1 Generate prompts of current label and truncate
        for qid in range(query_num):
            prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                0,
                ice[qid],
                label,
                ice_template=ice_template,
                prompt_template=prompt_template,
                remain_sep=False,
            )
            prompt_list.append(prompt)

        # 5.2 Get PPL: loop through all queries with current give label
        print(f"Calculating PPL for prompts labeled '{label}'")
        for idx in trange(0, len(prompt_list), inferencer.batch_size, disable=False):
            sub_prompt_list = prompt_list[idx : idx + inferencer.batch_size]

            with torch.no_grad():
                sub_res = inferencer._get_ppl(
                    sub_prompt_list
                ).tolist()  # list of scores, length is batch size

            for batch_offset, (res, prompt) in enumerate(zip(sub_res, sub_prompt_list)):
                sub_ppl_list.append(res)
                output_handler.save_prompt_and_ppl(
                    label,
                    prompt[len(ice[idx + batch_offset]) :],
                    prompt,
                    res,
                    index,
                )
                index = index + 1
        ppl.append(sub_ppl_list)

    # 6. Get lowest PPL class as predictions
    ppl = list(zip(*ppl))
    for single_ppl in ppl:
        sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
    output_handler.save_predictions(sub_predictions)

    # 7. Output
    output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename)

    output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
    output_handler.write_to_json(output_json_filepath, output_json_filename)

    predictions = [
        sample["prediction"] for sample in output_handler.results_dict.values()
    ]
    references = proxy_data["label"]
    evaluator = AccEvaluator()
    score = evaluator.score(predictions=predictions, references=references)
    print(score)


def test_whole_inference():
    data_name = "sst5"
    print(f"==== Data: {data_name} ===")
    args = make_args(
        **{
            "dataset": data_name,
            "num_class": data2numcls[data_name],
            "noise_p": 0.0,
            "debug": False,
            "run": True,
            "num_clients": 4,
            "model": "EleutherAI/gpt-neo-2.7B",
            "retriever": "topk",
            "partition": "iid",
            "proxy_split": "test",
            "proxy_size": 500,
            "log_dir": "debug2",
            "local_ice_num": 32,
            "server_ice_num": 32,
            "subset_num": None,
            "concat": "reorder",
        }
    )

    os.makedirs(args.log_dir, exist_ok=True)

    # load dataset
    dataset = load_dataset(data2usename[args.dataset])

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # slice a subset for proxy dataset on server
    proxy_data, remain_data = datasplit_subset(
        dataset[args.proxy_split],
        subset_num=args.proxy_size,
        split=args.proxy_split,
        verbose=True,
        return_remain=True,
    )
    dataset[args.proxy_split] = remain_data

    # partition data into clients
    fed_dataset = cls_iid_partition(
        dataset=dataset,
        split="train",
        data_name=args.dataset,
        num_clients=args.num_clients,
        test_split=test_split,
        subset_num=args.subset_num,
    )

    # add client idx for each local training sample
    for cid in range(args.num_clients):
        local_sample_num = len(fed_dataset[f"train-client{cid}"])
        cid_col = [cid for _ in range(local_sample_num)]
        fed_dataset[f"train-client{cid}"] = fed_dataset[
            f"train-client{cid}"
        ].add_column("cid", copy.deepcopy(cid_col))
    print("Add client ID column for each local training dataset.")

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # make datareader for all clients' local training data
    fed_reader = []
    for cid in range(args.num_clients):
        fed_reader.append(
            DatasetReader(
                fed_dataset[f"train-client{cid}"],
                input_columns=["sentence"],
                output_column="label",
            )
        )

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {"sentence": "</text>"}, ice_token="</E>")

    # generate retriever for each client: each local train data will be used to build local train corpus
    fed_retrievers = []
    RETRIEVER = get_fedretriever(args)
    for cid in range(args.num_clients):
        retriever = RETRIEVER(fed_reader[cid], ice_num=args.local_ice_num)
        fed_retrievers.append(retriever)

    inferencer = PPLFedOptBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    model_name = "DMLP"
    embedding_size = (768,)  # size is (dim,)
    model_width = 512
    print(f"Try to create budget model")
    model = get_model(
        model_name=model_name,
        output_size=args.num_clients,
        data_shape=embedding_size,
        width=model_width,
    )
    model = model.to(inferencer.device)
    print(f"Try to load saved model parameter...")
    model.load_state_dict(torch.load("debug/budget_model.pt"))
    print(f"Load model parameter done")
    inferencer.budget_model = model

    predictions = inferencer.inference(
        fed_retrievers,
        proxy_data,
        ice_template=template,
        concat=args.concat,
        output_json_filepath=args.log_dir,
        args=args,
    )
    references = proxy_data["label"]
    evaluator = AccEvaluator()
    score = evaluator.score(predictions=predictions, references=references)
    print(score)


def test_budget_model_prediciton():
    data_name = "sst5"
    print(f"==== Data: {data_name} ===")
    args = make_args(
        **{
            "dataset": data_name,
            "num_class": data2numcls[data_name],
            "noise_p": 0.0,
            "debug": False,
            "run": True,
            "num_clients": 4,
            "model": "EleutherAI/gpt-neo-2.7B",
            "retriever": "topk",
            "partition": "iid",
            "proxy_split": "test",
            "proxy_size": 500,
            "log_dir": "debug",
            "local_ice_num": 32,
            "server_ice_num": 32,
            "subset_num": None,
            "concat": "reorder",
        }
    )

    os.makedirs(args.log_dir, exist_ok=True)

    # load dataset
    dataset = load_dataset(data2usename[args.dataset])

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # slice a subset for proxy dataset on server
    proxy_data, remain_data = datasplit_subset(
        dataset[args.proxy_split],
        subset_num=args.proxy_size,
        split=args.proxy_split,
        verbose=True,
        return_remain=True,
    )
    dataset[args.proxy_split] = remain_data

    # partition data into clients
    fed_dataset = cls_iid_partition(
        dataset=dataset,
        split="train",
        data_name=args.dataset,
        num_clients=args.num_clients,
        test_split=test_split,
        subset_num=args.subset_num,
    )

    # add client idx for each local training sample
    for cid in range(args.num_clients):
        local_sample_num = len(fed_dataset[f"train-client{cid}"])
        cid_col = [cid for _ in range(local_sample_num)]
        fed_dataset[f"train-client{cid}"] = fed_dataset[
            f"train-client{cid}"
        ].add_column("cid", copy.deepcopy(cid_col))
    print("Add client ID column for each local training dataset.")

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # make datareader for all clients' local training data
    fed_reader = []
    for cid in range(args.num_clients):
        fed_reader.append(
            DatasetReader(
                fed_dataset[f"train-client{cid}"],
                input_columns=["sentence"],
                output_column="label",
            )
        )

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {"sentence": "</text>"}, ice_token="</E>")

    # generate retriever for each client: each local train data will be used to build local train corpus
    fed_retrievers = []
    RETRIEVER = get_fedretriever(args)
    for cid in range(args.num_clients):
        retriever = RETRIEVER(fed_reader[cid], ice_num=args.local_ice_num)
        fed_retrievers.append(retriever)

    inferencer = PPLFedOptBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    # proxy dataset
    retrievers = fed_retrievers
    query_dataset = proxy_data
    num_clients = len(retrievers)

    query_num = len(query_dataset)
    output_handler = PPLInferencerOutputHandler(inferencer.accelerator)

    sub_predictions = []
    ppl = []
    ice = []  # {cid: [] for cid in range(num_clients)}

    output_json_filename = inferencer.output_json_filename
    output_json_filepath = "debug"

    query_orig_idxs = query_dataset["idx"]
    query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
        query_dataset
    )
    query_dataloader = retrievers[0].create_dataloader(query_datalist)
    query_res_list = retrievers[0].forward(
        inferencer.retriever_model,
        query_dataloader,
        orig_idxs=query_orig_idxs,
        process_bar=True,
        information="Embedding test data queries...",
    )
    model_name = "DMLP"
    embedding_size = query_res_list[0]["embed"].shape  # size is (dim,)
    model_width = 512
    print(f"Try to create budget model")
    model = get_model(
        model_name=model_name,
        output_size=args.num_clients,
        data_shape=embedding_size,
        width=model_width,
    )
    model = model.to(inferencer.device)
    print(f"Try to load saved model parameter...")
    model.load_state_dict(torch.load("debug/budget_model.pt"))
    print(f"Load model parameter done")
    inferencer.budget_model = model
    budgets_prediciton = inferencer._get_budgets_model_prediction(query_res_list)
    budget_sum = map(sum, budgets_prediciton)
    cnt_budget = Counter(budget_sum)
    print(f"{budgets_prediciton[:5]}")
    print(cnt_budget)


def test_proxy_data_opt_budget():
    data_name = "sst5"
    print(f"==== Data: {data_name} ===")
    args = make_args(
        **{
            "dataset": data_name,
            "num_class": data2numcls[data_name],
            "noise_p": 0.0,
            "debug": False,
            "run": True,
            "num_clients": 4,
            "model": "EleutherAI/gpt-neo-2.7B",
            "retriever": "topk",
            "partition": "iid",
            "proxy_split": "test",
            "proxy_size": 500,
            "log_dir": "debug2",
            "local_ice_num": 32,
            "server_ice_num": 32,
            "subset_num": None,
            "concat": "reorder",
        }
    )

    os.makedirs(args.log_dir, exist_ok=True)

    # load dataset
    dataset = load_dataset(data2usename[args.dataset])

    if "test" not in dataset:
        test_split = "validation"
    else:
        test_split = "test"

    # slice a subset for proxy dataset on server
    proxy_data, remain_data = datasplit_subset(
        dataset[args.proxy_split],
        subset_num=args.proxy_size,
        split=args.proxy_split,
        verbose=True,
        return_remain=True,
    )
    dataset[args.proxy_split] = remain_data

    # partition data into clients
    fed_dataset = cls_iid_partition(
        dataset=dataset,
        split="train",
        data_name=args.dataset,
        num_clients=args.num_clients,
        test_split=test_split,
        subset_num=args.subset_num,
    )

    # add client idx for each local training sample
    for cid in range(args.num_clients):
        local_sample_num = len(fed_dataset[f"train-client{cid}"])
        cid_col = [cid for _ in range(local_sample_num)]
        fed_dataset[f"train-client{cid}"] = fed_dataset[
            f"train-client{cid}"
        ].add_column("cid", copy.deepcopy(cid_col))
    print("Add client ID column for each local training dataset.")

    # save selected test query subset
    subset_orig_idxs = fed_dataset[test_split]["idx"]
    subset_orig_idxs_file = os.path.join(args.log_dir, "query_subset_orig_idxs.json")
    with open(subset_orig_idxs_file, "w") as f:
        json.dump(subset_orig_idxs, f)

    # make datareader for all clients' local training data
    fed_reader = []
    for cid in range(args.num_clients):
        fed_reader.append(
            DatasetReader(
                fed_dataset[f"train-client{cid}"],
                input_columns=["sentence"],
                output_column="label",
            )
        )

    # get template for classification
    tp_dict = gen_template_classification(args)
    template = PromptTemplate(tp_dict, {"sentence": "</text>"}, ice_token="</E>")

    # generate retriever for each client: each local train data will be used to build local train corpus
    fed_retrievers = []
    RETRIEVER = get_fedretriever(args)
    for cid in range(args.num_clients):
        retriever = RETRIEVER(fed_reader[cid], ice_num=args.local_ice_num)
        fed_retrievers.append(retriever)

    inferencer = PPLFedOptBudgetInferencer(
        model_name=args.model, output_json_filepath=args.log_dir, args=args
    )

    print("=" * 20)
    print(f"Perform proxy data without inference")
    proxy_ice_source, proxy_per_client_budgets, proxy_preds = (
        inferencer.proxy_data_opt_budget(
            fed_retrievers,
            query_dataset=proxy_data,
            ice_template=template,
            output_json_filepath=args.log_dir,
            inference=False,
            args=args,
            local_ice_num=args.server_ice_num,
        )
    )

    train_loss_hist, eval_loss_hist, proxy_res_list = (
        inferencer.train_local_budget_model(
            fed_retrievers,
            query_dataset=proxy_data,
            opt_client_budget=proxy_per_client_budgets,
            model_name="SMLP",
            model_width=380,
            epochs=200,
            lr=0.01,
            batch_size=8,
            train_ratio=0.8,
            seed=0,
        )
    )
    torch.save(
        {"train_loss_hist": train_loss_hist, "eval_loss_hist": eval_loss_hist},
        os.path.join(args.log_dir, "train_hist.pt"),
    )
    # print(
    #     f"Check idx order: {[res['idx'] for res in proxy_res_list] == proxy_data['idx']}"
    # )

    # ==== Reorder
    predictions = inferencer.inference(
        fed_retrievers,
        query_dataset=proxy_data,
        ice_template=template,
        concat="reorder",
        output_json_filepath=args.log_dir,
        args=args,
    )

    proxy_reference = proxy_data["label"]
    evaluator = AccEvaluator()
    score = evaluator.score(predictions=predictions, references=proxy_reference)
    print("Reorder inference on proxy:", score)

    # ==== Merge
    predictions = inferencer.inference(
        fed_retrievers,
        query_dataset=proxy_data,
        ice_template=template,
        concat="merge",
        output_json_filepath=args.log_dir,
        args=args,
    )

    proxy_reference = proxy_data["label"]
    evaluator = AccEvaluator()
    score = evaluator.score(predictions=predictions, references=proxy_reference)
    print("Merge inference on proxy:", score)

    # ==== Simple concate
    predictions = inferencer.inference(
        fed_retrievers,
        query_dataset=proxy_data,
        ice_template=template,
        concat="simple",
        output_json_filepath=args.log_dir,
        args=args,
    )

    proxy_reference = proxy_data["label"]
    evaluator = AccEvaluator()
    score = evaluator.score(predictions=predictions, references=proxy_reference)
    print("Simple concate inference on proxy:", score)


if __name__ == "__main__":
    # test_inference()
    # test_budget_model_prediciton()
    # test_whole_inference()
    test_proxy_data_opt_budget()
    pass
