"""
This example follows the experimental settings of the GPT-2 PubMed experiments in the ICML 2024 Spotlight paper,
"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749).

The ``model_name_or_path`` parameter can be set to other models on HuggingFace. Note that we use the FastChat
library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation template of your
desired model is not available in FastChat, please register the conversation template in the FastChat library. See the
following link for an example:
https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py

For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library:
https://microsoft.github.io/DPSDA/.
"""

from pe.data.text import PubMed
from pe.logging import setup_logging
from pe.runner import PE
from pe.population import PEPopulation
from pe.dp import Gaussian, Exponential
from pe.api.text import LLMAugPE
from pe.llm import HuggingfaceLLM
from pe.llm import NONE_INSTRUCT_MODELS
from pe.embedding.text import SentenceTransformer
from pe.histogram import NearestNeighbors
from pe.callback import SaveCheckpoints
from pe.callback import ComputeFID, ComputePrecisionRecall, ComputeFormatMatch
from pe.callback import SaveTextToCSV
from pe.logger import CSVPrint
from pe.logger import LogPrint
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME

import pandas as pd
import os
import numpy as np

import argparse
# from datasets import Dataset
# from sklearn.model_selection import train_test_split
import torch
# import gc
from peft import LoraConfig, get_peft_model, TaskType
import accelerate

# from pe.llm import sft_fine_tune, sft_fine_tune_until_converge, get_per_sample_loss
from pe.runner import PESGD

from huggingface_hub import login


