import os
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Union

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

@dataclass
class Aimv2VisionConfig(PretrainedConfig):
    model_type = 'aimv2_vision_model'
    base_config_key = "vision_config"

    hidden_size: int = 1024
    intermediate_size: int = 2816
    num_hidden_layers: int = 24
    num_attention_heads: int = 8
    num_channels: int = 3
    image_size: int = 224
    patch_size: int = 14
    rms_norm_eps: float = 1e-5
    attention_dropout: float = 0.0
    qkv_bias: bool = False
    mlp_bias: bool = False
    hidden_act: str = "silu"
    initializer_range: float = 0.02
    use_head: bool = True
    is_native: bool = False
    drop_path_rate: float = 0.0
    skiplink_layers: Optional[List] = None
    _attn_implementation: str = "flash_attention_2"

    def __init__(self,** kwargs):
        super().__init__(**kwargs)
        if self.skiplink_layers is None:
            self.skiplink_layers = []

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],** kwargs) -> 'PretrainedConfig':
        if pretrained_model_name_or_path in AIMV2VISION_CONFIG:
            config_dict = asdict(AIMV2VISION_CONFIG[pretrained_model_name_or_path])
            config_dict.update(kwargs)
            if "skiplink_layers" in kwargs:
                config_dict["skiplink_layers"] = kwargs.pop("skiplink_layers")
        else:
            raise NotImplementedError
        
        return cls(**config_dict)
    

    @classmethod
    def from_config(cls, config: 'Aimv2VisionConfig',** kwargs):
        config_dict = asdict(config)
        config_dict.update(kwargs)
        return cls(**config_dict)
    

    def to_dict(self) -> Dict[str, any]:
        output = super().to_dict()
        if hasattr(self, "skiplink_layers") and self.skiplink_layers is not None:
            output["skiplink_layers"] = self.skiplink_layers
        return output


class Aimv2Config(PretrainedConfig):
    pass


AIMV2VISION_CONFIG: Dict[str, Aimv2VisionConfig] = {}


AIMV2VISION_CONFIG['aimv2-large-patch14-448'] = Aimv2VisionConfig(
    image_size=448, 
    hidden_size=1024,
    intermediate_size=2816,
    num_hidden_layers=24,
    num_attention_heads=8,
    num_channels=3,
    patch_size=14,
    rms_norm_eps=1e-5,
    attention_dropout=0.0,
    qkv_bias=False,
    mlp_bias=False,
    hidden_act="silu",
    initializer_range=0.02,
    use_head=False,
    is_native=False,
)


AIMV2VISION_CONFIG['aimv2-large-patch14-native'] = Aimv2VisionConfig(
    image_size=224,
    hidden_size=1024,
    intermediate_size=2816,
    num_hidden_layers=24,
    num_attention_heads=8,
    num_channels=3,
    patch_size=14,
    rms_norm_eps=1e-5,
    attention_dropout=0.0,
    qkv_bias=False,
    mlp_bias=False,
    hidden_act="silu",
    initializer_range=0.02,
    use_head=False,
    is_native=True,
)
