"""
This example follows the experimental settings of the GPT-3.5 PubMed experiments in the ICML 2024 Spotlight paper,
"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749), except
that the model is changed from GPT-3.5 to GPT-4o-mini-2024-07-18 as the original GPT-3.5 model version used in the
paper is no longer available.

To run the code, the following environment variables are required:
* OPENAI_API_KEY: OpenAI API key. You can get it from https://platform.openai.com/account/api-keys. Multiple keys can
    be separated by commas, and a key with the lowest current workload will be used for each request.

We can also switch from OpenAI API to Azure OpenAI API by using :py:class:`pe.llm.AzureOpenAILLM` instead
of :py:class:`pe.llm.OpenAILLM`. In that case, the following environment variables are required:
* ``AZURE_OPENAI_API_KEY``: Azure OpenAI API key. You can get it from https://portal.azure.com/. Multiple keys can
    be separated by commas. The key can also be "AZ_CLI", in which case the Azure CLI will be used to authenticate
    the requests, and the environment variable ``AZURE_OPENAI_API_SCOPE`` needs to be set. See Azure OpenAI
    authentication documentation for more information:
    https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#microsoft-entra-id-authentication
* ``AZURE_OPENAI_API_ENDPOINT``: Azure OpenAI endpoint. Multiple endpoints can be separated by commas. You can get
    it from https://portal.azure.com/.
* ``AZURE_OPENAI_API_VERSION``: Azure OpenAI API version. You can get it from https://portal.azure.com/.
Assuming $x_1$ API keys and $x_2$ endpoints are provided. When it is desired to use $X$ endpoints/API keys, then
$x_1,x_2$ must be equal to either $X$ or 1, and the maximum of $x_1,x_2$ must be $X$. For the $x_i$ that equals to
1, the same value will be used for all endpoints/API keys. For each request, the API key + endpoint pair with the
lowest current workload will be used.

These environment variables can be set in a .env file in the same directory as this script. For example:
```
OPENAI_API_KEY=your_openai_api_key
```
See https://github.com/theskumar/python-dotenv for more information about the .env file.

For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library:
https://microsoft.github.io/DPSDA/.
"""

from dotenv import load_dotenv

from pe.data.text import PubMed
from pe.logging import setup_logging
from pe.runner import PE
from pe.population import PEPopulation
from pe.dp import Gaussian, Exponential
from pe.api.text import LLMAugPE
from pe.llm import OpenAILLM, AzureOpenAILLM, HuggingfaceLLM
from pe.embedding.text import SentenceTransformer
from pe.histogram import NearestNeighbors
from pe.callback import SaveCheckpoints
from pe.callback import ComputeFID, ComputePrecisionRecall
from pe.callback import SaveTextToCSV
from pe.logger import CSVPrint
from pe.logger import LogPrint
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME

import pandas as pd
import os
import numpy as np

import argparse
# from datasets import Dataset
# from sklearn.model_selection import train_test_split
import torch
# import gc
from peft import LoraConfig, get_peft_model, TaskType
import accelerate

# from pe.llm import sft_fine_tune, sft_fine_tune_until_converge, get_per_sample_loss
from pe.runner import PESGD


