from typing import Union, List
import gc
import torch


def str2torch_dtype(dtype: str, default=torch.float32) -> torch.dtype:
    """Convert a string representation of a data type to the corresponding PyTorch data type.

    Args:
        dtype: String specifying the data type (e.g., "fp16", "bf16", "fp32").
        default: Default PyTorch data type to return if the input string is not recognized.

    Returns:
        PyTorch data type corresponding to the input string or the default data type.
    """
    if dtype == "fp16":
        return torch.float16
    elif dtype == "bf16":
        return torch.bfloat16
    elif dtype == "fp32":
        return torch.float32
    return default


def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]],
                         dtype=torch.float32) -> list:
    """Cast trainable parameters of a model or list of models to the specified data type.

    Args:
        model: Single PyTorch module or list of modules whose parameters will be cast.
        dtype: Target data type to cast the trainable parameters to.

    Returns:
        List of trainable parameters that were cast to the specified data type.
    """
    if not isinstance(model, list):
        model = [model]
    training_params = []
    for m in model:
        for param in m.parameters():
            # only upcast trainable parameters into fp32
            if param.requires_grad:
                param.data = param.to(dtype)
                training_params.append(param)
    return training_params


def flush_vram():
    """Clear GPU memory by emptying the CUDA cache and collecting garbage."""
    torch.cuda.empty_cache()
    gc.collect()


def quantization(model):
    """Quantize a model's weights using the optimum.quanto library.

    Args:
        model: PyTorch model whose weights will be quantized to float8 precision.
    """
    from optimum.quanto import freeze, qfloat8, quantize, QTensor
    quantization_type = qfloat8
    print("Quantizing transformer")
    quantize(model, weights=quantization_type)
    freeze(model)
    flush_vram()
