import numpy as np
import pandas as pd
import os
from datasets import Dataset, Features, Array3D, ClassLabel
from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
import gc
import copy
import transformers
import accelerate
from peft import LoraConfig, PeftModel, PeftConfig
from safetensors.torch import load_file
# from opacus.validators import ModuleValidator

from pe.dp import Gaussian
from pe.dp import Exponential
from pe.data import Data
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.logging import execution_logger
from pe.llm import sft_fine_tune, sft_fine_tune_until_converge, opacus_dpsgd_fine_tune
from pe.llm import weighted_fine_tune
from pe.llm import evaluate_model_on_private_data, evaluate_model_by_sample
from pe.llm import get_per_sample_loss
from pe.llm import ghost_suite_grad_dot
from pe.llm import get_sample_grad, get_sample_grad_different_noise
from pe.constant.data import IMAGE_DATA_COLUMN_NAME, LABEL_ID_COLUMN_NAME
from sklearn.model_selection import KFold
from trl import PPOTrainer, PPOConfig
from trl import GRPOTrainer, GRPOConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import glob



class PE(object):
    """The class that runs the PE algorithm."""

    def __init__(self, priv_data, population, histogram, dp=None, loggers=[], callbacks=[]):
        """Constructor.

        :param priv_data: The private data
        :type priv_data: :py:class:`pe.data.Data`
        :param population: The population algorithm
        :type population: :py:class:`pe.population.Population`
        :param histogram: The histogram algorithm
        :type histogram: :py:class:`pe.histogram.Histogram`
        :param dp: The DP algorithm, defaults to None, in which case the Gaussian mechanism
            :py:class:`pe.dp.Gaussian` is used
        :type dp: :py:class:`pe.dp.DP`, optional
        :param loggers: The list of loggers, defaults to []
        :type loggers: list[:py:class:`pe.logger.Logger`], optional
        :param callbacks: The list of callbacks, defaults to []
        :type callbacks: list[Callable or :py:class:`pe.callback.Callback`], optional
        """
        super().__init__()
        self._priv_data = priv_data
        self._population = population
        self._histogram = histogram
        if dp is None:
            dp = Gaussian()
        self._dp = dp
        self._loggers = loggers
        self._callbacks = callbacks

    def load_checkpoint(self, checkpoint_path):
        """Load a checkpoint.

        :param checkpoint_path: The path to the checkpoint
        :type checkpoint_path: str
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data` or None
        """
        syn_data = Data()
        if not syn_data.load_checkpoint(checkpoint_path):
            return None
        return syn_data

    def _log_metrics(self, syn_data):
        """Log metrics.

        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        """
        if not self._callbacks:
            return
        metric_items = []
        for callback in self._callbacks:
            metric_items.extend(callback(syn_data) or [])
        for logger in self._loggers:
            logger.log(iteration=syn_data.metadata.iteration, metric_items=metric_items)
        for metric_item in metric_items:
            metric_item.clean_up()

    def _get_num_samples_per_label_id(self, num_samples, fraction_per_label_id):
        """Get the number of samples per label id given the total number of samples

        :param num_samples: The total number of samples
        :type num_samples: int
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :raises ValueError: If the length of fraction_per_label_id is not the same as the number of labels
        :raises ValueError: If the number of samples is so small that the number of samples for some label ids is zero
        :return: The number of samples per label id
        :rtype: np.ndarray
        """
        if fraction_per_label_id is None:
            execution_logger.warning(
                "fraction_per_label_id is not provided. Assuming the fraction of label ids in private data is public "
                "information."
            )
            fraction_per_label_id = self._priv_data.data_frame[LABEL_ID_COLUMN_NAME].value_counts().to_dict()
            fraction_per_label_id = [
                0 if i not in fraction_per_label_id else fraction_per_label_id[i]
                for i in range(len(self._priv_data.metadata.label_info))
            ]
        if len(fraction_per_label_id) != len(self._priv_data.metadata.label_info):
            raise ValueError("fraction_per_label_id should have the same length as the number of labels.")
        fraction_per_label_id = np.array(fraction_per_label_id)
        fraction_per_label_id = fraction_per_label_id / np.sum(fraction_per_label_id)

        target_num_samples_per_label_id = fraction_per_label_id * num_samples
        num_samples_per_label_id = np.floor(target_num_samples_per_label_id).astype(int)
        num_samples_left = num_samples - np.sum(num_samples_per_label_id)
        ids = np.argsort(target_num_samples_per_label_id - num_samples_per_label_id)[::-1]
        num_samples_per_label_id[ids[:num_samples_left]] += 1
        assert np.sum(num_samples_per_label_id) == num_samples
        if np.any(num_samples_per_label_id == 0):
            raise ValueError("num_samples is so small that the number of samples for some label ids is zero.")
        return num_samples_per_label_id

    def _clean_up_loggers(self):
        """Clean up loggers."""
        for logger in self._loggers:
            logger.clean_up()

    def evaluate(self, checkpoint_path):
        """Evaluate the synthetic data.

        :param checkpoint_path: The path to the checkpoint
        :type checkpoint_path: str
        """
        syn_data = self.load_checkpoint(checkpoint_path)
        execution_logger.info(f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}")
        self._log_metrics(syn_data)

    def run(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)

            # Run PE iterations.
            for iteration in range(syn_data.metadata.iteration + 1, len(num_samples_schedule)):
                execution_logger.info(f"PE iteration {iteration}")
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                priv_data_list = []

                # Generate synthetic data for each label.
                for label_id in range(len(self._priv_data.metadata.label_info)):
                    execution_logger.info(f"Label {label_id}")
                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                    sub_syn_data = syn_data.filter_label_id(label_id=label_id)

                    # DP NN histogram.
                    sub_priv_data, sub_syn_data = self._histogram.compute_histogram(
                        priv_data=sub_priv_data, syn_data=sub_syn_data
                    )
                    priv_data_list.append(sub_priv_data)
                    sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)

                    # Generate next population.
                    sub_syn_data = self._population.next(
                        syn_data=sub_syn_data,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    sub_syn_data.set_label_id(label_id)
                    syn_data_list.append(sub_syn_data)

                syn_data = Data.concat(syn_data_list)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration

                new_priv_data = Data.concat(priv_data_list)
                self._priv_data = self._priv_data.merge(new_priv_data)

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
        finally:
            self._clean_up_loggers()

        return syn_data

class PESGD(object):
    """The class that runs the PE algorithm."""

    def __init__(self, priv_data, population, histogram, seed, t_select, t_fine_tune, exp_folder, llm, slm, dp=None, syn_cluster_num=10, loggers=[], log_print_logger=None, callbacks=[], setting="SelfGen", original_llm_path=None, init_data_file='', llm_add_instruction=True, llm_additional_generation=False):
        """Constructor.

        :param priv_data: The private data
        :type priv_data: :py:class:`pe.data.Data`
        :param population: The population algorithm
        :type population: :py:class:`pe.population.Population`
        :param histogram: The histogram algorithm
        :type histogram: :py:class:`pe.histogram.Histogram`
        :param dp: The DP algorithm, defaults to None, in which case the Gaussian mechanism
            :py:class:`pe.dp.Gaussian` is used
        :type dp: :py:class:`pe.dp.DP`, optional
        :param loggers: The list of loggers, defaults to []
        :type loggers: list[:py:class:`pe.logger.Logger`], optional
        :param callbacks: The list of callbacks, defaults to []
        :type callbacks: list[Callable or :py:class:`pe.callback.Callback`], optional
        """
        super().__init__()
        self._priv_data, self._priv_data_test = priv_data[0], priv_data[1]
        self._population = population
        self._histogram = histogram
        if dp is None:
            dp = Gaussian()
        self._dp = dp
        self._loggers = loggers
        self._log_print_logger = log_print_logger
        self._callbacks = callbacks
        self._seed = seed
        self._t_select = t_select
        self._t_fine_tune = t_fine_tune
        self._exp_folder = exp_folder
        self._syn_cluster_num = syn_cluster_num
        self._llm = llm
        self._slm = slm
        self._setting = setting
        self._original_llm_path = original_llm_path
        self._init_data_file = init_data_file
        self._llm_add_instruction = llm_add_instruction
        self._llm_additional_generation = llm_additional_generation
        print(f"llm SFT with{'out' if not self._llm_add_instruction else ''} instruction")
        execution_logger.info(f"llm SFT with{'out' if not self._llm_add_instruction else ''} instruction")
        assert self._original_llm_path != None, "original_llm_path should not be None, please set it to the path of the original LLM model."

        self.private_data_preparation() # perpare the private data for training
    
    def _get_images_and_label_from_data(self, data):
        """Getting images and labels from the data.

        :param data: The data object
        :type data: :py:class:`pe.data.Data`
        :return: The images and labels
        :rtype: tuple[np.ndarray, np.ndarray]
        """
        if data is None:
            return None, None
        else:
            images = np.stack(data[IMAGE_DATA_COLUMN_NAME].values)
            images = images.transpose((0, 3, 1, 2)) / 255.0
            images = [img for img in images]
            # print(f"{type(images)=}, {len(images)=}, {type(images[0])=}, {images[0].shape=}")
            labels = np.array(data[LABEL_ID_COLUMN_NAME].values)
            features = Features({
                "image": Array3D(dtype="float16", shape=images[0].shape),
                "labels": ClassLabel(names=[str(i) for i in range(len(np.unique(labels)))]),
            })
            return images, labels, features

    def private_data_preparation(self, test_size_ratio=0.0):
        """Prepare the private data for the PE algorithm."""
        print(f"{self._priv_data.data_frame.columns=}")
        if 'PE.TEXT' in self._priv_data.data_frame.columns:
            if test_size_ratio > 0.0:
                # Split private data into train and dev sets
                train_df, dev_df = train_test_split(self._priv_data.data_frame, test_size=test_size_ratio, random_state=self._seed, shuffle=True)
                self._priv_train_data = Dataset.from_pandas(train_df.reset_index(drop=True))
                self._priv_dev_data = Dataset.from_pandas(dev_df.reset_index(drop=True))
                # self._priv_eval_data = Dataset.from_pandas(self._priv_data.data_frame.reset_index(drop=True))
                self._priv_eval_data = Dataset.from_pandas(self._priv_data_test.data_frame.reset_index(drop=True))
                if not 'text' in self._priv_train_data.column_names:
                    self._priv_train_data = self._priv_train_data.add_column("text", self._priv_train_data['PE.TEXT'])
                    self._priv_dev_data = self._priv_dev_data.add_column("text", self._priv_dev_data['PE.TEXT'])
                    self._priv_eval_data = self._priv_eval_data.add_column("text", self._priv_eval_data['PE.TEXT'])
            else:
                self._priv_train_data = Dataset.from_pandas(self._priv_data.data_frame.reset_index(drop=True))
                self._priv_dev_data = None
                # self._priv_eval_data = Dataset.from_pandas(self._priv_data.data_frame.reset_index(drop=True))
                self._priv_eval_data = Dataset.from_pandas(self._priv_data_test.data_frame.reset_index(drop=True))
                if not 'text' in self._priv_train_data.column_names:
                    self._priv_train_data = self._priv_train_data.add_column("text", self._priv_train_data['PE.TEXT'])
                    self._priv_dev_data = None
                    self._priv_eval_data = self._priv_eval_data.add_column("text", self._priv_eval_data['PE.TEXT'])
        elif 'PE.IMAGE' in self._priv_data.data_frame.columns:
            if test_size_ratio > 0.0:
                # Split private data into train and dev sets
                images, labels, features = self._get_images_and_label_from_data(train_df)
                self._priv_train_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                images, labels, features = self._get_images_and_label_from_data(dev_df)
                self._priv_dev_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                images, labels, features = self._get_images_and_label_from_data(self._priv_data_test.data_frame)
                self._priv_eval_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                if not 'image' in self._priv_train_data.column_names:
                    self._priv_train_data = self._priv_train_data.add_column("image", self._priv_train_data['PE.IMAGE'])
                    self._priv_dev_data = self._priv_dev_data.add_column("image", self._priv_dev_data['PE.IMAGE'])
                    self._priv_eval_data = self._priv_eval_data.add_column("image", self._priv_eval_data['PE.IMAGE'])
                # if not 'text' in self._priv_train_data.column_names:
                #     self._priv_train_data = self._priv_train_data.add_column("text", self._priv_train_data['image'])
                #     self._priv_dev_data = self._priv_dev_data.add_column("text", self._priv_dev_data['image'])
                #     self._priv_eval_data = self._priv_eval_data.add_column("text", self._priv_eval_data['image'])
            else:
                images, labels, features = self._get_images_and_label_from_data(self._priv_data.data_frame)
                self._priv_train_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                self._priv_dev_data = None
                images, labels, features = self._get_images_and_label_from_data(self._priv_data_test.data_frame)
                self._priv_eval_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                # print(f"{self._priv_eval_data.column_names=}")
                if not 'image' in self._priv_train_data.column_names:
                    self._priv_train_data = self._priv_train_data.add_column("image", self._priv_train_data['PE.IMAGE'])
                    self._priv_dev_data = None
                    self._priv_eval_data = self._priv_eval_data.add_column("image", self._priv_eval_data['PE.IMAGE'])
                # print(f"{self._priv_eval_data.column_names=}")
                # if not 'text' in self._priv_train_data.column_names:
                #     self._priv_train_data = self._priv_train_data.add_column("text", self._priv_train_data['image'])
                #     self._priv_dev_data = None
                #     self._priv_eval_data = self._priv_eval_data.add_column("text", self._priv_eval_data['image'])

    def syn_data_preparation(self, syn_data, test_size_ratio=0.0):
        if 'PE.TEXT' in syn_data.data_frame.columns:
            if test_size_ratio > 0.0:
                train_df, dev_df = train_test_split(syn_data, test_size=test_size_ratio, random_state=self._seed, shuffle=True)
                self._syn_train_data = Dataset.from_pandas(train_df.reset_index(drop=True))
                self._syn_dev_data = Dataset.from_pandas(dev_df.reset_index(drop=True))
                # self._syn_eval_data = Dataset.from_pandas(syn_data.reset_index(drop=True))
                if not 'text' in self._syn_train_data.column_names:
                    self._syn_train_data = self._syn_train_data.add_column("text", self._syn_train_data['PE.TEXT'])
                    self._syn_dev_data = self._syn_dev_data.add_column("text", self._syn_dev_data['PE.TEXT'])
                    # self._syn_eval_data = self._syn_eval_data.add_column("text", self._syn_eval_data['PE.TEXT'])
            else:
                self._syn_train_data = Dataset.from_pandas(syn_data.data_frame.reset_index(drop=True))
                self._syn_dev_data = None
                # self._syn_eval_data = Dataset.from_pandas(syn_data.data_frame.reset_index(drop=True))
                self._syn_eval_data = None
                if not 'text' in self._syn_train_data.column_names:
                    self._syn_train_data = self._syn_train_data.add_column("text", self._syn_train_data['PE.TEXT'])
                    self._syn_dev_data = None
                    # self._syn_eval_data = self._syn_eval_data.add_column("text", self._syn_eval_data['PE.TEXT'])
                    self._syn_eavl_data = None
        elif 'PE.IMAGE' in syn_data.data_frame.columns:
            if test_size_ratio > 0.0:
                train_df, dev_df = train_test_split(syn_data.data_frame, test_size=test_size_ratio, random_state=self._seed, shuffle=True)
                images, labels, features = self._get_images_and_label_from_data(train_df)
                self._syn_train_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                images, labels, features = self._get_images_and_label_from_data(dev_df)
                self._syn_dev_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                if not 'image' in self._syn_train_data.column_names:
                    self._syn_train_data = self._syn_train_data.add_column("image", self._syn_train_data['PE.IMAGE'])
                    self._syn_dev_data = self._syn_dev_data.add_column("image", self._syn_dev_data['PE.IMAGE'])
                    # self._syn_eval_data = self._syn_eval_data.add_column("image", self._syn_eval_data['PE.IMAGE'])
                # if not 'text' in self._syn_train_data.column_names:
                #     self._syn_train_data = self._syn_train_data.add_column("text", self._syn_train_data['image'])
                #     self._syn_dev_data = self._syn_dev_data.add_column("text", self._syn_dev_data['image'])
                #     # self._syn_eval_data = self._syn_eval_data.add_column("text", self._syn_eval_data['image'])
            else:
                images, labels, features = self._get_images_and_label_from_data(syn_data.data_frame)
                self._syn_train_data = Dataset.from_dict({"image": images, "labels": labels}, features=features) if images is not None else None #, features=features
                self._syn_dev_data = None
                # self._syn_eval_data = Dataset.from_pandas(syn_data.data_frame.reset_index(drop=True))
                if not 'image' in self._syn_train_data.column_names:
                    self._syn_train_data = self._syn_train_data.add_column("image", self._syn_train_data['PE.IMAGE'])
                    self._syn_dev_data = None
                    # self._syn_eval_data = self._syn_eval_data.add_column("image", self._syn_eval_data['PE.IMAGE'])
                    self._syn_eavl_data = None
                # if not 'text' in self._syn_train_data.column_names:
                #     self._syn_train_data = self._syn_train_data.add_column("text", self._syn_train_data['image'])
                #     self._syn_dev_data = None
                #     # self._syn_eval_data = self._syn_eval_data.add_column("text", self._syn_eval_data['image'])
                #     self._syn_eavl_data = None

    def syn_data_partition_and_preparation(self, syn_data, test_size_ratio=0.05):
        """Partition the synthetic data into training and development sets, and prepare them for training.

        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        :return: The partitioned synthetic training data and development data
        :rtype: tuple[:py:class:`datasets.Dataset`, :py:class:`datasets.Dataset`]
        """
        self._syn_train_data_list = []
        _syn_df_list = []
        self._syn_dev_data = None

        if test_size_ratio > 0:
            train_df, dev_df = train_test_split(syn_data.data_frame, test_size=0.05, random_state=self._seed, shuffle=True)
            train_df = train_df.reset_index(drop=True)
            dev_df = dev_df.reset_index(drop=True)
        else:
            train_df = syn_data.data_frame.reset_index(drop=True)
            dev_df = None

        execution_logger.info(f"splitting the synthetic data into {self._syn_cluster_num=}*{self._population._initial_variation_api_fold+1=} folds for training")
        print(f"[debugging] {type(syn_data)=}, {type(self._priv_data)=}")
        kf = KFold(n_splits=self._syn_cluster_num*(self._population._initial_variation_api_fold+1), shuffle=True, random_state=self._seed)
        splits = list(kf.split(train_df))
        for _, dev_index in splits:
            # _syn_df = train_df.iloc[dev_index].reset_index(drop=True)
            # _syn_data = Dataset.from_pandas(_syn_df)
            # self._syn_train_data_list.append(_syn_data)
            _syn_df = train_df.iloc[dev_index]
            # TODO: so ugly a way to create a new Data object, but it works
            _syn_data_obj = copy.deepcopy(self._priv_data)
            _syn_data_obj.data_frame = _syn_df
            _syn_df_list.append(Data.concat([_syn_data_obj], metadata=self._priv_data.metadata))
            # print(f"[debugging] type(_syn_df_list items): {type(_syn_data_obj)=}")
            self._syn_train_data_list.append(_syn_df)

        self._syn_train_data_list = [Dataset.from_pandas(train_df.reset_index(drop=True)) for train_df in self._syn_train_data_list]
        self._syn_dev_data = Dataset.from_pandas(dev_df.reset_index(drop=True)) if dev_df is not None else None
        if not 'text' in self._syn_train_data_list[0].column_names:
            self._syn_train_data_list = [syn_train_data.add_column("text", syn_train_data['PE.TEXT']) for syn_train_data in self._syn_train_data_list]
            self._syn_dev_data = self._syn_dev_data.add_column("text", self._syn_dev_data['PE.TEXT']) if self._syn_dev_data is not None else None


        return _syn_df_list
        # return self._syn_train_data_list, self._syn_dev_data

    def _get_num_samples_per_label_id(self, num_samples, fraction_per_label_id):
        """Get the number of samples per label id given the total number of samples

        :param num_samples: The total number of samples
        :type num_samples: int
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :raises ValueError: If the length of fraction_per_label_id is not the same as the number of labels
        :raises ValueError: If the number of samples is so small that the number of samples for some label ids is zero
        :return: The number of samples per label id
        :rtype: np.ndarray
        """
        if fraction_per_label_id is None:
            execution_logger.warning(
                "fraction_per_label_id is not provided. Assuming the fraction of label ids in private data is public "
                "information."
            )
            fraction_per_label_id = self._priv_data.data_frame[LABEL_ID_COLUMN_NAME].value_counts().to_dict()
            fraction_per_label_id = [
                0 if i not in fraction_per_label_id else fraction_per_label_id[i]
                for i in range(len(self._priv_data.metadata.label_info))
            ]
        if len(fraction_per_label_id) != len(self._priv_data.metadata.label_info):
            raise ValueError("fraction_per_label_id should have the same length as the number of labels.")
        fraction_per_label_id = np.array(fraction_per_label_id)
        fraction_per_label_id = fraction_per_label_id / np.sum(fraction_per_label_id)

        target_num_samples_per_label_id = fraction_per_label_id * num_samples
        num_samples_per_label_id = np.floor(target_num_samples_per_label_id).astype(int)
        num_samples_left = num_samples - np.sum(num_samples_per_label_id)
        ids = np.argsort(target_num_samples_per_label_id - num_samples_per_label_id)[::-1]
        num_samples_per_label_id[ids[:num_samples_left]] += 1
        assert np.sum(num_samples_per_label_id) == num_samples
        if np.any(num_samples_per_label_id == 0):
            raise ValueError("num_samples is so small that the number of samples for some label ids is zero.")
        return num_samples_per_label_id

    def _clean_up_loggers(self):
        """Clean up loggers."""
        for logger in self._loggers:
            logger.clean_up()

    def evaluate(self, checkpoint_path):
        """Evaluate the synthetic data.

        :param checkpoint_path: The path to the checkpoint
        :type checkpoint_path: str
        """
        syn_data = self.load_checkpoint(checkpoint_path)
        execution_logger.info(f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}")
        self._log_metrics(syn_data)

    def load_checkpoint(self, checkpoint_path):
        """Load a checkpoint.

        :param checkpoint_path: The path to the checkpoint
        :type checkpoint_path: str
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data` or None
        """
        syn_data = Data()
        if not syn_data.load_checkpoint(checkpoint_path):
            return None
        return syn_data

    def _log_metrics(self, syn_data):
        """Log metrics.

        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        """
        if not self._callbacks:
            return
        metric_items = []
        for callback in self._callbacks:
            metric_items.extend(callback(syn_data) or [])
        for logger in self._loggers:
            logger.log(iteration=syn_data.metadata.iteration, metric_items=metric_items)
        for metric_item in metric_items:
            metric_item.clean_up()

    def run(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                execution_logger.info(f"PE initial data generated with {len(syn_data.data_frame)} samples [finished].")
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")

            # Run PE iterations.
            for iteration in range(syn_data.metadata.iteration + 1, len(num_samples_schedule)):
                execution_logger.info(f"PE iteration {iteration}")

                print(f"[debugging] {syn_data.metadata.iteration} -> {iteration}, {type(syn_data)}")
                
                
                origianl_slm_per_sample_losses = get_per_sample_loss(
                    model=self._slm._model, tokenizer=self._slm._tokenizer, dataset=self._priv_train_data, batch_size=8,
                )
                print(f"[debugging] original slm per sample losses: {origianl_slm_per_sample_losses=}")
                

                syn_data_df_list = self.syn_data_partition_and_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.
                # self._syn_train_data_list, self._syn_dev_data = self.syn_data_partition_and_preparation(syn_data=syn_data)
                selection_model_list = []
                per_sample_losses_per_model = []

                
                for _i, syn_train_data in  enumerate(self._syn_train_data_list):             
                    logging_dir = self._exp_folder + f"/slm_syn_train/{iteration}/slm#{_i}/"
                    if not os.path.exists(logging_dir):
                        os.makedirs(logging_dir)
                    # # glm._model = glm._model.to('cpu') 
                    # self._llm._model = self._llm._model.to('cpu')
                    # torch.cuda.empty_cache()
                    # gc.collect()

                    if iteration <= 1:
                        # For the first iteration, we use the original selection model
                        slm_for_selection = copy.deepcopy(self._slm._model)
                    else:
                        load_model_path = self._exp_folder + f"/slm_syn_train/{iteration-1}/llm/"
                        if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                            print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                            state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                            slm_for_selection.load_state_dict(state_dict)
                        else:
                            # For the subsequent iterations, we load the selection model from the previous iteration
                            base = transformers.AutoModelForCausalLM.from_pretrained(
                                self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                            )
                            peft_config = PeftConfig.from_pretrained(load_model_path)
                            slm_for_selection = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                    # TODO: number of training epochs should be hyper-parameter
                    slm_sft_eval_metric, slm_trained_model, _ = sft_fine_tune(
                        slm_for_selection, self._slm._tokenizer, syn_train_data, output_dir=logging_dir,
                        per_device_train_batch_size=8, num_train_epochs=self._t_select, learning_rate=5e-6,
                        save_steps=100, logging_steps=1,
                    )
                    # slm_sft_eval_metric, slm._model, slm._tokenizer = sft_fine_tune_until_converge(
                    #     slm._model, slm._tokenizer, syn_train_data, output_dir=logging_dir,
                    #     per_device_train_batch_size=8, max_epochs=self._t_select*2, learning_rate=5e-5,
                    #     save_steps=100, logging_steps=5,
                    #     min_delta=0.001, patience=3,
                    # )
                    self._log_print_logger.single_log("info", f"Selection model fine-tuned on synthetic dataset split#{_i} in Iteration#{iteration} with evaluation metric: {slm_sft_eval_metric}")
                    per_sample_losses_per_model.append(
                        get_per_sample_loss(
                            model=slm_trained_model, tokenizer=self._slm._tokenizer, dataset=self._priv_train_data, batch_size=8,
                        )
                    )

                print(f"[debugging] per_sample_losses_per_model (absolute loss): {len(per_sample_losses_per_model)=}")
                # print(f"[debugging] per_sample_losses_per_model (absolute loss): {per_sample_losses_per_model=}")
                
                per_sample_losses_per_model = [origianl_slm_per_sample_losses - losses for losses in per_sample_losses_per_model]

                print(f"[debugging] per_sample_losses_per_model (loss decrease): {len(per_sample_losses_per_model)=}, {per_sample_losses_per_model[0].shape=}")
                # print(f"[debugging] per_sample_losses_per_model (loss decrease): {per_sample_losses_per_model=}")

                # ############# [discarded] for exponential mechanism #############
                # per_dev_model_losses = [torch.sum(_item) for _item in per_sample_losses_per_model]
                # print(f"[debugging] per_dev_model_losses: {len(per_dev_model_losses)=}")
                # print(f"[debugging] per_dev_model_losses: {per_dev_model_losses=}")
                # per_dev_model_losses = F.softmax(torch.tensor(per_dev_model_losses), dim=0) # control the utility of the voting
                # # per_dev_model_prob = F.softmax(per_dev_model_losses*epsilon/(2*1), dim=0)
                # # print(f"[debugging] per_dev_model_prob: {per_dev_model_prob=}")
                # ############# [discarded] for exponential mechanism #############


                # ############# step1, RL on the llm with negative loss as negative reward #############
                # L2 normalize per_dev_model_losses along dimension 1 (i.e., across models for each sample)
                per_sample_losses_per_model = torch.stack(per_sample_losses_per_model, dim=0) # shape: (num_models, num_private_samples)
                print(f"[debugging] per_sample_losses_per_model (stacked): {per_sample_losses_per_model.shape=}, {per_sample_losses_per_model=}")
                norms = per_sample_losses_per_model.norm(p=2, dim=0, keepdim=True)
                norms = torch.where(norms == 0, torch.ones_like(norms), norms)  # avoid division by zero
                _per_sample_normed_losses_per_model = per_sample_losses_per_model / norms
                print(f"[debugging] per_dev_model_losses after L_2 normalization: {_per_sample_normed_losses_per_model=}")
                # Sum each row of per_dev_model_losses
                per_dev_model_prob_for_training = _per_sample_normed_losses_per_model.sum(dim=1) # this aligns with per private sample cluster
                print(f"[debugging] per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                execution_logger.info(f"[iteration {iteration}] per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                # DP.add_noise
                per_dev_model_prob_for_training = per_dev_model_prob_for_training + np.random.normal(scale=self._dp._noise_multiplier, size=len(per_dev_model_prob_for_training)) 
                print(f"[debugging] after adding noise, per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                execution_logger.info(f"[iteration {iteration}] after adding noise, per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                # Normalize per_dev_model_prob_for_training with its mean and standard deviation
                mean = per_dev_model_prob_for_training.mean()
                std = per_dev_model_prob_for_training.std(unbiased=False)
                if std == 0:
                    std = 1.0  # avoid division by zero
                per_dev_model_prob_for_training = (per_dev_model_prob_for_training - mean) / std
                # per_dev_model_prob_for_training = torch.tensor([1.0/len(self._syn_train_data_list) for _ in range(len(self._syn_train_data_list))]).to(next(self._llm._model.parameters()).device)  # uniform distribution for now, to be replaced by the actual probabilities
                
                # TODO: train LLM using the synthetic data and the related probabilities
                # per_dev_model_prob_for_training should be used to RL the llm model
                # Prepare RL training data: combine all synthetic train datasets and assign rewards
                # Each syn_train_data in self._syn_train_data_list corresponds to a cluster, and its reward is per_dev_model_prob_for_training[_i]
                if iteration > 1:
                    load_model_path = self._exp_folder + f"/slm_syn_train/{iteration-1}/llm/"
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        self._llm._model.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                all_rl_samples = []
                all_rl_rewards = []
                for _i, syn_train_data in enumerate(self._syn_train_data_list):
                    # Convert to pandas for easier handling
                    df = syn_train_data.to_pandas()
                    reward = per_dev_model_prob_for_training[_i].item() if hasattr(per_dev_model_prob_for_training[_i], 'item') else float(per_dev_model_prob_for_training[_i])
                    all_rl_samples.append(df)
                    all_rl_rewards.extend([reward] * len(df))
                # Concatenate all samples
                rl_df = pd.concat(all_rl_samples, ignore_index=True)
                # If needed, convert to Dataset
                rl_dataset = Dataset.from_pandas(rl_df.reset_index(drop=True))
                # Add rewards as a new column
                rl_dataset = rl_dataset.add_column("reward", all_rl_rewards)
                all_rl_rewards = torch.tensor(all_rl_rewards, dtype=torch.float16).to(next(self._llm._model.parameters()).device)  # Ensure rewards are on the same device as the model
                # Now train self._llm using RL with (rl_dataset, reward)
                print(f"[debugging] rl_dataset: {rl_dataset.shape=}, {rl_dataset.column_names=}, {rl_dataset[0]=}")
                # rl_config = GRPOConfig(
                #     learning_rate=5e-5,
                #     batch_size=8,
                #     mini_batch_size=8,
                #     log_with=None,
                #     optimize_cuda_cache=True,
                #     project_kwargs={"output_dir": os.path.join(self._exp_folder, f"llm_ppo/{iteration}/")},
                #     epochs=1,
                # )
                # rl_trainer = GRPOTrainer(
                #     config=rl_config,
                #     model=self._llm._model,
                #     tokenizer=self._llm._tokenizer,
                #     train_dataset=rl_dataset,
                #     eval_dataset=self._priv_train_data,
                #     reward_column="reward",
                # )
                # rl_trainer.train()
                logging_dir = self._exp_folder + f"/slm_syn_train/{iteration}/llm/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)
                llm_sft_eval_metric, llm_trained_model, _ = weighted_fine_tune(
                    self._llm._model, self._llm._tokenizer, rl_dataset, weight=all_rl_rewards, output_dir=logging_dir,
                    per_device_train_batch_size=8, num_train_epochs=self._t_fine_tune, learning_rate=5e-6,
                    save_steps=100, logging_steps=1,
                )
                self._llm._model = llm_trained_model
                self._slm._model = llm_trained_model
                if "SelfGen" in self._setting:
                    self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                elif "LargeGen" in self._setting:
                    if self._llm_additional_generation == True:
                        _original_glm = copy.deepcopy(self._population._api._llm._model)
                        _original_glm_tokenizer = copy.deepcopy(self._population._api._llm._tokenizer)
                        self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                        self._population._api._llm._tokenizer = self._llm._tokenizer
                        
                        _num_samples_per_label_id = self._get_num_samples_per_label_id(
                            num_samples=num_samples_schedule[0],
                            fraction_per_label_id=fraction_per_label_id,
                        )
                        _syn_data_list = []
                        for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                            _syn_data = self._population.initial(
                                label_info=_label_info,
                                num_samples=_num_samples_per_label_id[_label_id],
                            )
                            _syn_data.set_label_id(_label_id)
                            _syn_data_list.append(_syn_data)
                        _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                        _syn_data.data_frame.reset_index(drop=True, inplace=True)
                        _syn_data.metadata.iteration = iteration + (len(num_samples_schedule)+1)*2
                        self._log_metrics(_syn_data)
                        execution_logger.info(f"LLM generated data within [{iteration=}] with {len(_syn_data.data_frame)} samples [finished].")
                        # if self._init_data_file != '':
                        #     _syn_data.save_checkpoint(self._init_data_file)
                        #     print(f"save PE initial data to {self._init_data_file}")
                        self._population._api._llm._model = copy.deepcopy(_original_glm)
                        self._population._api._llm._tokenizer = copy.deepcopy(_original_glm_tokenizer)

                # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_train_data,
                    batch_size=8,
                )
                print(f"Evaluation on train private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                execution_logger.info(f"LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                )
                print(f"Evaluation on test private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                execution_logger.info(f"LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                # ############# step1, RL on the llm with negative loss as negative reward #############


                # ############# step2, cluster selection for next generation variation by treating negative as zero #############
                per_dev_model_prob_for_selection = F.softmax(per_dev_model_prob_for_training, dim=0).cpu().numpy() # make everything positive using softmax
                print(f"[debugging] per_dev_model_prob_for_selection: {per_dev_model_prob_for_selection=}, {len(per_dev_model_prob_for_selection)=}")
                execution_logger.info(f"[iteration {iteration}] per_dev_model_prob_for_selection: {per_dev_model_prob_for_selection=}, {len(per_dev_model_prob_for_selection)=}")
                # Choose #self._syn_cluster_num indices randomly based on the probability given above
                selected_cluster_indices = np.random.choice(range(len(per_dev_model_prob_for_selection)), size=self._syn_cluster_num, p=per_dev_model_prob_for_selection, replace=False)
                print(f"[debugging] selected_cluster_indices: {selected_cluster_indices=}, {len(selected_cluster_indices)=}")
                execution_logger.info(f"[iteration {iteration}] selected_cluster_indices: {selected_cluster_indices=}, {len(selected_cluster_indices)=}")
                selected_data_cluster_df_list = [syn_data_df_list[_i] for _i in selected_cluster_indices]
                # ############# step2, cluster selection for next generation variation by treating negative as zero #############

                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                # priv_data_list = []

                # Generate synthetic data for each label.
                for label_id in range(len(self._priv_data.metadata.label_info)):
                    execution_logger.info(f"Label {label_id}")
                    for _i, _syn_data_df in enumerate(selected_data_cluster_df_list):
                        execution_logger.info(f"Selected original cluster #{selected_cluster_indices[_i]} for next generation")
                        sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                        sub_syn_data = _syn_data_df.filter_label_id(label_id=label_id)
                        print(f"[debugging] checking sub_syn_data len for {label_id=}: {len(sub_syn_data.data_frame)=}")

                        # # DP NN histogram.
                        # sub_priv_data, sub_syn_data = self._histogram.compute_histogram(
                        #     priv_data=sub_priv_data, syn_data=sub_syn_data
                        # )
                        # priv_data_list.append(sub_priv_data)
                        # sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)

                        # Generate next population.
                        # sub_syn_data = self._population.next(
                        #     syn_data=sub_syn_data,
                        #     num_samples=int(num_samples_per_label_id[label_id]*per_dev_model_prob[_i].item()),
                        # )
                        sub_syn_data = self._population.next(
                            syn_data=sub_syn_data,
                            num_samples=len(sub_syn_data.data_frame), # for each label, next generation will generate self._population._initial_variation_api_fold variantions as duplications
                            selected=True,
                        )
                        sub_syn_data.set_label_id(label_id)
                        syn_data_list.append(sub_syn_data)

                # syn_data = Data.concat([syn_data] + syn_data_list)
                # syn_data = self.syn_data_concat_and_shuffle(
                #     selected_data_cluster_df_list=selected_data_cluster_df_list,
                #     syn_data_list=syn_data_list,
                #     shuffle=True,
                #     shuffle_all=True,
                # )
                if not self._population._keep_selected:
                    syn_data = Data.concat(selected_data_cluster_df_list + syn_data_list)
                else:
                    syn_data = Data.concat(syn_data_list)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration
                print(f"[debugging] syn_data after concat: {len(syn_data.data_frame)=} with {syn_data.metadata.iteration=}")

                # new_priv_data = Data.concat(priv_data_list)
                # self._priv_data = self._priv_data.merge(new_priv_data)

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
        finally:
            self._clean_up_loggers()

        return syn_data


    def run_dot_gradient(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                execution_logger.info(f"PE initial data generated with {len(syn_data.data_frame)} samples [finished].")
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")

            # Run PE iterations.
            for iteration in range(syn_data.metadata.iteration + 1, len(num_samples_schedule)):
                execution_logger.info(f"PE iteration {iteration}")

                print(f"[debugging] ghostsuite gradient dot for {syn_data.metadata.iteration} -> {iteration}, {type(syn_data)}")
                
                self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.
                
                logging_dir = self._exp_folder + f"/slm_syn_ghost_dot_grad/{iteration}/slm/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)

                if iteration <= 1:
                    # For the first iteration, we use the original selection model
                    slm_for_selection = copy.deepcopy(self._slm._model)
                else:
                    load_model_path = self._exp_folder + f"/slm_syn_train/{iteration-1}/llm/"
                    # For the subsequent iterations, we load the selection model from the previous iteration
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        slm_for_selection.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        slm_for_selection = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                
                # TODO: number of training epochs should be hyper-parameter
                grad_dot_results = ghost_suite_grad_dot(
                    slm_for_selection, self._slm._tokenizer, self._syn_train_data, self._priv_train_data, 
                    output_dir=logging_dir, 
                    per_device_train_batch_size=8, num_train_epochs=1, learning_rate=5e-6,
                    save_steps=100, logging_steps=1,
                )
                # self._log_print_logger.single_log("info", f"Selection model fine-tuned on synthetic dataset split#{_i} in Iteration#{iteration} with evaluation metric: {slm_sft_eval_metric}")

                print(f"[debugging] per_sample_losses_per_model (absolute loss): {len(per_sample_losses_per_model)=}")
                # print(f"[debugging] per_sample_losses_per_model (absolute loss): {per_sample_losses_per_model=}")
                
                per_sample_losses_per_model = [origianl_slm_per_sample_losses - losses for losses in per_sample_losses_per_model]

                print(f"[debugging] per_sample_losses_per_model (loss decrease): {len(per_sample_losses_per_model)=}, {per_sample_losses_per_model[0].shape=}")
                # print(f"[debugging] per_sample_losses_per_model (loss decrease): {per_sample_losses_per_model=}")

                # ############# [discarded] for exponential mechanism #############
                # per_dev_model_losses = [torch.sum(_item) for _item in per_sample_losses_per_model]
                # print(f"[debugging] per_dev_model_losses: {len(per_dev_model_losses)=}")
                # print(f"[debugging] per_dev_model_losses: {per_dev_model_losses=}")
                # per_dev_model_losses = F.softmax(torch.tensor(per_dev_model_losses), dim=0) # control the utility of the voting
                # # per_dev_model_prob = F.softmax(per_dev_model_losses*epsilon/(2*1), dim=0)
                # # print(f"[debugging] per_dev_model_prob: {per_dev_model_prob=}")
                # ############# [discarded] for exponential mechanism #############


                # ############# step1, RL on the llm with negative loss as negative reward #############
                # L2 normalize per_dev_model_losses along dimension 1 (i.e., across models for each sample)
                per_sample_losses_per_model = torch.stack(per_sample_losses_per_model, dim=0) # shape: (num_models, num_private_samples)
                print(f"[debugging] per_sample_losses_per_model (stacked): {per_sample_losses_per_model.shape=}, {per_sample_losses_per_model=}")
                norms = per_sample_losses_per_model.norm(p=2, dim=0, keepdim=True)
                norms = torch.where(norms == 0, torch.ones_like(norms), norms)  # avoid division by zero
                _per_sample_normed_losses_per_model = per_sample_losses_per_model / norms
                print(f"[debugging] per_dev_model_losses after L_2 normalization: {_per_sample_normed_losses_per_model=}")
                # Sum each row of per_dev_model_losses
                per_dev_model_prob_for_training = _per_sample_normed_losses_per_model.sum(dim=1) # this aligns with per private sample cluster
                print(f"[debugging] per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                execution_logger.info(f"[iteration {iteration}] per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                # DP.add_noise
                per_dev_model_prob_for_training = per_dev_model_prob_for_training + np.random.normal(scale=self._dp._noise_multiplier, size=len(per_dev_model_prob_for_training)) 
                print(f"[debugging] after adding noise, per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                execution_logger.info(f"[iteration {iteration}] after adding noise, per_dev_model_prob_for_training: {per_dev_model_prob_for_training.shape=}, {per_dev_model_prob_for_training=}")
                # Normalize per_dev_model_prob_for_training with its mean and standard deviation
                mean = per_dev_model_prob_for_training.mean()
                std = per_dev_model_prob_for_training.std(unbiased=False)
                if std == 0:
                    std = 1.0  # avoid division by zero
                per_dev_model_prob_for_training = (per_dev_model_prob_for_training - mean) / std
                # per_dev_model_prob_for_training = torch.tensor([1.0/len(self._syn_train_data_list) for _ in range(len(self._syn_train_data_list))]).to(next(self._llm._model.parameters()).device)  # uniform distribution for now, to be replaced by the actual probabilities
                
                # TODO: train LLM using the synthetic data and the related probabilities
                # per_dev_model_prob_for_training should be used to RL the llm model
                # Prepare RL training data: combine all synthetic train datasets and assign rewards
                # Each syn_train_data in self._syn_train_data_list corresponds to a cluster, and its reward is per_dev_model_prob_for_training[_i]
                if iteration > 1:
                    load_model_path = self._exp_folder + f"/slm_syn_train/{iteration-1}/llm/"
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        self._llm._model.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                all_rl_samples = []
                all_rl_rewards = []
                for _i, syn_train_data in enumerate(self._syn_train_data_list):
                    # Convert to pandas for easier handling
                    df = syn_train_data.to_pandas()
                    reward = per_dev_model_prob_for_training[_i].item() if hasattr(per_dev_model_prob_for_training[_i], 'item') else float(per_dev_model_prob_for_training[_i])
                    all_rl_samples.append(df)
                    all_rl_rewards.extend([reward] * len(df))
                # Concatenate all samples
                rl_df = pd.concat(all_rl_samples, ignore_index=True)
                # If needed, convert to Dataset
                rl_dataset = Dataset.from_pandas(rl_df.reset_index(drop=True))
                # Add rewards as a new column
                rl_dataset = rl_dataset.add_column("reward", all_rl_rewards)
                all_rl_rewards = torch.tensor(all_rl_rewards, dtype=torch.float16).to(next(self._llm._model.parameters()).device)  # Ensure rewards are on the same device as the model
                # Now train self._llm using RL with (rl_dataset, reward)
                print(f"[debugging] rl_dataset: {rl_dataset.shape=}, {rl_dataset.column_names=}, {rl_dataset[0]=}")
                # rl_config = GRPOConfig(
                #     learning_rate=5e-5,
                #     batch_size=8,
                #     mini_batch_size=8,
                #     log_with=None,
                #     optimize_cuda_cache=True,
                #     project_kwargs={"output_dir": os.path.join(self._exp_folder, f"llm_ppo/{iteration}/")},
                #     epochs=1,
                # )
                # rl_trainer = GRPOTrainer(
                #     config=rl_config,
                #     model=self._llm._model,
                #     tokenizer=self._llm._tokenizer,
                #     train_dataset=rl_dataset,
                #     eval_dataset=self._priv_eval_data,
                #     reward_column="reward",
                # )
                # rl_trainer.train()
                logging_dir = self._exp_folder + f"/slm_syn_train/{iteration}/llm/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)
                llm_sft_eval_metric, llm_trained_model, _ = weighted_fine_tune(
                    self._llm._model, self._llm._tokenizer, rl_dataset, weight=all_rl_rewards, output_dir=logging_dir,
                    per_device_train_batch_size=8, num_train_epochs=self._t_fine_tune, learning_rate=5e-6,
                    save_steps=100, logging_steps=1,
                )
                self._llm._model = llm_trained_model
                self._slm._model = llm_trained_model
                if "SelfGen" in self._setting:
                    self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM

                # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_train_data,
                    batch_size=8,
                )
                print(f"Evaluation on train private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                execution_logger.info(f"LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                )
                print(f"Evaluation on test private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                execution_logger.info(f"LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.4f}, Token Accuracy: {eval_acc:.4f}")
                # ############# step1, RL on the llm with negative loss as negative reward #############


                # ############# step2, cluster selection for next generation variation by treating negative as zero #############
                per_dev_model_prob_for_selection = F.softmax(per_dev_model_prob_for_training, dim=0).cpu().numpy() # make everything positive using softmax
                print(f"[debugging] per_dev_model_prob_for_selection: {per_dev_model_prob_for_selection=}, {len(per_dev_model_prob_for_selection)=}")
                execution_logger.info(f"[iteration {iteration}] per_dev_model_prob_for_selection: {per_dev_model_prob_for_selection=}, {len(per_dev_model_prob_for_selection)=}")
                # Choose #self._syn_cluster_num indices randomly based on the probability given above
                selected_cluster_indices = np.random.choice(range(len(per_dev_model_prob_for_selection)), size=self._syn_cluster_num, p=per_dev_model_prob_for_selection, replace=False)
                print(f"[debugging] selected_cluster_indices: {selected_cluster_indices=}, {len(selected_cluster_indices)=}")
                execution_logger.info(f"[iteration {iteration}] selected_cluster_indices: {selected_cluster_indices=}, {len(selected_cluster_indices)=}")
                selected_data_cluster_df_list = [syn_data_df_list[_i] for _i in selected_cluster_indices]
                # ############# step2, cluster selection for next generation variation by treating negative as zero #############

                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                # priv_data_list = []

                # Generate synthetic data for each label.
                for label_id in range(len(self._priv_data.metadata.label_info)):
                    execution_logger.info(f"Label {label_id}")
                    for _i, _syn_data_df in enumerate(selected_data_cluster_df_list):
                        execution_logger.info(f"Selected original cluster #{selected_cluster_indices[_i]} for next generation")
                        sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                        sub_syn_data = _syn_data_df.filter_label_id(label_id=label_id)
                        print(f"[debugging] checking sub_syn_data len for {label_id=}: {len(sub_syn_data.data_frame)=}")

                        # # DP NN histogram.
                        # sub_priv_data, sub_syn_data = self._histogram.compute_histogram(
                        #     priv_data=sub_priv_data, syn_data=sub_syn_data
                        # )
                        # priv_data_list.append(sub_priv_data)
                        # sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)

                        # Generate next population.
                        # sub_syn_data = self._population.next(
                        #     syn_data=sub_syn_data,
                        #     num_samples=int(num_samples_per_label_id[label_id]*per_dev_model_prob[_i].item()),
                        # )
                        sub_syn_data = self._population.next(
                            syn_data=sub_syn_data,
                            num_samples=len(sub_syn_data.data_frame), # for each label, next generation will generate self._population._initial_variation_api_fold variantions as duplications
                            selected=True,
                        )
                        sub_syn_data.set_label_id(label_id)
                        syn_data_list.append(sub_syn_data)

                # syn_data = Data.concat([syn_data] + syn_data_list)
                # syn_data = self.syn_data_concat_and_shuffle(
                #     selected_data_cluster_df_list=selected_data_cluster_df_list,
                #     syn_data_list=syn_data_list,
                #     shuffle=True,
                #     shuffle_all=True,
                # )
                if not self._population._keep_selected:
                    syn_data = Data.concat(selected_data_cluster_df_list + syn_data_list)
                else:
                    syn_data = Data.concat(syn_data_list)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration
                print(f"[debugging] syn_data after concat: {len(syn_data.data_frame)=} with {syn_data.metadata.iteration=}")

                # new_priv_data = Data.concat(priv_data_list)
                # self._priv_data = self._priv_data.merge(new_priv_data)

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
        finally:
            self._clean_up_loggers()

        return syn_data

    def run_priv_sgd(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
        val_sample_ratio=0.2,
        lr=5e-5,
        _scaler=1.0,
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :param val_sample_ratio: The sub-sampling ratio of private samples within each iteration, defaults to 0.2 for this method
        :type val_sample_ratio: float
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """

        # TODO: enable or disable base llm evaluation if needed
        """ base llm evaluation """
        eval_loss, eval_acc = evaluate_model_on_private_data(
            model=self._llm._model,
            tokenizer=self._llm._tokenizer,
            # dataset=self._priv_dev_data,
            dataset=self._priv_eval_data,
            batch_size=8,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        print(f"Base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        execution_logger.info(f"LLM base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        """ base llm evaluation """
        
        """ base llm generation evaluation """
        if "LargeGen" in self._setting:
            if self._llm_additional_generation == True:
                _model, _offload_hook = accelerate.cpu_offload_with_hook(self._population._api._llm._model, execution_device="cuda")
                _offload_hook.offload()
                _original_glm = copy.deepcopy(self._population._api._llm)

                self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM     
                _num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                _syn_data_list = []
                for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                    _syn_data = self._population.initial(
                        label_info=_label_info,
                        num_samples=_num_samples_per_label_id[_label_id],
                    )
                    _syn_data.set_label_id(_label_id)
                    _syn_data_list.append(_syn_data)
                _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                _syn_data.data_frame.reset_index(drop=True, inplace=True)
                _syn_data.metadata.iteration = 0 + (len(num_samples_schedule)+1)*2
                self._log_metrics(_syn_data)
                execution_logger.info(f"LLM generated data using base model with {len(_syn_data.data_frame)} samples [finished].")
                
                _model, _offload_hook = accelerate.cpu_offload_with_hook(self._population._api._llm._model, execution_device="cuda")
                _offload_hook.offload()
                self._population._api._llm = copy.deepcopy(_original_glm)

        """ base llm generation evaluation """
        torch.cuda.empty_cache()
        gc.collect()


        """
        SFT the llm using private samples without dp, the upper bound and also with dp
        """

        logging_dir = self._exp_folder + f"/llm_priv_train/no_syn/"
        if not os.path.exists(logging_dir):
            os.makedirs(logging_dir)
        if epsilon > 1E7:
            # aproximate no privacy considered
            llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                copy.deepcopy(self._llm._model), self._llm._tokenizer, self._priv_train_data, output_dir=logging_dir,
                per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio),
                num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                save_steps=(100 if (not ("LargeGen" in self._setting and self._llm_additional_generation)) else 1), logging_steps=5,
                eval_dataset=self._priv_eval_data,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
        else:
            llm_sft_eval_metric, llm_trained_model, _ = opacus_dpsgd_fine_tune(
                copy.deepcopy(self._llm._model), self._llm._tokenizer, self._priv_train_data, output_dir=logging_dir,
                per_device_train_batch_size=4, 
                num_train_epochs=len(num_samples_schedule)-1, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                save_steps=(100 if (not ("LargeGen" in self._setting and self._llm_additional_generation)) else 1), logging_steps=5,
                eval_dataset=self._priv_eval_data,
                epsilon=epsilon, delta=delta, val_sample_ratio=val_sample_ratio,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
        eval_loss, eval_acc = evaluate_model_on_private_data(
            model=llm_trained_model,
            tokenizer=self._llm._tokenizer,
            # dataset=self._priv_dev_data,
            dataset=self._priv_eval_data,
            batch_size=8,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        execution_logger.info(f"parameter-tuning, learing_rate_scale={_scaler}")
        print(f"DP-SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        execution_logger.info(f"LLM DP-SGD fine-tuned on base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

        print(f"[debugging] {type(self._llm._model)=}, {type(llm_trained_model)=}")

        if "LargeGen" in self._setting:
            if self._llm_additional_generation == True:
                _original_glm = copy.deepcopy(self._population._api._llm)
                self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM     
                folder_names = [name for name in os.listdir(logging_dir) if os.path.isdir(os.path.join(logging_dir, name))]
                for folder in folder_names:
                    load_model_path = os.path.join(logging_dir, folder)
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        llm_trained_model.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        llm_trained_model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                    
                    # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                    eval_loss, eval_acc = evaluate_model_on_private_data(
                        model=llm_trained_model,
                        tokenizer=self._llm._tokenizer,
                        dataset=self._priv_eval_data,
                        batch_size=8,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    )
                    print(f"Train on private data with {folder}, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                    execution_logger.info(f"LLM fine-tuned with {folder} on private data and evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

                    # Trained LLM generate synthetic data to see synthetic data format accuracy
                    self._population._api._llm._model = copy.deepcopy(llm_trained_model) # use the trained model for generation
                    _num_samples_per_label_id = self._get_num_samples_per_label_id(
                        num_samples=num_samples_schedule[0],
                        fraction_per_label_id=fraction_per_label_id,
                    )
                    _syn_data_list = []
                    for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                        _syn_data = self._population.initial(
                            label_info=_label_info,
                            num_samples=_num_samples_per_label_id[_label_id],
                        )
                        _syn_data.set_label_id(_label_id)
                        _syn_data_list.append(_syn_data)
                    _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                    _syn_data.data_frame.reset_index(drop=True, inplace=True)
                    _syn_data.metadata.iteration = int(folder.split('-')[-1]) + 1000*2
                    self._log_metrics(_syn_data)
                    execution_logger.info(f"LLM generated data after private fine-tuning with {len(_syn_data.data_frame)} samples [finished].")
                self._population._api._llm = copy.deepcopy(_original_glm)
        """
        SFT the llm using private samples without dp, the upper bound
        """

        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=self._t_fine_tune,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
                sampling_ratio=val_sample_ratio,
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                print(f"[debugging] before variation api for iteration 0, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                execution_logger.info(f"PE initial data generated with {len(syn_data.data_frame)} samples [finished].")
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")

            # For the first iteration, we instruction-fine-tune the LLM using the synthetic data
            logging_dir = self._exp_folder + f"/llm_syn_train/"
            if not os.path.exists(logging_dir):
                os.makedirs(logging_dir)

            self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.

            llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                self._llm._model, self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio), # 4000 sample: 5e-5, 400 sample: 5e-6
                num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                save_steps=100, logging_steps=5,
                eval_dataset=self._priv_eval_data,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            # no need to assign the trained model to the generation model
            # if "SelfGen" in self._setting:
            #     self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
            self._llm._model = llm_trained_model
            torch.cuda.empty_cache()
            gc.collect()
            print(f"[debugging] after gradient deletion in iteration=0, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

            # Evaluate the trained LLM on the private data for token prediction accuracy and loss
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=self._llm._model,
                tokenizer=self._llm._tokenizer,
                dataset=self._priv_train_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"[Iteration 0] Instruction fine-tune, evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=self._llm._model,
                tokenizer=self._llm._tokenizer,
                dataset=self._priv_eval_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"[Iteration 0] Instruction fine-tune, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

            if "LargeGen" in self._setting:
                if self._llm_additional_generation == True:
                    _original_glm = copy.deepcopy(self._population._api._llm)
                    self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM     
                    # self._population._api._llm._model = copy.deepcopy(llm_trained_model) # use the trained model for generation
                    _num_samples_per_label_id = self._get_num_samples_per_label_id(
                        num_samples=num_samples_schedule[0],
                        fraction_per_label_id=fraction_per_label_id,
                    )
                    _syn_data_list = []
                    for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                        _syn_data = self._population.initial(
                            label_info=_label_info,
                            num_samples=_num_samples_per_label_id[_label_id],
                        )
                        _syn_data.set_label_id(_label_id)
                        _syn_data_list.append(_syn_data)
                    _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                    _syn_data.data_frame.reset_index(drop=True, inplace=True)
                    _syn_data.metadata.iteration = 0 + 1000*4
                    self._log_metrics(_syn_data)
                    execution_logger.info(f"LLM generated data after synthetic data instruction fine-tuning with {len(_syn_data.data_frame)} samples [finished].")
                    self._population._api._llm = copy.deepcopy(_original_glm)

            """
            SFT the llm using private samples without dp, the upper bound
            """
            logging_dir = self._exp_folder + f"/llm_priv_train/base_on_syn/"
            if not os.path.exists(logging_dir):
                os.makedirs(logging_dir)
            
            if epsilon > 1E7:
                # aproximate no privacy considered
                llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                    copy.deepcopy(self._llm._model), self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                    per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio), # 4000 sample: 5e-5, 400 sample: 5e-6
                    num_train_epochs=len(num_samples_schedule)-1, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                    save_steps=(100 if (not ("LargeGen" in self._setting and self._llm_additional_generation)) else 1), logging_steps=5,
                    eval_dataset=self._priv_eval_data,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                # eval_loss, eval_acc = evaluate_model_on_private_data(
                #     model=llm_trained_model,
                #     tokenizer=self._llm._tokenizer,
                #     # dataset=self._priv_dev_data,
                #     dataset=self._priv_eval_data,
                #     batch_size=8,
                #     add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                # )
                # print(f"DP-SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                # execution_logger.info(f"LLM DP-SGD fine-tuned on instruction-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            else:
                print(f"training_epoch={len(num_samples_schedule)-1=}")
                llm_sft_eval_metric, llm_trained_model, _ = opacus_dpsgd_fine_tune(
                    copy.deepcopy(self._llm._model), self._llm._tokenizer, self._priv_train_data, output_dir=logging_dir,
                    per_device_train_batch_size=4, 
                    num_train_epochs=len(num_samples_schedule)-1, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                    save_steps=(100 if (not ("LargeGen" in self._setting and self._llm_additional_generation)) else 1), logging_steps=5,
                    eval_dataset=self._priv_eval_data,
                    epsilon=epsilon, delta=delta, val_sample_ratio=val_sample_ratio,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                # # assert 1 == 0
            
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=llm_trained_model,
                tokenizer=self._llm._tokenizer,
                # dataset=self._priv_dev_data,
                dataset=self._priv_eval_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"DP-SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"LLM DP-SGD fine-tuned on instruction-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            
            if "LargeGen" in self._setting:
                if self._llm_additional_generation == True:
                    _original_glm = copy.deepcopy(self._population._api._llm)
                    self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM     
                    folder_names = [name for name in os.listdir(logging_dir) if os.path.isdir(os.path.join(logging_dir, name))]
                    for folder in folder_names:
                        load_model_path = os.path.join(logging_dir, folder)
                        if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                            print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                            state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                            llm_trained_model.load_state_dict(state_dict)
                        else:
                            base = transformers.AutoModelForCausalLM.from_pretrained(
                                self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                            )
                            peft_config = PeftConfig.from_pretrained(load_model_path)
                            llm_trained_model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                        
                        # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                        eval_loss, eval_acc = evaluate_model_on_private_data(
                            model=llm_trained_model,
                            tokenizer=self._llm._tokenizer,
                            dataset=self._priv_eval_data,
                            batch_size=8,
                            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                        )
                        print(f"Train on private data with {folder} based on , evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                        execution_logger.info(f"LLM fine-tuned with {folder} on private data and evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

                        # Trained LLM generate synthetic data to see synthetic data format accuracy
                        self._population._api._llm._model = copy.deepcopy(llm_trained_model) # use the trained model for generation
                        _num_samples_per_label_id = self._get_num_samples_per_label_id(
                            num_samples=num_samples_schedule[0],
                            fraction_per_label_id=fraction_per_label_id,
                        )
                        _syn_data_list = []
                        for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                            _syn_data = self._population.initial(
                                label_info=_label_info,
                                num_samples=_num_samples_per_label_id[_label_id],
                            )
                            _syn_data.set_label_id(_label_id)
                            _syn_data_list.append(_syn_data)
                        _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                        _syn_data.data_frame.reset_index(drop=True, inplace=True)
                        _syn_data.metadata.iteration = int(folder.split('-')[-1]) + 1000*4
                        self._log_metrics(_syn_data)
                        execution_logger.info(f"LLM generated data after private fine-tuning on instruction-tuned model with {len(_syn_data.data_frame)} samples [finished].")
                    self._population._api._llm = copy.deepcopy(_original_glm)

            """
            SFT the llm using private samples without dp, the upper bound
            """




        finally:
            self._clean_up_loggers()
        return syn_data

    def run_optimal_gradient_combination(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
        val_sample_ratio=0.2,
        lr=5e-5,
        metric_inverse_epsilon=1E-6,
        noise_place='coefficient',
        prompt_type='RandomK',
        _scaler=0.5,
        clip_or_normalize='normalize',
        noise_on_vote=True,
        sample_evolve=True,
        approx_strategy='opt',
        with_instruction_base=0,
        use_eigen=False,
        clip_norm=1.0,
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :param val_sample_ratio: The sub-sampling ratio of private samples within each iteration, defaults to 0.2 for this method
        :type val_sample_ratio: float
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """

        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
                sampling_ratio=val_sample_ratio,
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                print(f"[debugging] before variation api for iteration 0, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                execution_logger.info(f"PE initial data generated with {len(syn_data.data_frame)} samples [finished].")
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")
        
            optimizer_state_dict = None

            # if "LargeGen" in self._setting:
            #     if self._llm_additional_generation == True:
            #         _original_glm = copy.deepcopy(self._population._api._llm)
            #         self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM     
            #         _num_samples_per_label_id = self._get_num_samples_per_label_id(
            #             num_samples=num_samples_schedule[0],
            #             fraction_per_label_id=fraction_per_label_id,
            #         )
            #         _syn_data_list = []
            #         for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
            #             _syn_data = self._population.initial(
            #                 label_info=_label_info,
            #                 num_samples=_num_samples_per_label_id[_label_id],
            #             )
            #             _syn_data.set_label_id(_label_id)
            #             _syn_data_list.append(_syn_data)
            #         _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
            #         _syn_data.data_frame.reset_index(drop=True, inplace=True)
            #         _syn_data.metadata.iteration = 0 + (len(num_samples_schedule)+1)*2
            #         self._log_metrics(_syn_data)
            #         execution_logger.info(f"LLM generated data within [iteration=0] with {len(_syn_data.data_frame)} samples [finished].")
            #         # if self._init_data_file != '':
            #         #     _syn_data.save_checkpoint(self._init_data_file)
            #         #     print(f"save PE initial data to {self._init_data_file}")
            #         # self._population._api._llm._model = copy.deepcopy(_original_glm)
            #         # self._population._api._llm._tokenizer = copy.deepcopy(_original_glm_tokenizer)
            #         self._population._api._llm = copy.deepcopy(_original_glm)

            # ###### Run PE iterations. ######
            for iteration in range(syn_data.metadata.iteration + 1, len(num_samples_schedule)+1):
                execution_logger.info(f"PE iteration {iteration}")

                print(f"[debugging] ghostsuite gradient dot for {syn_data.metadata.iteration} -> {iteration}, {type(syn_data)}")
                
                self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.

                # ### TODO: see if it further helps ###
                # TODO: for debug, without it, the performance is also good
                if iteration == 1 and with_instruction_base != 0:
                    # For the first iteration, we instruction-fine-tune the LLM using the synthetic data
                    logging_dir = self._exp_folder + f"/llm_syn_opt_grad_comb/0/llm/"
                    if not os.path.exists(logging_dir):
                        os.makedirs(logging_dir)

                    execution_logger.info(f"parameter-tuning, learing_rate_scale={_scaler} for instruction fine-tunning")
                    llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                        self._llm._model, self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                        per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio), # 4000 sample: 5e-5, 400 sample: 5e-6
                        num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                        save_steps=100, logging_steps=5,
                        eval_dataset=self._priv_eval_data,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    )
                    # no need to assign the trained model to the generation model
                    # if "SelfGen" in self._setting:
                    #     self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                    self._llm._model = llm_trained_model
                    torch.cuda.empty_cache()
                    gc.collect()
                    print(f"[debugging] after gradient deletion in iteration=0, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                    # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                    eval_loss, eval_acc = evaluate_model_on_private_data(
                        model=self._llm._model,
                        tokenizer=self._llm._tokenizer,
                        dataset=self._priv_train_data,
                        batch_size=8,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    )
                    print(f"[Iteration 0] Instruction fine-tune, evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                    execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                    eval_loss, eval_acc = evaluate_model_on_private_data(
                        model=self._llm._model,
                        tokenizer=self._llm._tokenizer,
                        dataset=self._priv_eval_data,
                        batch_size=8,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    )
                    print(f"[Iteration 0] Instruction fine-tune, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                    execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

                logging_dir = self._exp_folder + f"/llm_syn_opt_grad_comb/{iteration}/llm/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)

                if iteration > 1: # when iteration==1, the self._llm._model is already at its starting point
                    load_model_path = self._exp_folder + f"/llm_syn_opt_grad_comb/{iteration-1}/llm/"
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        self._llm._model.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)

                print(f"[debugging] before gradient calculate and updation, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                if noise_place=='coefficient':
                    llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                        self._llm._model, self._llm._tokenizer, optimizer_state_dict, self._syn_train_data, self._priv_train_data, 
                        output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                        per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                        save_steps=100, logging_steps=1,
                        noise_multiplier=self._dp._noise_multiplier,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                        metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                        clip_or_normalize=clip_or_normalize,
                        noise_on_vote=noise_on_vote, 
                        approximate_strategy=approx_strategy,
                        use_eigen=use_eigen,
                        clip_norm=clip_norm,
                    )
                else: # noise_place=='inner_product'
                    llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad_different_noise( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                        self._llm._model, self._llm._tokenizer, optimizer_state_dict, self._syn_train_data, self._priv_train_data, 
                        output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                        per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                        save_steps=100, logging_steps=1,
                        noise_multiplier=self._dp._noise_multiplier,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                        metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                        # clip_or_normalize=clip_or_normalize,
                        use_eigen=use_eigen,
                        clip_norm=clip_norm,
                    )
                if "SelfGen" in self._setting:
                    self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                elif "LargeGen" in self._setting:
                    if self._llm_additional_generation == True:
                        _original_glm = copy.deepcopy(self._population._api._llm)
                        self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM
                        self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                        
                        _num_samples_per_label_id = self._get_num_samples_per_label_id(
                            num_samples=num_samples_schedule[0],
                            fraction_per_label_id=fraction_per_label_id,
                        )
                        _syn_data_list = []
                        for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                            _syn_data = self._population.initial(
                                label_info=_label_info,
                                num_samples=_num_samples_per_label_id[_label_id],
                            )
                            _syn_data.set_label_id(_label_id)
                            _syn_data_list.append(_syn_data)
                        _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                        _syn_data.data_frame.reset_index(drop=True, inplace=True)
                        _syn_data.metadata.iteration = iteration + (len(num_samples_schedule)+1)*2
                        self._log_metrics(_syn_data)
                        execution_logger.info(f"LLM generated data within [{iteration=}] with {len(_syn_data.data_frame)} samples [finished].")
                        self._population._api._llm = copy.deepcopy(_original_glm)

                self._llm._model = llm_trained_model
                if train_sample_grad is not None:
                    del train_sample_grad
                    del val_sample_grad
                torch.cuda.empty_cache()
                gc.collect()
                print(f"[debugging] after gradient deletion in {iteration=}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_train_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"[Iteration {iteration}] Evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"[Iteration {iteration}] Evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                # Evaluate the trained LLM on the private data for token prediction accuracy and loss

                # no sample selection or data generation in the last iteration after evaluation
                if iteration == len(num_samples_schedule):
                    break
                
                if not sample_evolve:
                    if save_checkpoint:
                        syn_data.save_checkpoint(checkpoint_path)
                    self._log_metrics(syn_data)
                    continue
                
                if prompt_type == 'TopKabs':
                    train_sample_dp_score = torch.abs(train_sample_dp_score).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                    print(f"[debugging] [iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                elif prompt_type == 'TopK':
                    train_sample_dp_score = train_sample_dp_score.cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                    print(f"[debugging] [iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                elif prompt_type == 'RandomKabs':
                    train_sample_dp_score = F.softmax(torch.abs(train_sample_dp_score)).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                    print(f"[debugging] [iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                else:
                    assert prompt_type == 'RandomK'
                    train_sample_dp_score = F.softmax(train_sample_dp_score).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                    print(f"[debugging] [iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")

                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )

                # Generate synthetic data for each label.
                syn_data_list = []
                selected_syn_data = syn_data.select_by_index(selected_sample_indices)
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    execution_logger.info(f"Label {label_id}")
                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                    sub_syn_data = selected_syn_data.filter_label_id(label_id=label_id)
                    print(f"[debugging] checking sub_syn_data len for {label_id=}: {len(sub_syn_data.data_frame)=}")
                    print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                    # # DP NN histogram.
                    # sub_priv_data, sub_syn_data = self._histogram.compute_histogram(
                    #     priv_data=sub_priv_data, syn_data=sub_syn_data
                    # )
                    # priv_data_list.append(sub_priv_data)
                    # sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)

                    if self._population._initial_variation_api_fold > 0:
                        # Generate next population.
                        # sub_syn_data = self._population.next(
                        #     syn_data=sub_syn_data,
                        #     num_samples=int(num_samples_per_label_id[label_id]*per_dev_model_prob[_i].item()),
                        # )
                        sub_syn_data = self._population.next(
                            syn_data=sub_syn_data,
                            num_samples=len(sub_syn_data.data_frame), # for each label, next generation will generate self._population._initial_variation_api_fold variantions as duplications
                            selected=True,
                        )
                    else:
                        print(f"[debugging] no variation api for {iteration=}, just randomly generate new samples")
                        sub_syn_data = self._population.initial(
                            label_info=label_info,
                            num_samples=num_samples_per_label_id[label_id],
                        )
                    sub_syn_data.set_label_id(label_id)
                    syn_data_list.append(sub_syn_data)

                    print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")
                    
                # syn_data = Data.concat([syn_data] + syn_data_list)
                # syn_data = self.syn_data_concat_and_shuffle(
                #     selected_data_cluster_df_list=selected_data_cluster_df_list,
                #     syn_data_list=syn_data_list,
                #     shuffle=True,
                #     shuffle_all=True,
                # )
                print(f"[debugging] in <./pe/runner/pe.py> <opt_grad> {selected_syn_data=}, {type(selected_syn_data)=}")
                if not self._population._keep_selected:
                    syn_data = Data.concat([selected_syn_data] + syn_data_list, metadata=self._priv_data.metadata)
                else:
                    syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration
                print(f"[debugging] syn_data after concat: {len(syn_data.data_frame)=} with {syn_data.metadata.iteration=}")

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
        finally:
            self._clean_up_loggers()

        return syn_data

    def run_pure_pe(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
        val_sample_ratio_of_conterpart=0.2,
        lr=5e-5,
        _scaler=1.0,  # for the learning rate scaling
    ):
        """Run the PE algorithm.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
                sampling_ratio=1.0,  # no sub-sampling for this method
            )

            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")

            # Run PE iterations.
            for iteration in range(syn_data.metadata.iteration + 1, len(num_samples_schedule)):
                execution_logger.info(f"PE iteration {iteration}")

                ''' train self._llm using the synthetic data '''

                self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.

                # if iteration == 1:
                #     # For the first iteration, we instruction-fine-tune the LLM using the synthetic data
                #     logging_dir = self._exp_folder + f"/llm_syn_pe/0/llm/"
                #     if not os.path.exists(logging_dir):
                #         os.makedirs(logging_dir)

                #     execution_logger.info(f"LLM instruction fine-tune on synthetic data for iteration 0, with learning rate scaler={_scaler}")
                #     llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                #         self._llm._model, self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                #         per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart), # 4000 sample: 5e-5, 400 sample: 5e-6
                #         num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                #         save_steps=100, logging_steps=5,
                #         eval_dataset=self._priv_eval_data,
                #         add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                #     )
                #     # no need to assign the trained model to the generation model
                #     # if "SelfGen" in self._setting:
                #     #     self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                #     self._llm._model = llm_trained_model
                #     torch.cuda.empty_cache()
                #     gc.collect()
                #     print(f"[debugging] after gradient deletion in iteration=0, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                #     # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                #     eval_loss, eval_acc = evaluate_model_on_private_data(
                #         model=self._llm._model,
                #         tokenizer=self._llm._tokenizer,
                #         dataset=self._priv_train_data,
                #         batch_size=8,
                #         add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                #     )
                #     print(f"[Iteration 0] Instruction fine-tune, evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                #     execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                #     eval_loss, eval_acc = evaluate_model_on_private_data(
                #         model=self._llm._model,
                #         tokenizer=self._llm._tokenizer,
                #         dataset=self._priv_eval_data,
                #         batch_size=8,
                #         add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                #     )
                #     print(f"[Iteration 0] Instruction fine-tune, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                #     execution_logger.info(f"[Iteration 0] Instruction fine-tune, LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

                logging_dir = self._exp_folder + f"/llm_syn_pe/{iteration}/llm/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)
                
                # print(f"{type(self._syn_train_data)=}")
                # print(f"{self._syn_train_data.column_names=}")
                llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                    copy.deepcopy(self._llm._model), self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                    per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart), # 4000 sample: 5e-5, 400 sample: 5e-6
                    num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart)*_scaler, #*((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart), # 4000 sample -- 5e-5, 400 sample 5e-6
                    save_steps=100, logging_steps=5,
                    eval_dataset=self._priv_eval_data,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                # if "SelfGen" in self._setting:
                #     self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                # self._llm._model = llm_trained_model
                torch.cuda.empty_cache()
                gc.collect()
                print(f"[debugging] after gradient deletion in {iteration=}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=llm_trained_model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_train_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"[Iteration {iteration}] Evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=llm_trained_model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"[Iteration {iteration}] Evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")

                if "LargeGen" in self._setting:
                    if self._llm_additional_generation == True:
                        _original_glm = copy.deepcopy(self._population._api._llm)
                        self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM
                        self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                        
                        _num_samples_per_label_id = self._get_num_samples_per_label_id(
                            num_samples=num_samples_schedule[0],
                            fraction_per_label_id=fraction_per_label_id,
                        )
                        _syn_data_list = []
                        for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                            _syn_data = self._population.initial(
                                label_info=_label_info,
                                num_samples=_num_samples_per_label_id[_label_id],
                            )
                            _syn_data.set_label_id(_label_id)
                            _syn_data_list.append(_syn_data)
                        _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                        _syn_data.data_frame.reset_index(drop=True, inplace=True)
                        _syn_data.metadata.iteration = iteration + (len(num_samples_schedule)+1)*2
                        self._log_metrics(_syn_data)
                        execution_logger.info(f"LLM generated data within [{iteration=}] with {len(_syn_data.data_frame)} samples [finished].")
                        self._population._api._llm = copy.deepcopy(_original_glm)

                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                priv_data_list = []

                # Generate synthetic data for each label.
                for label_id in range(len(self._priv_data.metadata.label_info)):
                    execution_logger.info(f"Label {label_id}")
                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                    sub_syn_data = syn_data.filter_label_id(label_id=label_id)

                    # DP NN histogram.
                    sub_priv_data, sub_syn_data = self._histogram.compute_histogram(
                        priv_data=sub_priv_data, syn_data=sub_syn_data
                    )
                    priv_data_list.append(sub_priv_data)
                    sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)

                    # Generate next population.
                    sub_syn_data = self._population.next(
                        syn_data=sub_syn_data,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    sub_syn_data.set_label_id(label_id)
                    syn_data_list.append(sub_syn_data)

                syn_data = Data.concat(syn_data_list)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration

                new_priv_data = Data.concat(priv_data_list)
                self._priv_data = self._priv_data.merge(new_priv_data)

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
            
            iteration = len(num_samples_schedule)
            self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.
            logging_dir = self._exp_folder + f"/llm_syn_pe/{iteration}/llm/"
            if not os.path.exists(logging_dir):
                os.makedirs(logging_dir)
            if iteration > 1:
                load_model_path = self._exp_folder + f"/llm_syn_pe/{iteration-1}/llm/"
                if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                    print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                    state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                    self._llm._model.load_state_dict(state_dict)
                else:
                    base = transformers.AutoModelForCausalLM.from_pretrained(
                        self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                    )
                    peft_config = PeftConfig.from_pretrained(load_model_path)
                    self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
            
            print(f"{type(self._syn_train_data)=}")
            print(f"{self._syn_train_data.column_names=}")
            llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
                self._llm._model, self._llm._tokenizer, self._syn_train_data, output_dir=logging_dir,
                per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart), # 4000 sample: 5e-5, 400 sample: 5e-6
                num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio_of_conterpart)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                save_steps=100, logging_steps=5,
                eval_dataset=self._priv_eval_data,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            # if "SelfGen" in self._setting:
            #     self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
            # self._llm._model = llm_trained_model
            torch.cuda.empty_cache()
            gc.collect()
            print(f"[debugging] after gradient deletion in {iteration=}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

            # Evaluate the trained LLM on the private data for token prediction accuracy and loss
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=llm_trained_model,
                tokenizer=self._llm._tokenizer,
                dataset=self._priv_train_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"[Iteration {iteration}] Evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on train private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=llm_trained_model,
                tokenizer=self._llm._tokenizer,
                dataset=self._priv_eval_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"[Iteration {iteration}] Evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[Iteration {iteration}] LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            # Evaluate the trained LLM on the private data for token prediction accuracy and loss

            if "LargeGen" in self._setting:
                if self._llm_additional_generation == True:
                    _original_glm = copy.deepcopy(self._population._api._llm)
                    self._population._api._llm = copy.deepcopy(self._llm)  # Update the population model with the trained LLM
                    self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                    
                    _num_samples_per_label_id = self._get_num_samples_per_label_id(
                        num_samples=num_samples_schedule[0],
                        fraction_per_label_id=fraction_per_label_id,
                    )
                    _syn_data_list = []
                    for _label_id, _label_info in enumerate(self._priv_data.metadata.label_info):
                        _syn_data = self._population.initial(
                            label_info=_label_info,
                            num_samples=_num_samples_per_label_id[_label_id],
                        )
                        _syn_data.set_label_id(_label_id)
                        _syn_data_list.append(_syn_data)
                    _syn_data = Data.concat(_syn_data_list, metadata=self._priv_data.metadata)
                    _syn_data.data_frame.reset_index(drop=True, inplace=True)
                    _syn_data.metadata.iteration = iteration + (len(num_samples_schedule)+1)*2
                    self._log_metrics(_syn_data)
                    execution_logger.info(f"LLM generated data within [{iteration=}] with {len(_syn_data.data_frame)} samples [finished].")
                    self._population._api._llm = copy.deepcopy(_original_glm)

        finally:
            self._clean_up_loggers()

        return syn_data

    def run_loss_plot(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
        val_sample_ratio=0.2,
        lr=5e-5,
        _scaler=1.0,
        metric_inverse_epsilon=1E-6,
        noise_place='coefficient',
        prompt_type='RandomK',
        clip_or_normalize='normalize',
        noise_on_vote=True,
        sample_evolve=True,
        approx_strategy='opt',
        with_instruction_base=0,
        use_eigen=False,
    ):
        """Run the loss plot of each training gradient update.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :param val_sample_ratio: The sub-sampling ratio of private samples within each iteration, defaults to 0.2 for this method
        :type val_sample_ratio: float
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        loss_points = {}
        """ base llm evaluation """
        eval_loss, eval_acc = evaluate_model_on_private_data(
            model=self._llm._model,
            tokenizer=self._llm._tokenizer,
            # dataset=self._priv_dev_data,
            dataset=self._priv_eval_data,
            batch_size=8,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        print(f"Base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        execution_logger.info(f"LLM base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        """ base llm evaluation """
        loss_points['checkpoint-0'] = [(eval_loss, eval_acc)]

        """
        SFT the llm using private samples without dp
        """

        _logging_dir = self._exp_folder + f"/llm_priv_train/no_syn/"
        if not os.path.exists(_logging_dir):
            os.makedirs(_logging_dir)
        
        llm_sft_eval_metric, llm_trained_model, _ = sft_fine_tune(
            copy.deepcopy(self._llm._model), self._llm._tokenizer, self._priv_train_data, output_dir=_logging_dir,
            per_device_train_batch_size=4, gradient_accumulation_steps=int((1/4)*len(self._priv_train_data)*val_sample_ratio),
            num_train_epochs=self._t_fine_tune, learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
            save_steps=1, logging_steps=1,
            eval_dataset=self._priv_eval_data,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        eval_loss, eval_acc = evaluate_model_on_private_data(
            model=llm_trained_model,
            tokenizer=self._llm._tokenizer,
            # dataset=self._priv_dev_data,
            dataset=self._priv_eval_data,
            batch_size=8,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        execution_logger.info(f"parameter-tuning, learing_rate_scale={_scaler}")
        print(f"SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        execution_logger.info(f"LLM SGD fine-tuned on base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
        

        try:
            # Set privacy budget.
            self._dp.set_epsilon_and_delta(
                num_iterations=len(num_samples_schedule) - 1,
                epsilon=epsilon,
                delta=delta,
                noise_multiplier=noise_multiplier,
                sampling_ratio=val_sample_ratio,
            )
            # Generate or load initial data.
            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):
                execution_logger.info(
                    f"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}"
                )
            elif self._init_data_file != '' and (syn_data := self.load_checkpoint(self._init_data_file)):
                execution_logger.info(f"PE initial data loaded from already generated data with {len(syn_data.data_frame)} samples [finished].")
            else:
                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[0],
                    fraction_per_label_id=fraction_per_label_id,
                )
                syn_data_list = []
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    syn_data = self._population.initial(
                        label_info=label_info,
                        num_samples=num_samples_per_label_id[label_id],
                    )
                    syn_data.set_label_id(label_id)
                    syn_data_list.append(syn_data)
                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = 0
                self._log_metrics(syn_data)
                if self._init_data_file != '':
                    syn_data.save_checkpoint(self._init_data_file)
                    print(f"save PE initial data to {self._init_data_file}")
            self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.

            iteration = 1

            # evaluate the base-model
            logging_dir = self._exp_folder + f"/llm_priv_train/optgrad_update/"
            if not os.path.exists(logging_dir):
                os.makedirs(logging_dir)
            if noise_place=='coefficient':
                llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                    copy.deepcopy(self._llm._model), self._llm._tokenizer, None, self._syn_train_data, self._priv_train_data, 
                    output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                    per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                    save_steps=100, logging_steps=1,
                    noise_multiplier=self._dp._noise_multiplier,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                    clip_or_normalize=clip_or_normalize,
                    noise_on_vote=noise_on_vote,
                    approximate_strategy=approx_strategy,
                    use_eigen=use_eigen,
                )
            else: # noise_place=='inner_product'
                llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad_different_noise( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                    copy.deepcopy(self._llm._model), self._llm._tokenizer, None, self._syn_train_data, self._priv_train_data, 
                    output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                    per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                    save_steps=100, logging_steps=1,
                    noise_multiplier=self._dp._noise_multiplier,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                    # clip_or_normalize=clip_or_normalize,
                    use_eigen=use_eigen,
                )
            if "SelfGen" in self._setting:
                self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
            if train_sample_grad is not None:
                del train_sample_grad
                del val_sample_grad
            torch.cuda.empty_cache()
            gc.collect()
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=llm_trained_model,
                tokenizer=self._llm._tokenizer,
                dataset=self._priv_eval_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            print(f"[checkpoint-0] (base model) OptGrad Evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[[checkpoint-0] (base model) OptGrad LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            if not 'checkpoint-0' in loss_points.keys():
                loss_points['checkpoint-0'] = [()]
            loss_points['checkpoint-0'].append((eval_loss, eval_acc))

            ######### generate synthetic samples if needed #########
            if sample_evolve:
                if prompt_type == 'TopKabs':
                    train_sample_dp_score = torch.abs(train_sample_dp_score).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                    print(f"[debugging] [iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                elif prompt_type == 'TopK':
                    train_sample_dp_score = train_sample_dp_score.cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                    print(f"[debugging] [iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                elif prompt_type == 'RandomKabs':
                    train_sample_dp_score = F.softmax(torch.abs(train_sample_dp_score)).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                    print(f"[debugging] [iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                else:
                    assert prompt_type == 'RandomK'
                    train_sample_dp_score = F.softmax(train_sample_dp_score).cpu().numpy()
                    print(f"[debugging] [iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                    selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                    print(f"[debugging] [iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    execution_logger.info(f"[iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")

                num_samples_per_label_id = self._get_num_samples_per_label_id(
                    num_samples=num_samples_schedule[iteration],
                    fraction_per_label_id=fraction_per_label_id,
                )

                # Generate synthetic data for each label.
                syn_data_list = []
                selected_syn_data = syn_data.select_by_index(selected_sample_indices)
                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                    execution_logger.info(f"Label {label_id}")
                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                    sub_syn_data = selected_syn_data.filter_label_id(label_id=label_id)
                    print(f"[debugging] checking sub_syn_data len for {label_id=}: {len(sub_syn_data.data_frame)=}")
                    print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                    if self._population._initial_variation_api_fold > 0:
                        sub_syn_data = self._population.next(
                            syn_data=sub_syn_data,
                            num_samples=len(sub_syn_data.data_frame), # for each label, next generation will generate self._population._initial_variation_api_fold variantions as duplications
                            selected=True,
                        )
                    else:
                        print(f"[debugging] no variation api for {iteration=}, just randomly generate new samples")
                        sub_syn_data = self._population.initial(
                            label_info=label_info,
                            num_samples=num_samples_per_label_id[label_id],
                        )
                    sub_syn_data.set_label_id(label_id)
                    syn_data_list.append(sub_syn_data)

                    print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                print(f"[debugging] in <./pe/runner/pe.py> <opt_grad> {selected_syn_data=}, {type(selected_syn_data)=}")
                if not self._population._keep_selected:
                    syn_data = Data.concat([selected_syn_data] + syn_data_list, metadata=self._priv_data.metadata)
                else:
                    syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                syn_data.data_frame.reset_index(drop=True, inplace=True)
                syn_data.metadata.iteration = iteration
                print(f"[debugging] syn_data after concat: {len(syn_data.data_frame)=} with {syn_data.metadata.iteration=}")

                if save_checkpoint:
                    syn_data.save_checkpoint(checkpoint_path)
                self._log_metrics(syn_data)
                self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.
            ######### generate synthetic samples if needed #########


            logging_dir = self._exp_folder + f"/llm_priv_train/dpsgd_update/"
            if not os.path.exists(logging_dir):
                os.makedirs(logging_dir)
            print(f"num_train_epochs={int((len(num_samples_schedule)-1)*val_sample_ratio)}, {epsilon=}, {delta=}")
            llm_sft_eval_metric, llm_trained_model, _ = opacus_dpsgd_fine_tune(
                copy.deepcopy(self._llm._model), self._llm._tokenizer, self._priv_train_data, output_dir=logging_dir,
                per_device_train_batch_size=4, 
                num_train_epochs=int((len(num_samples_schedule)-1)*val_sample_ratio), learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                save_steps=100, logging_steps=5,
                eval_dataset=self._priv_eval_data,
                epsilon=epsilon, delta=delta, val_sample_ratio=val_sample_ratio,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                stop_one_batch=True,
            )
            eval_loss, eval_acc = evaluate_model_on_private_data(
                model=llm_trained_model,
                tokenizer=self._llm._tokenizer,
                # dataset=self._priv_dev_data,
                dataset=self._priv_eval_data,
                batch_size=8,
                add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
            )
            execution_logger.info(f"parameter-tuning, learing_rate_scale={_scaler}")
            print(f"[checkpoint-0] (base model) DP-SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            execution_logger.info(f"[checkpoint-0] (base model) LLM DP-SGD fine-tuned on base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
            if not 'checkpoint-0' in loss_points.keys():
                loss_points['checkpoint-0'] = [(),()]
            loss_points['checkpoint-0'].append((eval_loss, eval_acc))


            # evaluate each train step
            folder_names = [name for name in os.listdir(_logging_dir) if os.path.isdir(os.path.join(_logging_dir, name))]
            # assert len(folder_names) == len(num_samples_schedule)-1, f'[ERROR] should have {len(folder_names)=} == {len(num_samples_schedule)=}'
            # for folder in folder_names:
            for iteration in range(1, len(folder_names)+1):
                folder = 'checkpoint-' + str(iteration)
                assert folder in folder_names, f'[ERROR] should have {folder=} under folder {_logging_dir}'
                load_model_path = os.path.join(_logging_dir, folder)
                if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                    print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                    state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                    _llm_trained_model.load_state_dict(state_dict)
                else:
                    base = transformers.AutoModelForCausalLM.from_pretrained(
                        self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                    )
                    peft_config = PeftConfig.from_pretrained(load_model_path)
                    _llm_trained_model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                
                # Evaluate the trained LLM on the private data for token prediction accuracy and loss
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=_llm_trained_model,
                    tokenizer=self._llm._tokenizer,
                    # dataset=self._priv_dev_data,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"SGD Train on private data with {folder}, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"LLM SGD fine-tuned with {folder} on train private data and evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                loss_points[folder] = [(eval_loss, eval_acc)]

                logging_dir = self._exp_folder + f"/llm_priv_train/optgrad_update/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)
                if noise_place=='coefficient':
                    llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                        copy.deepcopy(_llm_trained_model), self._llm._tokenizer, optimizer_state_dict, self._syn_train_data, self._priv_train_data, 
                        output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                        per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                        save_steps=100, logging_steps=1,
                        noise_multiplier=self._dp._noise_multiplier,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                        metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                        clip_or_normalize=clip_or_normalize,
                        noise_on_vote=noise_on_vote,
                        approximate_strategy=approx_strategy,
                        use_eigen=use_eigen,
                    )
                else: # noise_place=='inner_product'
                    llm_trained_model, train_sample_dp_score, train_sample_grad, val_sample_grad, optimizer_state_dict = get_sample_grad_different_noise( # train_sample_grad.shape=[#train_sample, #param], val_sample_grad=[#real_sample, #param]
                        copy.deepcopy(_llm_trained_model), self._llm._tokenizer, optimizer_state_dict, self._syn_train_data, self._priv_train_data, 
                        output_dir=logging_dir, val_sample_ratio=val_sample_ratio,
                        per_device_train_batch_size=4, num_train_epochs=-1, learning_rate=lr, # no training at all
                        save_steps=100, logging_steps=1,
                        noise_multiplier=self._dp._noise_multiplier,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                        metric_inverse_epsilon=metric_inverse_epsilon, # for the optimal gradient combination
                        # clip_or_normalize=clip_or_normalize,
                        use_eigen=use_eigen,
                    )
                if "SelfGen" in self._setting:
                    self._population._api._llm._model = llm_trained_model  # Update the population model with the trained LLM
                # self._llm._model = llm_trained_model
                if train_sample_grad is not None:
                    del train_sample_grad
                    del val_sample_grad
                torch.cuda.empty_cache()
                gc.collect()
                # print(f"[debugging] after gradient deletion in {iteration=}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=llm_trained_model,
                    tokenizer=self._llm._tokenizer,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                print(f"[{folder}] OptGrad Evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"[{folder}] OptGrad LLM fine-tuned evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                if not folder in loss_points.keys():
                    loss_points[folder] = [()]
                loss_points[folder].append((eval_loss, eval_acc))
                ######### generate synthetic samples if needed #########
                if sample_evolve:
                    if prompt_type == 'TopKabs':
                        train_sample_dp_score = torch.abs(train_sample_dp_score).cpu().numpy()
                        print(f"[debugging] [iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        execution_logger.info(f"[iteration {iteration}] top-K(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                        print(f"[debugging] [iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                        execution_logger.info(f"[iteration {iteration}] top-K(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    elif prompt_type == 'TopK':
                        train_sample_dp_score = train_sample_dp_score.cpu().numpy()
                        print(f"[debugging] [iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        execution_logger.info(f"[iteration {iteration}] top-K train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        selected_sample_indices = np.argpartition(train_sample_dp_score, -num_samples_schedule[iteration])[-num_samples_schedule[iteration]:]
                        print(f"[debugging] [iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                        execution_logger.info(f"[iteration {iteration}] top-K selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    elif prompt_type == 'RandomKabs':
                        train_sample_dp_score = F.softmax(torch.abs(train_sample_dp_score)).cpu().numpy()
                        print(f"[debugging] [iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        execution_logger.info(f"[iteration {iteration}] softmax(abs) train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                        print(f"[debugging] [iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                        execution_logger.info(f"[iteration {iteration}] softmax(abs) selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                    else:
                        assert prompt_type == 'RandomK'
                        train_sample_dp_score = F.softmax(train_sample_dp_score).cpu().numpy()
                        print(f"[debugging] [iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        execution_logger.info(f"[iteration {iteration}] softmax train_sample_dp_score: {train_sample_dp_score=}, {len(train_sample_dp_score)=}")
                        selected_sample_indices = np.random.choice(range(len(train_sample_dp_score)), size=num_samples_schedule[iteration], p=train_sample_dp_score, replace=False)
                        print(f"[debugging] [iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")
                        execution_logger.info(f"[iteration {iteration}] softmax selected_sample_indices: {selected_sample_indices=}, {len(selected_sample_indices)=}")

                    num_samples_per_label_id = self._get_num_samples_per_label_id(
                        num_samples=num_samples_schedule[iteration],
                        fraction_per_label_id=fraction_per_label_id,
                    )

                    # Generate synthetic data for each label.
                    syn_data_list = []
                    selected_syn_data = syn_data.select_by_index(selected_sample_indices)
                    for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
                        execution_logger.info(f"Label {label_id}")
                        sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
                        sub_syn_data = selected_syn_data.filter_label_id(label_id=label_id)
                        print(f"[debugging] checking sub_syn_data len for {label_id=}: {len(sub_syn_data.data_frame)=}")
                        print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                        if self._population._initial_variation_api_fold > 0:
                            sub_syn_data = self._population.next(
                                syn_data=sub_syn_data,
                                num_samples=len(sub_syn_data.data_frame), # for each label, next generation will generate self._population._initial_variation_api_fold variantions as duplications
                                selected=True,
                            )
                        else:
                            print(f"[debugging] no variation api for {iteration=}, just randomly generate new samples")
                            sub_syn_data = self._population.initial(
                                label_info=label_info,
                                num_samples=num_samples_per_label_id[label_id],
                            )
                        sub_syn_data.set_label_id(label_id)
                        syn_data_list.append(sub_syn_data)

                        print(f"[debugging] before variation api in {iteration=} for label {label_id}, check memory: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB, {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")

                    print(f"[debugging] in <./pe/runner/pe.py> <opt_grad> {selected_syn_data=}, {type(selected_syn_data)=}")
                    if not self._population._keep_selected:
                        syn_data = Data.concat([selected_syn_data] + syn_data_list, metadata=self._priv_data.metadata)
                    else:
                        syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
                    syn_data.data_frame.reset_index(drop=True, inplace=True)
                    syn_data.metadata.iteration = iteration
                    print(f"[debugging] syn_data after concat: {len(syn_data.data_frame)=} with {syn_data.metadata.iteration=}")

                    if save_checkpoint:
                        syn_data.save_checkpoint(checkpoint_path)
                    self._log_metrics(syn_data)
                    self.syn_data_preparation(syn_data=syn_data, test_size_ratio=0.0) # with self._syn_train_data_list, self._syn_dev_data set in the function.
                ######### generate synthetic samples if needed #########                

                logging_dir = self._exp_folder + f"/llm_priv_train/dpsgd_update/"
                if not os.path.exists(logging_dir):
                    os.makedirs(logging_dir)
                print(f"num_train_epochs={int((len(num_samples_schedule)-1)*val_sample_ratio)}, {epsilon=}, {delta=}")
                llm_sft_eval_metric, llm_trained_model, _ = opacus_dpsgd_fine_tune(
                    copy.deepcopy(_llm_trained_model), self._llm._tokenizer, self._priv_train_data, output_dir=logging_dir,
                    per_device_train_batch_size=4, 
                    num_train_epochs=int((len(num_samples_schedule)-1)*val_sample_ratio), learning_rate=lr*((1/4)*len(self._priv_train_data)*val_sample_ratio/2)*_scaler, # 4000 sample -- 5e-5, 400 sample 5e-6
                    save_steps=100, logging_steps=5,
                    eval_dataset=self._priv_eval_data,
                    epsilon=epsilon, delta=delta, val_sample_ratio=val_sample_ratio,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    stop_one_batch=True,
                )
                eval_loss, eval_acc = evaluate_model_on_private_data(
                    model=llm_trained_model,
                    tokenizer=self._llm._tokenizer,
                    # dataset=self._priv_dev_data,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                execution_logger.info(f"parameter-tuning, learing_rate_scale={_scaler}")
                print(f"DP-SGD fine-tuned model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                execution_logger.info(f"LLM DP-SGD fine-tuned on base model, evaluation on test private data - Loss: {eval_loss:.5f}, Token Accuracy: {eval_acc:.5f}")
                if not folder in loss_points.keys():
                    loss_points[folder] = [(),()]
                loss_points[folder].append((eval_loss, eval_acc))

            json_path = os.path.join(self._exp_folder, "loss_points.json")
            with open(json_path, "w") as f:
                json.dump(loss_points, f)
            print(f"Saved loss_points to {json_path}")

        finally:
            self._clean_up_loggers()

        return syn_data

    def run_loss_distribution(
        self,
        num_samples_schedule,
        delta,
        epsilon=None,
        noise_multiplier=None,
        checkpoint_path=None,
        save_checkpoint=True,
        fraction_per_label_id=None,
        val_sample_ratio=0.2,
        lr=5e-5,
        _scaler=1.0,
        metric_inverse_epsilon=1E-6,
        noise_place='coefficient',
        prompt_type='RandomK',
        clip_or_normalize='normalize',
        noise_on_vote=True,
        sample_evolve=True,
        approx_strategy='opt',
        use_eigen=False,
        clip_norm=1.0,
    ):
        """Run the loss plot of each training gradient update.

        :param num_samples_schedule: The schedule of the number of samples for each PE iteration. The first element is
            the number of samples for the initial data, and the rest are the number of samples for each PE iteration.
            So the length of the list is the number of PE iterations plus one
        :type num_samples_schedule: list[int]
        :param delta: The delta value of DP
        :type delta: float
        :param epsilon: The epsilon value of DP, defaults to None
        :type epsilon: float, optional
        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None
        :type noise_multiplier: float, optional
        :param checkpoint_path: The path to load and save the checkpoint, defaults to None
        :type checkpoint_path: str, optional
        :param save_checkpoint: Whether to save the checkpoint, defaults to True
        :type save_checkpoint: bool, optional
        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be
            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the
            private data. Defaults to None
        :type fraction_per_label_id: list[float], optional
        :param val_sample_ratio: The sub-sampling ratio of private samples within each iteration, defaults to 0.2 for this method
        :type val_sample_ratio: float
        :return: The synthetic data
        :rtype: :py:class:`pe.data.Data`
        """
        loss_points = {}
        """ base llm evaluation """
        eval_loss, eval_acc = evaluate_model_by_sample(
            model=self._llm._model,
            tokenizer=self._llm._tokenizer,
            # dataset=self._priv_dev_data,
            dataset=self._priv_eval_data,
            batch_size=8,
            add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
        )
        loss_points['checkpoint-base'] = [eval_loss, eval_acc]
        print(f"Base model, evaluation done")
        """ base llm evaluation """

        """
        SFT the llm using private samples without dp
        """

        child_folders = [f for f in os.listdir(self._exp_folder) if os.path.isdir(os.path.join(self._exp_folder, f))]
        logging_folders = []
        print(f"{child_folders=}")
        for _c_folder in child_folders:
            if _c_folder[:4] == 'llm_' and (not '_syn_train' in _c_folder):
                logging_folders.append(_c_folder)
        print(f"{logging_folders=}")
        assert len(logging_folders) == 1 or len(logging_folders) == 2
        for logging_folder in logging_folders:
            _logging_dir = self._exp_folder + f"{logging_folder}/" + ("{}/llm/" if not 'train' in logging_folder else "")
            if "_priv_train" in logging_folder:
                _logging_dir += 'no_syn/'
            
            if not 'train' in logging_folder:
                # for iteration in range(0, len(num_samples_schedule)+1):
                for iteration in range(len(num_samples_schedule), len(num_samples_schedule)+1):
                    load_model_path = _logging_dir.format(iteration)
                    if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                        print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                        state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                        self._llm._model.load_state_dict(state_dict)
                    else:
                        base = transformers.AutoModelForCausalLM.from_pretrained(
                            self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                        )
                        peft_config = PeftConfig.from_pretrained(load_model_path)
                        self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                    eval_loss, eval_acc = evaluate_model_by_sample(
                        model=self._llm._model,
                        tokenizer=self._llm._tokenizer,
                        # dataset=self._priv_dev_data,
                        dataset=self._priv_eval_data,
                        batch_size=8,
                        add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                    )
                    loss_points[f'checkpoint-{iteration}'] = [eval_loss, eval_acc]
                    print(f"[Iteration {iteration}], evaluation done")
            else:
                load_model_path = _logging_dir
                if 'model.safetensors' in os.listdir(load_model_path) and ('adapter_model.safetensors' not in os.listdir(load_model_path)):
                    print(f"load from {os.path.join(load_model_path, 'model.safetensors')}")
                    state_dict = load_file(os.path.join(load_model_path, 'model.safetensors'))
                    self._llm._model.load_state_dict(state_dict)
                else:
                    base = transformers.AutoModelForCausalLM.from_pretrained(
                        self._original_llm_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True,
                    )
                    peft_config = PeftConfig.from_pretrained(load_model_path)
                    self._llm._model = PeftModel.from_pretrained(base, load_model_path, is_trainable=True)
                eval_loss, eval_acc = evaluate_model_by_sample(
                    model=self._llm._model,
                    tokenizer=self._llm._tokenizer,
                    # dataset=self._priv_dev_data,
                    dataset=self._priv_eval_data,
                    batch_size=8,
                    add_instruction=self._llm_add_instruction, instruction=self._population._api._random_api_prompt_config,
                )
                loss_points['checkpoint-sgd'] = [eval_loss, eval_acc]
                print(f"Model after (dp-)sgd training, evaluation done")

        json_path = os.path.join(self._exp_folder, "per_sample_loss_acc.json")
        with open(json_path, "w") as f:
            json.dump(loss_points, f)
        print(f"Saved per_sample_loss_acc to {json_path}")

        self._clean_up_loggers()

        return None