pd.options.mode.copy_on_write = True


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='summary generator')
    parser.add_argument("--method", type=str, default='ReTrain', help="['ReTrain', 'GhostGradDot'] The differential privacy mechanism to use.")
    parser.add_argument('--seed', type=int, default=0, metavar='seed', help='random seed (default: 0)')
    parser.add_argument('--gpu', default=0, type=int, help='gpu device id')
    parser.add_argument('--gen_model_name', type=str, default='gpt-4o-mini-2024-07-18', help="The model for generation, might be an API.")
    parser.add_argument("--fine_tune_model_name", type=str, default='Qwen/Qwen3-4B', help="The model fine-tuned on synthetic dataset.")
    parser.add_argument("--selection_model_name", type=str, default='Qwen/Qwen3-4B', help="The model fine-tuned on private dataset.")
    parser.add_argument("--fine_tune_model_train_iter", type=int, default=10, help="Numbers of iterations for training the target model for fine-tune. Default 10.")
    parser.add_argument("--selection_model_train_iter", type=int, default=10, help="Numbers of iterations for training each selection model. Default 10.")
    parser.add_argument("--setting", type=str, default='LargeGen', help="['SelfGen', 'LargeGen']. SelfGen: The model fine-tuned is the model used as generation; LargeGen: Huge LLM API are used for generation.")

    parser.add_argument("--max_completion_tokens", type=int, default=1024, help="The maximum number of tokens to generate in the response. Should be related to task.")

    parser.add_argument("--dp_mechanism", type=str, default='Gaussian', help="['Exponential', 'Gaussian'] The differential privacy mechanism to use.")
    parser.add_argument("--dp_epsilon", type=float, default=1.0, help="[1,2,4,inf(>=100000000.0)], (epsilon, delta)-DP protection of the histogram for each party, delata=0.0 for exponential mechanism")
    parser.add_argument("--dp_delta", type=float, default=0.0, help="[0.0, 1E-5, 1E-5/(args.steps+1)], (epsilon, delta)-DP protection of the histogram for each party")
    parser.add_argument("--dp_syn_cluster_num", type=int, default=10, help="How many clusters to use for clustering the synthetic data. Default 10.")

    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank, default 8.")
    parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha, default 32.")
    parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout rate, default 0.1")
    parser.add_argument("--target_modules", default=['q_proj', 'v_proj'], nargs='+', type=str)

    parser.add_argument("--variation_api_fold", type=int, default=2, help="How many variations to apply to the initial synthetic data. Default 2, which means N*(2+1) samples will be generated.")

    parser.add_argument("--num_synthetic_per_iter", type=int, default=1000, help="Numbers of synthetic samples to generate per iteration (before api fold). Default 1000.")
    parser.add_argument("--num_iter", type=int, default=10, help="Numbers of iterations for synthetic samples. Default 10.")

    args = parser.parse_args()

    assert args.setting != 'SelfGen', "SelfGen is not supported in this script."

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    exp_folder = f"./results/{args.method}/{args.setting}/text/pubmed/{args.dp_mechanism}_{args.dp_epsilon}_{args.dp_delta}_{args.dp_syn_cluster_num}cluster_{args.variation_api_fold}fold/{args.gen_model_name.split('/')[-1]}/{args.fine_tune_model_name.split('/')[-1]}/{args.selection_model_name.split('/')[-1]}/[{args.num_synthetic_per_iter}]_{args.num_iter}_select{args.selection_model_train_iter}_finetune{args.fine_tune_model_train_iter}/seed{args.seed}/"
    current_folder = os.path.dirname(os.path.abspath(__file__))

    load_dotenv()
    setup_logging(log_file=os.path.join(exp_folder, "log.txt"))

    # load the private dataset
    data = PubMed(root_dir="./data/pubmed", split="train_small") 
    # priv_data_test = PubMed(root_dir="./data/pubmed", split="test") 

    glm = AzureOpenAILLM(max_completion_tokens=1000, model=args.gen_model_name, temperature=1.2, num_threads=4)
    llm = HuggingfaceLLM(max_completion_tokens=args.max_completion_tokens, model_name_or_path=args.fine_tune_model_name, temperature=1.0, device_map=None) # fine-tune model
    slm = HuggingfaceLLM(max_completion_tokens=args.max_completion_tokens, model_name_or_path=args.selection_model_name, temperature=1.0, device_map=None) # selection model
    accelerator = accelerate.Accelerator()

    # glm._model, glm_offload_hook = accelerate.cpu_offload_with_hook(glm._model, execution_device="cuda")
    # glm_offload_hook.offload()
    slm_lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM, 
    )
    slm._model = get_peft_model(slm._model, slm_lora_config)
    slm._model = accelerator.prepare(slm._model)
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    slm._model, slm_offload_hook = accelerate.cpu_offload_with_hook(slm._model, execution_device="cuda")
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    slm_offload_hook.offload()
    print(f"[debugging] slm._model: {slm._model.device=}, {slm._model.dtype=}")
    # for _, module in slm._model.named_modules():
    #     if "lora" in module.__class__.__name__.lower():
    #         module.to(dtype=torch.float16, device=next(slm._model.parameters()).device)
    llm_lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM, 
    )
    llm._model = get_peft_model(llm._model, llm_lora_config)
    llm._model = accelerator.prepare(llm._model)
    llm._model, llm_offload_hook = accelerate.cpu_offload_with_hook(llm._model, execution_device="cuda")
    llm_offload_hook.offload()
    # for _, module in llm._model.named_modules():
    #     if "lora" in module.__class__.__name__.lower():
    #         module.to(dtype=torch.float16, device=next(llm._model.parameters()).device)


    api = LLMAugPE(
        llm=glm,
        random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"),
        variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"),
        min_word_count=25,
        word_count_std=36,
        token_to_word_ratio=5,
        max_completion_tokens_limit=1200,
        blank_probabilities=0.6,
    )
    embedding = SentenceTransformer(model="sentence-t5-base")
    histogram = NearestNeighbors(
        embedding=embedding,
        mode="L2",
        lookahead_degree=0,
    )
    population = PEPopulation(
        # api=api, initial_variation_api_fold=3, next_variation_api_fold=3, keep_selected=True, selection_mode="rank"
        api=api, initial_variation_api_fold=args.variation_api_fold, next_variation_api_fold=args.variation_api_fold, keep_selected=True, selection_mode="random", # selection_mode="random" for gradient version of PE
    )

    save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
    compute_fid = ComputeFID(
        priv_data=data, embedding=embedding, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1}
    )
    compute_precision_recall = ComputePrecisionRecall(
        priv_data=data,
        embedding=embedding,
        filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1},
        num_precision_neighbors=10, # default is 4
        num_recall_neighbors=10, # default is 5
    )
    save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))

    csv_print = CSVPrint(output_folder=exp_folder)
    log_print = LogPrint()

    num_private_samples = len(data.data_frame)
    delta = 1.0 / num_private_samples / np.log(num_private_samples)

    _iter = 0

    # pe_runner = PE(
    #     priv_data=data,
    #     population=population,
    #     histogram=histogram,
    #     callbacks=[save_checkpoints, save_text_to_csv, compute_fid, compute_precision_recall],
    #     loggers=[csv_print, log_print],
    # )
    pe_runner = PESGD(
        priv_data=data,
        population=population,
        histogram=histogram,
        seed=args.seed,
        t_select=args.selection_model_train_iter,  # How many times to train the selection model
        t_fine_tune=args.fine_tune_model_train_iter,  # How many times to train the target model
        exp_folder=exp_folder,
        dp=Gaussian() if args.dp_mechanism.lower() == "gaussian" else Exponential(),
        syn_cluster_num=args.dp_syn_cluster_num,
        loggers=[csv_print, log_print],
        log_print_logger=log_print,
        callbacks=[save_checkpoints, save_text_to_csv, compute_fid, compute_precision_recall],
        llm=llm,
        slm=slm,
        setting=args.setting,
        original_llm_path=args.fine_tune_model_name,
    )
    if args.method == 'GhostGradDot':
        pe_runner.run_dot_gradient(
            # num_samples_schedule=[20] * 11,
            # num_samples_schedule=[2000] * 11,
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
        )
    else:
        pe_runner.run(
            # num_samples_schedule=[20] * 11,
            # num_samples_schedule=[2000] * 11,
            num_samples_schedule=[args.num_synthetic_per_iter] * (args.num_iter+1),
            delta=delta,
            epsilon=args.dp_epsilon,
            checkpoint_path=os.path.join(exp_folder, "checkpoint"),
        )
