from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Union

import torch.nn as nn
from trl import ScriptArguments, ModelConfig, SFTConfig
from transformers.trainer import PreTrainedModel
from redflag.data_utils import DEFAULT_SAMPLER_CFG
from redflag.sft_trainer_utils import SeqLossWeighting

# TODO
ALL_DATASETS = ["harm_complete", "harm_refuse", "benign"]
GEOMETRIC_INIT = {"probs": 0.1}


@dataclass
class RedFlagScriptArguments(ScriptArguments):
    train_datasets: List[str] = field(default_factory=lambda: ALL_DATASETS)
    eval_datasets: Dict[str, Any] = field(default_factory=lambda: ALL_DATASETS)
    insert_sampler: dict = field(default_factory=lambda: DEFAULT_SAMPLER_CFG)
    drop_rf_proba: float = 0.0
    resume_checkpoint: bool = True
    restart_count: int = 0
    load_dotenv: bool = False


@dataclass
class RedFlagModelConfig(ModelConfig):
    init_embed: dict | None = None
    pad_token_id: int | None = None


@dataclass
class AdvAttackConfig:
    iters: int = 8
    eps: float = 0.005
    opt_config: dict = field(default_factory=lambda: {"type": "sign", "lr": 1.0e-04})
    prefill_length: int = 24
    attack_precision: str = "auto"  # "auto", "float32", "bfloat16", "float16"
    maximize_loss_weight: float = 0.0


@dataclass
class EMAConfig:
    """Configuration for Exponential Moving Average (EMA) during training."""
    use_ema: bool = False
    ema_decay: float = 0.999
    ema_min_decay: float = 0.0
    ema_update_after_step: int = 0
    ema_use_warmup: bool = True
    ema_inv_gamma: float = 1.0
    ema_power: float = 2/3


@dataclass
class RedFlagConfig(SFTConfig):
    ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None
    utility_loss_mode: str = "kl"
    rf_xent_mode: str = "rf_only"
    alpha_rf_xent: float = 1.0
    alpha_kl_redflag: float = 1.0
    alpha_kl_ref: float = 1.0
    rf_xent_cutoff: float = 0.5
    ref_model_init_kwargs: Optional[Dict] = None
    alpha_away_rf: float = 0.0
    away_rf_cutoff: float = -5.0
    use_base_model_as_ref: bool = True
    copy_base_model_as_ref: bool = False
    kl_fix: bool = True
    max_length: bool | None = None
    drop_prompt_attn_mask_prob: float = 0.0
    kl_weighting: Optional[SeqLossWeighting] = None
    xent_weighting: Optional[SeqLossWeighting] = None
    adv_attack: Optional[AdvAttackConfig] = field(default_factory=lambda: None)
    ema_config: Optional[EMAConfig] = field(default_factory=lambda: None)
