import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional

from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available


@dataclass
class DDPOConfig:
    """
    Configuration class for DDPOTrainer
    """

    # common parameters
    exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
    """the name of this experiment (by default is the file name without the extension name)"""
    run_name: Optional[str] = ""
    """Run name for wandb logging and checkpoint saving."""
    seed: int = 0
    """Seed value for random generations"""
    log_with: Optional[Literal["wandb", "tensorboard"]] = None
    """Log with either 'wandb' or 'tensorboard', check  https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
    tracker_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the tracker (e.g. wandb_project)"""
    accelerator_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the accelerator"""
    project_kwargs: dict = field(default_factory=dict)
    """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
    tracker_project_name: str = "trl"
    """Name of project to use for tracking"""
    logdir: str = "logs"
    """Top-level logging directory for checkpoint saving."""

    # hyperparameters
    num_epochs: int = 100
    """Number of epochs to train."""
    save_freq: int = 1
    """Number of epochs between saving model checkpoints."""
    num_checkpoint_limit: int = 5
    """Number of checkpoints to keep before overwriting old ones."""
    mixed_precision: str = "fp16"
    """Mixed precision training."""
    allow_tf32: bool = True
    """Allow tf32 on Ampere GPUs."""
    resume_from: Optional[str] = ""
    """Resume training from a checkpoint."""
    sample_num_steps: int = 50
    """Number of sampler inference steps."""
    sample_eta: float = 1.0
    """Eta parameter for the DDIM sampler."""
    sample_guidance_scale: float = 5.0
    """Classifier-free guidance weight."""
    sample_batch_size: int = 1
    """Batch size (per GPU!) to use for sampling."""
    sample_num_batches_per_epoch: int = 2
    """Number of batches to sample per epoch."""
    train_batch_size: int = 1
    """Batch size (per GPU!) to use for training."""
    train_use_8bit_adam: bool = False
    """Whether to use the 8bit Adam optimizer from bitsandbytes."""
    train_learning_rate: float = 3e-4
    """Learning rate."""
    train_adam_beta1: float = 0.9
    """Adam beta1."""
    train_adam_beta2: float = 0.999
    """Adam beta2."""
    train_adam_weight_decay: float = 1e-4
    """Adam weight decay."""
    train_adam_epsilon: float = 1e-8
    """Adam epsilon."""
    train_gradient_accumulation_steps: int = 1
    """Number of gradient accumulation steps."""
    train_max_grad_norm: float = 1.0
    """Maximum gradient norm for gradient clipping."""
    train_num_inner_epochs: int = 1
    """Number of inner epochs per outer epoch."""
    train_cfg: bool = True
    """Whether or not to use classifier-free guidance during training."""
    train_adv_clip_max: float = 5
    """Clip advantages to the range."""
    train_clip_range: float = 1e-4
    """The PPO clip range."""
    train_timestep_fraction: float = 1.0
    """The fraction of timesteps to train on."""
    per_prompt_stat_tracking: bool = False
    """Whether to track statistics for each prompt separately."""
    per_prompt_stat_tracking_buffer_size: int = 16
    """Number of reward values to store in the buffer for each prompt."""
    per_prompt_stat_tracking_min_count: int = 16
    """The minimum number of reward values to store in the buffer."""
    async_reward_computation: bool = False
    """Whether to compute rewards asynchronously."""
    max_workers: int = 2
    """The maximum number of workers to use for async reward computation."""
    negative_prompts: Optional[str] = ""
    """Comma-separated list of prompts to use as negative examples."""

    def to_dict(self):
        output_dict = {}
        for key, value in self.__dict__.items():
            output_dict[key] = value
        return flatten_dict(output_dict)

    def __post_init__(self):
        if self.log_with not in ["wandb", "tensorboard"]:
            warnings.warn(
                "Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
            )

        if self.log_with == "wandb" and not is_torchvision_available():
            warnings.warn("Wandb image logging requires torchvision to be installed")

        if self.train_use_8bit_adam and not is_bitsandbytes_available():
            raise ImportError(
                "You need to install bitsandbytes to use 8bit Adam. "
                "You can install it with `pip install bitsandbytes`."
            )