pd.options.mode.copy_on_write = True
login(token="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")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='summary generator')

    parser.add_argument("--task", type=str, default='pubmed', help="['pubmed', 'biorxiv', 'congressional'] The differential privacy mechanism to use.")
    parser.add_argument("--method", type=str, default='OptGradCombination', help="['ReTrain', 'GhostGradDot', 'OptGradCombination', 'purePE', 'privSGD'] The differential privacy mechanism to use.")
    parser.add_argument('--seed', type=int, default=0, metavar='seed', help='random seed (default: 0)')
    parser.add_argument('--gpu', default=0, type=int, help='gpu device id')
    parser.add_argument('--gen_model_name', type=str, default='Qwen/Qwen3-4B', help="The model for generation, might be an API.")
    parser.add_argument("--fine_tune_model_name", type=str, default='Qwen/Qwen3-4B', help="The model fine-tuned on synthetic dataset.")
    parser.add_argument("--selection_model_name", type=str, default='Qwen/Qwen3-4B', help="The model fine-tuned on private dataset.")
    parser.add_argument("--fine_tune_model_train_iter", type=int, default=10, help="Numbers of iterations for training the target model for fine-tune. Default 10.")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate. Default as 5e-5.")
    parser.add_argument("--lr_scaler", type=float, default=0.5, help="Learning rate scaler for grid search.")
    parser.add_argument("--selection_model_train_iter", type=int, default=10, help="Numbers of iterations for training each selection model. Default 10.")
    parser.add_argument("--setting", type=str, default='SelfGen', help="['SelfGen', 'LargeGen']. SelfGen: The model fine-tuned is the model used as generation; LargeGen: Huge LLM API are used for generation.")

    parser.add_argument("--max_completion_tokens", type=int, default=1024, help="The maximum number of tokens to generate in the response. Should be related to task.")

    parser.add_argument("--dp_mechanism", type=str, default='Gaussian', help="['Exponential', 'Gaussian'] The differential privacy mechanism to use.")
    parser.add_argument("--dp_epsilon", type=float, default=1.0, help="[1,2,4,inf(>=100000000.0)], (epsilon, delta)-DP protection of the histogram for each party, delata=0.0 for exponential mechanism")
    parser.add_argument("--dp_delta", type=float, default=0.0, help="[0.0, 1E-5, 1E-5/(args.steps+1)], (epsilon, delta)-DP protection of the histogram for each party")
    parser.add_argument("--dp_clip_norm", type=float, default=1.0, help="Clipping threshold for DP noise addition")
    parser.add_argument("--dp_syn_cluster_num", type=int, default=10, help="How many clusters to use for clustering the synthetic data. Default 10.")
    
    parser.add_argument("--metric_inverse_epsilon", type=float, default=1e-6, help="The small epsilon used to inverse GTG. Default 1e-6.")

    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank, default 8.")
    parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha, default 32.")
    parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout rate, default 0.1")
    parser.add_argument("--target_modules", default=['q_proj', 'v_proj'], nargs='+', type=str)

    parser.add_argument("--variation_api_fold", type=int, default=2, help="How many variations to apply to the initial synthetic data. Default 2, which means N*(2+1) samples will be generated.")

    parser.add_argument("--num_synthetic_per_iter", type=int, default=1000, help="Numbers of synthetic samples to generate per iteration (before api fold). Default 1000.")
    parser.add_argument("--num_iter", type=int, default=10, help="Numbers of iterations for synthetic samples. Default 10.")
    parser.add_argument("--priv_select_ratio_per_iter", type=float, default=1.0, help="Numbers of synthetic samples to generate per iteration (before api fold). Default 1.0")
    parser.add_argument("--prompt_select_strategy", type=str, default='RandomK', help="Method for seed sample selection. ['RandomK', 'RandomKabs', 'TopK', 'TopKabs']")
    
    parser.add_argument("--train_priv_size", type=str, default='small', help="Number of samples used as train private data. ['small'=400, 'middle'=4000, 'total'=all]")
    
    parser.add_argument("--debug", type=int, default=0, help="debug mode, will use small test private dataset to accelerate the process")
    parser.add_argument("--with_instruction_base", type=int, default=0, help="debug mode, will use small test private dataset to accelerate the process")
    parser.add_argument("--loss_distribution", type=int, default=0, help="whether to calculate and plot loss distribution change before and after training")
    parser.add_argument("--llm_additional_generation", type=int, default=0, help="whether to calculate and plot loss distribution change before and after training")


    args = parser.parse_args()

    if args.setting == 'SelfGen':
        assert args.fine_tune_model_name == args.gen_model_name, f"Should have Generation-LM={args.gen_model_name} equals to Fine-tuned-LM={args.fine_tune_model_name}, but got different models instead."
    elif ('OptGradCombination' in args.method) or args.method == 'purePE':
        # # args.fine_tune_model_train_iter = int(1/args.val_sample_ratio)
        # # TODO: do we need to change this?
        # # print(f"[warning] args.fine_tune_model_train_iter is always set to {int(1/args.val_sample_ratio)=} now")
        # args.num_iter *= int(1/args.val_sample_ratio)
        args.dp_syn_cluster_num = 0 # no-use argument, set to zero
    # if args.method == 'privSGD':
    #     args.num_iter = 0

    skip_train = False
    if 'Eigen' in args.method:
        assert 'FixSample' in args.method and 'NoisyRealGrad' in args.method, f"[ERROR] If eigen value of private gradient is used, FixSample and NoisyRealGrad are required."
    if 'LossDistribution' in args.method:
        args.method = args.method.replace('LossDistribution', '')
        args.method = args.method.replace('_', '')
        args.loss_distribution = 1
        skip_train = True

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # load the private dataset
    if args.train_priv_size == 'small':
        data = PubMed(root_dir=f"./data/{args.task}", split="train_small")
    elif args.train_priv_size == 'small2':
        data = PubMed(root_dir=f"./data/{args.task}", split="train_small2")
    elif args.train_priv_size == 'middle':
        data = PubMed(root_dir=f"./data/{args.task}", split="train_middle")
    else: # use all the samples from the train private dataset, will be very large
        data = PubMed(root_dir=f"./data/{args.task}", split="train")
    if args.debug != 0:
        priv_data_test = PubMed(root_dir=f"./data/{args.task}", split="test_small") 
    else:
        priv_data_test = PubMed(root_dir=f"./data/{args.task}", split="test") 

    _prompt_selection = f'_{args.prompt_select_strategy}' if (('OptGradCombination' in args.method) and ('FixSample' not in args.method)) else ''
    _with_instruction_base = f'_withInstructionBase' if (('OptGradCombination' in args.method) and (args.with_instruction_base != 0)) else ''
    # exp_folder = f"./results{'_debug' if (args.debug!=0) else ''}/{args.method}{_prompt_selection}{_with_instruction_base}/{args.setting}/text/{args.task}/{args.dp_mechanism}_{args.dp_epsilon}_{args.dp_delta}_{args.dp_syn_cluster_num}cluster_{args.variation_api_fold}fold_priv{len(data.data_frame)}/{args.gen_model_name.split('/')[-1]}/{args.fine_tune_model_name.split('/')[-1]}/{args.selection_model_name.split('/')[-1]}/[{args.num_synthetic_per_iter}]_{args.num_iter}_select{args.selection_model_train_iter}_finetune{args.fine_tune_model_train_iter}_priveratio{args.priv_select_ratio_per_iter}_lr{args.lr}_MetInvEps{args.metric_inverse_epsilon}_clipNorm{args.dp_clip_norm}/seed{args.seed}/"
    exp_folder = f"./results_new{'_debug' if (args.debug!=0) else ''}/{args.method}{_prompt_selection}{_with_instruction_base}/{args.setting}/text/{args.task}/{args.dp_mechanism}_{args.dp_epsilon}_{args.dp_delta}_{args.dp_syn_cluster_num}cluster_{args.variation_api_fold}fold_priv{len(data.data_frame)}/{args.gen_model_name.split('/')[-1]}/{args.fine_tune_model_name.split('/')[-1]}/{args.selection_model_name.split('/')[-1]}/[{args.num_synthetic_per_iter}]_{args.num_iter}_select{args.selection_model_train_iter}_finetune{args.fine_tune_model_train_iter}_priveratio{args.priv_select_ratio_per_iter}_lr{args.lr}_MetInvEps{args.metric_inverse_epsilon}_clipNorm{args.dp_clip_norm}/seed{args.seed}/"
    current_folder = os.path.dirname(os.path.abspath(__file__))

    setup_logging(log_file=os.path.join(exp_folder, "log.txt"))

    init_data_folder = f"./results_new{'_debug' if (args.debug!=0) else ''}/_initial_data/text/{args.task}/{args.variation_api_fold}fold/{args.gen_model_name.split('/')[-1]}/[{args.num_synthetic_per_iter}]/seed{args.seed}/synthetic_text/"
    init_data_checkpoint = f"{init_data_folder}000000000/"
    if not os.path.exists(init_data_checkpoint):
        os.makedirs(init_data_checkpoint, exist_ok=True)

    # # Split priv_train_data into train and dev sets
    # train_df, dev_df = train_test_split(data.data_frame, test_size=0.1, random_state=args.seed, shuffle=True)
    # priv_train_data = Dataset.from_psandas(train_df.reset_index(drop=True))
    # priv_dev_data = Dataset.from_pandas(dev_df.reset_index(drop=True))
    # if not 'text' in priv_train_data.column_names:
    #     priv_train_data = priv_train_data.add_column("text", priv_train_data['PE.TEXT'])
    #     priv_dev_data = priv_dev_data.add_column("text", priv_dev_data['PE.TEXT'])

    args.llm_add_instruction = (args.fine_tune_model_name not in NONE_INSTRUCT_MODELS)
    # args.llm_add_instruction = True
        

    glm = HuggingfaceLLM(max_completion_tokens=args.max_completion_tokens, model_name_or_path=args.gen_model_name, temperature=1.0, device_map='auto', gen_with_instruction=args.llm_add_instruction) # generation model
    llm = HuggingfaceLLM(max_completion_tokens=args.max_completion_tokens, model_name_or_path=args.fine_tune_model_name, temperature=1.0, device_map=None, gen_with_instruction=args.llm_add_instruction) # fine-tune model
    slm = HuggingfaceLLM(max_completion_tokens=args.max_completion_tokens, model_name_or_path=args.selection_model_name, temperature=1.0, device_map=None, gen_with_instruction=args.llm_add_instruction) # selection model
    accelerator = accelerate.Accelerator()

    # # glm._model, glm_offload_hook = accelerate.cpu_offload_with_hook(glm._model, execution_device="cuda")
    # # glm_offload_hook.offload()
    slm_lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM, 
    )
    # print(f"{slm._model=}")
    slm._model = get_peft_model(slm._model, slm_lora_config)
    slm._model = accelerator.prepare(slm._model)
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    slm._model, slm_offload_hook = accelerate.cpu_offload_with_hook(slm._model, execution_device="cuda")
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    slm_offload_hook.offload()
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    # # for _, module in slm._model.named_modules():
    # #     if "lora" in module.__class__.__name__.lower():
    # #        module.to(dtype=torch.float16, device=next(slm._model.parameters()).device)
    llm_lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM, 
    )
    llm._model = get_peft_model(llm._model, llm_lora_config)
    llm._model = accelerator.prepare(llm._model)
    llm._model, llm_offload_hook = accelerate.cpu_offload_with_hook(llm._model, execution_device="cuda")
    llm_offload_hook.offload()
    # for _, module in llm._model.named_modules():
    #     if "lora" in module.__class__.__name__.lower():
    #         module.to(dtype=torch.float16, device=next(llm._model.parameters()).device)
    
    # for name, param in llm._model.named_parameters():
    #     print(f"{name}: shape={param.shape}, values={param.data.flatten()[:50]}")
    # print(f"{llm._model.dtype=}, {llm._model.config=}")

    api = LLMAugPE(
        llm=glm,
        random_api_prompt_file=os.path.join(current_folder, f"prompt/{args.task}/random_api_prompt.json"),
        variation_api_prompt_file=os.path.join(current_folder, f"prompt/{args.task}/variation_api_prompt.json"),
    )

    embedding = SentenceTransformer(model="sentence-t5-base")
    histogram = NearestNeighbors(
        embedding=embedding,
        mode="L2",
        lookahead_degree=0,
    )
    population = PEPopulation(
        # api=api, initial_variation_api_fold=6, next_variation_api_fold=6, keep_selected=True, selection_mode="rank"
        api=api, initial_variation_api_fold=args.variation_api_fold, next_variation_api_fold=args.variation_api_fold, keep_selected=True, selection_mode="random", # selection_mode="random" for gradient version of PE
    )

    save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
    compute_fid = ComputeFID(
        priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
    )
    compute_precision_recall = ComputePrecisionRecall(
        priv_data=data,
        embedding=embedding,
        filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1},
        num_precision_neighbors=10, # default is 4
        num_recall_neighbors=10, # default is 5
    )
    comput_format_match = ComputeFormatMatch(
        format_type=args.task,
    )
    save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))

    csv_print = CSVPrint(output_folder=exp_folder)
    log_print = LogPrint()

    num_private_samples = len(data.data_frame)
    delta = 1.0 / num_private_samples / np.log(num_private_samples)
    if delta > 1e-5:
        delta = 1e-5

    _iter = 0

    # # fine-tune the selection model on the private dataset
    # logging_dir = exp_folder + f"/slm/{_iter}/"
    # if not os.path.exists(logging_dir):
    #     os.makedirs(logging_dir)
    # glm._model = glm._model.to('cpu') 
    # llm._model = llm._model.to('cpu')
    # torch.cuda.empty_cache()
    # gc.collect()
    # slm_sft_eval_metric, slm._model, slm._tokenizer = sft_fine_tune(
    #     slm._model, slm._tokenizer, priv_train_data, output_dir=logging_dir,
    #     per_device_train_batch_size=8, num_train_epochs=10, learning_rate=5e-5,
    #     save_steps=100, logging_steps=5,
    # )
    # slm_sft_eval_metric, slm._model, slm._tokenizer = sft_fine_tune_until_converge(
    #     slm._model, slm._tokenizer, priv_train_data, output_dir=logging_dir,
    #     per_device_train_batch_size=8, max_epochs=20, learning_rate=5e-5,
    #     save_steps=100, logging_steps=5,
    #     min_delta=0.001, patience=3,
    # )

    # log_print.single_log("info", f"Selection model fine-tuned on private dataset in Iteration#{_iter} with evaluation metric: {slm_sft_eval_metric}")

    # # Compute per-sample loss on the dev set using the selection model
    # per_sample_losses_llm = get_per_sample_loss(
    #     model=llm._model, tokenizer=llm._tokenizer, dataset=priv_dev_data, batch_size=8,
    # )
    # per_sample_losses_slm = get_per_sample_loss(
    #     model=slm._model, tokenizer=slm._tokenizer, dataset=priv_dev_data, batch_size=8,
    # )
    # # per_sample_losses now contains the loss for each sample in priv_dev_data
    
    pe_runner = PESGD(
        priv_data=[data, priv_data_test],
        population=population,
        histogram=histogram,
        seed=args.seed,
        t_select=args.selection_model_train_iter,  # How many times to train the selection model
        t_fine_tune=args.fine_tune_model_train_iter,  # How many times to train the target model
        exp_folder=exp_folder,
        dp=Gaussian() if args.dp_mechanism.lower() == "gaussian" else Exponential(),
        syn_cluster_num=args.dp_syn_cluster_num,
        loggers=[csv_print, log_print],
        log_print_logger=log_print,
        callbacks=[save_checkpoints, save_text_to_csv, compute_fid, compute_precision_recall] + ([comput_format_match] if ('archehrqa' in args.task) else []),
        llm=llm,
        slm=slm,
        setting=args.setting,
        original_llm_path=args.fine_tune_model_name,
        init_data_file=init_data_checkpoint, # this is a folder
        llm_add_instruction=args.llm_add_instruction,
        llm_additional_generation=args.llm_additional_generation,
    )

    if args.loss_distribution != 0:
        pe_runner.run_loss_distribution(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio=args.priv_select_ratio_per_iter,
            lr=args.lr,
            _scaler=args.lr_scaler,
            metric_inverse_epsilon=args.metric_inverse_epsilon,
            noise_place='inner_product' if ('V2' in args.method) else 'coefficient', # V2 is of bad performance
            prompt_type=args.prompt_select_strategy,
            clip_or_normalize='clip' if ('Clip' in args.method or 'InnerProduct' in args.method) else 'normalize',
            noise_on_vote=('NoisyRealGrad' not in args.method),
            sample_evolve=('FixSample' not in args.method),
            approx_strategy='opt' if (not 'InnerProduct' in args.method) else ('woResidual' if ('NoResidual' in args.method) else 'wResidual'),
            use_eigen=('Eigen' in args.method),
            clip_norm=args.dp_clip_norm,
        )
        if skip_train:
            assert skip_train == False, "No training, ends here"


    if args.method == 'GhostGradDot':
        pe_runner.run_dot_gradient(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
        )
    # elif args.method == 'OptGradCombination':
    elif 'OptGradCombination' in args.method:
        pe_runner.run_optimal_gradient_combination(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio=args.priv_select_ratio_per_iter,
            lr=args.lr,
            metric_inverse_epsilon=args.metric_inverse_epsilon,
            noise_place='inner_product' if ('V2' in args.method) else 'coefficient', # V2 is of bad performance
            prompt_type=args.prompt_select_strategy,
            _scaler=args.lr_scaler,
            clip_or_normalize='clip' if ('Clip' in args.method) else 'normalize',
            noise_on_vote=('NoisyRealGrad' not in args.method),
            sample_evolve=('FixSample' not in args.method),
            approx_strategy='opt',
            with_instruction_base=args.with_instruction_base,
            use_eigen=('Eigen' in args.method),
            clip_norm=args.dp_clip_norm,
        )
    elif 'GradCombination' in args.method:
        pe_runner.run_optimal_gradient_combination(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio=args.priv_select_ratio_per_iter,
            lr=args.lr,
            metric_inverse_epsilon=args.metric_inverse_epsilon,
            noise_place='coefficient', # V2 is of bad performance
            prompt_type=args.prompt_select_strategy,
            _scaler=args.lr_scaler,
            clip_or_normalize='clip',
            noise_on_vote=('NoisyRealGrad' not in args.method),
            sample_evolve=False,
            approx_strategy='woResidual' if ('NoResidual' in args.method) else 'wResidual',
            with_instruction_base=args.with_instruction_base,
            use_eigen=('Eigen' in args.method),
            clip_norm=args.dp_clip_norm,
        )
    # elif args.method == 'OptGradCombinationV2':
    #     pe_runner.run_optimal_gradient_combination(
    #         num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
    #         delta=delta,
    #         epsilon=args.dp_epsilon,
    #         checkpoint_path=os.path.join(exp_folder, "checkpoint"),
    #         val_sample_ratio=args.priv_select_ratio_per_iter,
    #         lr=args.lr,
    #         metric_inverse_epsilon=args.metric_inverse_epsilon,
    #         noise_place='inner_product',
    #         _scaler=args.lr_scaler,
    #     )
    elif args.method == 'ReTrain':
        pe_runner.run(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
        )
    elif args.method == 'purePE':
        pe_runner.run_pure_pe(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio_of_conterpart=args.priv_select_ratio_per_iter,
            lr=args.lr,
            _scaler=args.lr_scaler,
        )
    elif args.method == 'topQPE':
        pe_runner.run_topQ_pe(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            lr=args.lr,
            Q=8, # top-Q voting
        )
    # elif args.method == 'singleLLMWASP':
    #     pe_runner.run_pure_pe(
    #         num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
    #         delta=delta,
    #         epsilon=args.dp_epsilon,
    #         checkpoint_path=os.path.join(exp_folder, "checkpoint"),
    #         lr=args.lr,
    #         Q=8, # top-Q voting
    #     )
    elif args.method == 'privSGD':
        if args.dp_epsilon <= 1E7:
            args.num_iter = int(args.num_iter * args.priv_select_ratio_per_iter) # adjust the number of iterations for privSGD
        pe_runner.run_priv_sgd(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio=args.priv_select_ratio_per_iter,
            lr=args.lr,
            _scaler=args.lr_scaler,
        )
    elif 'LossPlot' in args.method:
        pe_runner.run_loss_plot(
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
            val_sample_ratio=args.priv_select_ratio_per_iter,
            lr=args.lr,
            _scaler=args.lr_scaler,
            metric_inverse_epsilon=args.metric_inverse_epsilon,
            noise_place='inner_product' if ('V2' in args.method) else 'coefficient', # V2 is of bad performance
            prompt_type=args.prompt_select_strategy,
            clip_or_normalize='clip' if ('Clip' in args.method) else 'normalize',
            noise_on_vote=('NoisyRealGrad' not in args.method),
            sample_evolve=('FixSample' not in args.method),
            approx_strategy='opt',
            with_instruction_base=args.with_instruction_base,
            use_eigen=('Eigen' in args.method),
        )
    else:
        pe_runner.run(
            # num_samples_schedule=[10] * 4,
            # num_samples_schedule=[100] * 11,
            # num_samples_schedule=[2000] * 11,
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
        )

