from typing import Optional
from src.model_loading.common.enums.model_enums import BitPrecision, ModelFamily, ModelSize, QuantizationMethod
from src.model_loading.common.models.identifier import ModelIdentifier


class ModelStringifier:
    """Handles conversion between ModelIdentifiers and their string representations"""
    
    @staticmethod
    def to_string(model: ModelIdentifier) -> str:
        """Convert a ModelIdentifier to its string representation"""
        components = []
        
        # Add model family
        family_str = model.family.name.lower()
        if model.family == ModelFamily.LLAMA3:
            family_str = "llama3"
        elif model.family == ModelFamily.LLAMA2:
            family_str = "llama2"
        elif model.family == ModelFamily.LLAMA32:
            family_str = "llama32"
        elif model.family == ModelFamily.LLAMA31:
            family_str = "llama31"
        elif model.family == ModelFamily.OPT:
            family_str = "opt"
        components.append(family_str)
        
        # Add size
        if model.size == ModelSize.M125:
            components.append("125m")
        elif model.size == ModelSize.M350:
            components.append("350m")
        elif model.size == ModelSize.B1_3:
            components.append("1.3b")
        elif model.size == ModelSize.B2_7:
            components.append("2.7b")
        elif model.size == ModelSize.B6_7:
            components.append("6.7b")
        elif model.size == ModelSize.B13:
            components.append("13b")
        if model.size == ModelSize.B1:
            components.append("1b")
        elif model.size == ModelSize.B3:
            components.append("3b")
        elif model.size == ModelSize.B7:
            components.append("7b")
        elif model.size == ModelSize.B8:
            components.append("8b")
        elif model.size == ModelSize.B70:
            components.append("70b")
            
        # Add chat indicator
        if model.is_chat:
            components.append("chat")
            
        # Add quantization method and bits
        if model.quantization:
            if model.quantization == QuantizationMethod.AQLM_PV:
                components.append("aqlm-pv")  # Use hyphen instead of underscore
            else:
                components.append(model.quantization.name.lower())
                
            if model.bits:
                components.append(f"{model.bits.value}bit")
                
        # Add local indicator
        if model.is_local:
            components.append("local")
            
        return "_".join(components)

    @staticmethod
    def from_string(model_str: str) -> Optional[ModelIdentifier]:
        """Convert a string representation to a ModelIdentifier using factory method"""
        # Pre-process: Replace "aqlm_pv" with "aqlm-pv" to avoid splitting issues
        if "aqlm_pv" in model_str.lower():
            model_str = model_str.lower().replace("aqlm_pv", "aqlm-pv")
        
        components = model_str.lower().split("_")
        
        # Parse model family
        family_map = {
            "tinyllama": ModelFamily.TINYLLAMA,
            "llama2": ModelFamily.LLAMA2,
            "llama3": ModelFamily.LLAMA3,
            "llama31": ModelFamily.LLAMA31,
            "llama32": ModelFamily.LLAMA32,
            "bloomz": ModelFamily.BLOOMZ,
            "gpt2": ModelFamily.GPT2,
            "opt": ModelFamily.OPT,
        }
        family = family_map.get(components[0])
        if not family:
            return None
            
        # Parse size
        size_map = {
            "125m": ModelSize.M125,
            "350m": ModelSize.M350,
            "1.3b": ModelSize.B1_3,
            "2.7b": ModelSize.B2_7,
            "6.7b": ModelSize.B6_7,
            "13b": ModelSize.B13,
            "1b": ModelSize.B1,
            "3b": ModelSize.B3,
            "7b": ModelSize.B7,
            "8b": ModelSize.B8,
            "70b": ModelSize.B70
        }
        size = None
        for component in components[1:]:
            if component in size_map:
                size = size_map[component]
                break
        if not size:
            return None
            
        # Parse optional fields
        is_chat = "chat" in components
        is_local = "local" in components
        
        # Parse quantization method
        quant_map = {
            "bnb": QuantizationMethod.BNB,
            "awq": QuantizationMethod.AWQ,
            "gptq": QuantizationMethod.GPTQ,
            "hqq": QuantizationMethod.HQQ,
            "quanto": QuantizationMethod.QUANTO,
            "aqlm-pv": QuantizationMethod.AQLM_PV,  # Use hyphen version
            "aqlm": QuantizationMethod.AQLM,
            "qoq": QuantizationMethod.QOQ,
            "quarot": QuantizationMethod.QUAROT
        }
        quantization = None
        for component in components:
            if component in quant_map:
                quantization = quant_map[component]
                break
                
        # Parse bit precision
        bits = None
        for component in components:
            if component.endswith("bit"):
                try:
                    bit_value = int(component[:-3])
                    bit_map = {
                        1: BitPrecision.INT1,
                        2: BitPrecision.INT2,
                        3: BitPrecision.INT3,
                        4: BitPrecision.INT4,
                        8: BitPrecision.INT8
                    }
                    bits = bit_map.get(bit_value)
                except ValueError:
                    pass
        
        # Use the factory method to create the ModelIdentifier
        return ModelIdentifier.from_components(
            family=family,
            size=size,
            is_chat=is_chat,
            quantization=quantization,
            bits=bits,
            is_local=is_local
        )
