import os
import warnings
from typing import List, Optional, Union

import torch

from CLIP_utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from CLIP_utils.model import build_model_from_openai_state_dict, get_cast_dtype, convert_weights_to_lp
from CLIP_utils.pretrained import *

"""
OpenAI pretrained model functions

Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.

Modifications in `CLIP_utils` compared to the original OpenAI implementation:
1. **Extended Model Management**:
   - Introduced functions like `list_openai_models()` to enumerate available models and manage pretrained models.
   - Enabled dynamic construction of models via `_build_vision_tower` and `_build_text_tower`, supporting customized configurations.

2. **Precision Management**:
   - Added support for precision selection (`fp32`, `fp16`, `bf16`) through the `precision` argument.
   - Introduced `convert_weights_to_lp()` to enable conversion to low-precision (e.g., `bf16`) for improved efficiency on hardware like TPUs or newer GPUs.

3. **Device Flexibility**:
   - Enhanced device management with explicit device assignment using `device` argument (e.g., `cpu`, `cuda`).

4. **Pretrained Model Support**:
   - Integrated support for managing pretrained model checkpoints via `pretrained.py`.
   - Cache management added via `cache_dir` argument, allowing custom directory paths for storing downloaded weights.

5. **Compatibility with Non-JIT Models**:
   - Extended compatibility to handle both JIT-based models and non-JIT state_dict-based models.
   - Automatically converts state_dict models to match OpenAI's format using `build_model_from_openai_state_dict()`.

6. **Enhanced Preprocessing**:
   - Added `OPENAI_DATASET_MEAN` and `OPENAI_DATASET_STD` as attributes to the model for consistency with preprocessing pipelines.

7. **Resizing Positional Embeddings**:
   - Added functionality to resize positional embeddings (`resize_pos_embed()`) dynamically when the grid size of the positional embedding does not match the pretrained state_dict.

These enhancements make the CLIP model framework more flexible, extensible, and user-friendly for a variety of use cases.
"""


__all__ = ["list_openai_models", "load_openai_model"]


def list_openai_models() -> List[str]:
    """
    Returns the names of available OpenAI CLIP models.
    
    This function queries the pretrained model registry to find all models
    tagged as 'openai' models, which are the original models released by OpenAI.
    
    Returns:
        List[str]: A list of model names that can be passed to `load_openai_model()`.
    """
    return list_pretrained_models_by_tag('openai')


def load_openai_model(
        name: str,
        precision: Optional[str] = None,
        device: Optional[Union[str, torch.device]] = None,
        cache_dir: Optional[str] = None,
):
    """
    Load a CLIP model from OpenAI pretrained weights.
    
    This function handles both JIT-compiled models and state dictionary based models.
    It automatically handles precision conversion and device placement.
    
    Args:
        name: A model name listed by `list_openai_models()`, or the path to a model
              checkpoint containing the state_dict.
        precision: Model precision, can be 'fp32', 'fp16', or 'bf16'. If None, defaults 
                  to 'fp32' if device is 'cpu' else 'fp16'.
        device: The device to put the loaded model on. If None, uses CUDA if available,
                otherwise CPU.
        cache_dir: The directory to cache the downloaded model weights. If None, uses
                  the default cache directory.
    
    Returns:
        torch.nn.Module: The loaded CLIP model, ready for inference or fine-tuning.
        
    Raises:
        RuntimeError: If the specified model name is not found in the registry of
                     available models and is not a valid file path.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if precision is None:
        precision = 'fp32' if device == 'cpu' else 'fp16'

    if get_pretrained_url(name, 'openai'):
        model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        state_dict = torch.load(model_path, map_location="cpu")

    # Build a non-jit model from the OpenAI jitted model state dict
    cast_dtype = get_cast_dtype(precision)
    try:
        model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
    except KeyError:
        sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
        model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)

    # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
    model = model.to(device)
    # FIXME support pure fp16/bf16 precision modes
    if precision != 'fp16':
        model.float()
        if precision == 'bf16':
            # for bf16, convert back to low-precision
            convert_weights_to_lp(model, dtype=torch.bfloat16)

    # add mean / std attributes for consistency with OpenCLIP models
    model.visual.image_mean = OPENAI_DATASET_MEAN
    model.visual.image_std = OPENAI_DATASET_STD
    return model