from dataclasses import dataclass, field
from typing import Callable, Optional, Tuple, List
import types

import torch
import torch.nn.functional as F

from megatron.core.transformer import TransformerConfig


@dataclass
class GpatchTransformerConfig(TransformerConfig):
                        
           
                        
    model_arch: str = None
    hf_vocab_size: int = None

                        
         
    num_experts: int = None
    moe_norm_topk_prob: bool = False
    moe_norm_topk_prob_eps: float = 0.0

                        

                        
            
                        
    dpo: bool = False
    dpo_reward_models_cnt: int = 0
    dpo_policy_ref_model_cnt: int = 2

                        
           
                        
    rm_use_triplet_loss: bool = False
    rm_golden_margin: float = 0.5
    rm_triplet_coef: float = 0.5
    rm_triplet_focal_coef: List[float] = None
    rm_focal_loss_coef: List[float] = None
    rm_focal_loss_ranking_coef: List[float] = None
    rm_num_attributes: int = 1

                        
            
                        
    global_batch_size: int = -1
    micro_batch_size: int = -1
    ppo_loss_clip_val: float = 0.2
    ppo_enable_standardization: bool = False
    ppo_entropy_bonus: float = 0.0
    ppo_ratio_eps: float = 0.2
    ppo_dual_clip_ratio_c: float = None
    ppo_to_offload_adam_states: bool = False
    ppo_reward_clip_val: float = 1.
    ppo_value_truncate_head: bool = False
    use_grpo: bool = False
    grpo_advantage_epsilon: float = 1e-6
    grpo_kl_loss_beta: float = 1e-3
    rm_head_arch: str = 'single_layer'
    ppo_grpo_reward_type: str = "rm_only"
    ppo_rm_reward_alpha: float = 1.0
    ppo_rule_reward_beta: float = 1.0
    ppo_sampling_repeat: int = 1
    ppo_clip_ratio_low: float = None
    ppo_clip_ratio_high: float = None
    use_gspo_loss: bool = False
    ppo_clamp_kl_val: float = None
    ppo_logps_ratio_clamp: float = None
    ppo_resp_seq_len: int = 512
    dapo_overlong_penalty: bool = False
    dapo_overlong_buffer_len: int = 0
    dapo_overlong_penalty_factor: float = 0.0

                        
            
                        
    use_gen_rm: bool = False
    ppo_gen_rm_repeat: int = 1

                        
                       
                        
                                                           
    packed_freqs: bool = False
                                
    mrope_section: List[int] = field(default_factory=lambda: [])

                        
                
                        
    context_parallel_heads_kv_stride: int = None

                                                 
                                
                                                 
    lora_r: int = 8
    lora_alpha: int = 16


@dataclass
class Gemma3TransformerConfig(GpatchTransformerConfig):
    mm_projector_cls: Callable = None
    """Class to use for the mm projector."""

    image_size: int = 896
    patch_size: int = 14
    mm_tokens_per_image: int = 256

    embed_scale: float = 1.0
    rope_local_base_freq: float = 10000.0
    sliding_window_pattern: int = 6
    sliding_window: int = 1024
    query_pre_attn_scalar: int = 256
