import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import torch

from CLIP_utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from CLIP_utils.model import CLIP, convert_to_custom_text_state_dict, \
    resize_pos_embed, get_cast_dtype
from CLIP_utils.openai_models import load_openai_model
from CLIP_utils.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, \
    list_pretrained_tags_by_model, download_pretrained_from_hf
from CLIP_utils.transform import image_transform, AugmentationCfg
from CLIP_utils.tokenizer import HFTokenizer, tokenize


HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs


def _natural_key(string_):
    """
    Natural sort key function for sorting strings with embedded numbers.
    
    Args:
        string_: The string to convert to a natural sort key.
        
    Returns:
        A list where digits are converted to integers for natural sorting.
    """
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def _rescan_model_configs():
    """
    Scan all model configuration files and populate the global model config registry.
    Reads JSON config files from the config paths and validates they contain required fields.
    """
    global _MODEL_CONFIGS

    config_ext = ('.json',)
    config_files = []
    for config_path in _MODEL_CONFIG_PATHS:
        if config_path.is_file() and config_path.suffix in config_ext:
            config_files.append(config_path)
        elif config_path.is_dir():
            for ext in config_ext:
                config_files.extend(config_path.glob(f'*{ext}'))

    for cf in config_files:
        with open(cf, 'r') as f:
            model_cfg = json.load(f)
            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
                _MODEL_CONFIGS[cf.stem] = model_cfg

    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}


_rescan_model_configs()  # initial populate of model config registry


def list_models():
    """ 
    Enumerate available model architectures based on config files.
    
    Returns:
        List of available model architecture names.
    """
    return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
    """ 
    Add model config path or file and update registry.
    
    Args:
        path: Path to a model configuration file or directory containing config files.
    """
    if not isinstance(path, Path):
        path = Path(path)
    _MODEL_CONFIG_PATHS.append(path)
    _rescan_model_configs()


def get_model_config(model_name):
    """
    Get model configuration for the specified model name.
    
    Args:
        model_name: Name of the model to retrieve configuration for.
        
    Returns:
        A deep copy of the model configuration dict if found, otherwise None.
    """
    if model_name in _MODEL_CONFIGS:
        return deepcopy(_MODEL_CONFIGS[model_name])
    else:
        return None


def get_tokenizer(model_name):
    """
    Get the appropriate tokenizer for the specified model.
    
    Args:
        model_name: Name of the model to get tokenizer for.
        
    Returns:
        A tokenizer instance (either HFTokenizer or a tokenize function).
    """
    if model_name.startswith(HF_HUB_PREFIX):
        tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
    else:
        config = get_model_config(model_name)
        tokenizer = HFTokenizer(
            config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
    return tokenizer


def load_state_dict(checkpoint_path: str, map_location='cpu'):
    """
    Load a state dictionary from a checkpoint file.
    
    Args:
        checkpoint_path: Path to the checkpoint file.
        map_location: Device to map the loaded model to.
        
    Returns:
        The state dictionary from the checkpoint.
    """
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    if next(iter(state_dict.items()))[0].startswith('module'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
    """
    Load a checkpoint into a model.
    
    Args:
        model: The model to load the checkpoint into.
        checkpoint_path: Path to the checkpoint file.
        strict: Whether to strictly enforce that the keys in state_dict match the keys in model.
        
    Returns:
        A ReturningIncompatibleKeys object containing incompatible keys.
    """
    state_dict = load_state_dict(checkpoint_path)
    # detect old format and make compatible with new format
    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
        state_dict = convert_to_custom_text_state_dict(state_dict)
    resize_pos_embed(state_dict, model)
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
    return incompatible_keys


def create_model(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
):
    """
    Create a CLIP model with the specified options.
    
    Args:
        model_name: Name of the model architecture to create.
        pretrained: Name of pretrained weights to load or path to checkpoint file.
        precision: Numerical precision - one of 'fp32', 'fp16', 'bf16', 'pure_fp16', 'pure_bf16', or 'amp'.
        device: Device to create the model on.
        jit: Whether to apply torch.jit.script to the model.
        force_quick_gelu: Force use of QuickGELU activation.
        force_custom_text: Force use of CustomTextCLIP.
        force_patch_dropout: Override the model config with a specific patch dropout value.
        force_image_size: Override the model config with a specific image size.
        pretrained_image: Whether to load pretrained weights into image tower (only for timm models).
        pretrained_hf: Whether to load pretrained weights for HuggingFace text models.
        cache_dir: Directory to cache downloaded model weights.
        output_dict: Whether the model should return a dict of outputs.
        require_pretrained: Whether to raise an error if pretrained weights are not found.
        
    Returns:
        A CLIP model instance.
    """
    HF_HUB_PREFIX = "hf_hub:"
    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
    if has_hf_hub_prefix:
        # Get hf_hub path from pretrained config
        pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
        if not pretrained_cfg:
            raise RuntimeError(f"Pretrained config not found for {model_name}-{pretrained}")
        hf_hub = pretrained_cfg.get('hf_hub', '')
        if not hf_hub:
            raise ValueError(f"'hf_hub' not specified in pretrained config for {model_name}-{pretrained}")
        model_id = hf_hub.strip('/')

        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
        config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)

        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        pretrained_cfg = config['preprocess_cfg']
        model_cfg = config['model_cfg']
    else:
        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
        checkpoint_path = None
        pretrained_cfg = {}
        model_cfg = None

    if isinstance(device, str):
        device = torch.device(device)

    if pretrained and pretrained.lower() == 'openai':
        logging.info(f'Loading pretrained {model_name} from OpenAI.')
        model = load_openai_model(
            model_name,
            precision=precision,
            device=device,
            cache_dir=cache_dir,
        )
    else:
        model_cfg = model_cfg or get_model_config(model_name)
        if model_cfg is not None:
            logging.info(f'Loaded {model_name} model config.')
        else:
            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
            raise RuntimeError(f'Model config for {model_name} not found.')

        if force_quick_gelu:
            # override for use of QuickGELU on non-OpenAI transformer models
            model_cfg["quick_gelu"] = True

        if force_patch_dropout is not None:
            # override the default patch dropout value
            model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout

        if force_image_size is not None:
            # override model config's image size
            model_cfg["vision_cfg"]["image_size"] = force_image_size

        is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
        if pretrained_image:
            if is_timm_model:
                # pretrained weight loading for timm models set via vision_cfg
                model_cfg['vision_cfg']['timm_model_pretrained'] = True
            else:
                assert False, 'pretrained image towers currently only supported for timm models'

        # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
        cast_dtype = get_cast_dtype(precision)
        is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
        custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

        if custom_text:
            if is_hf_model:
                model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
            if "coca" in model_name:
                raise ValueError('Coca is not implemented')
                model = CoCa(**model_cfg, cast_dtype=cast_dtype)
            else:
                raise ValueError('CustomTextCLIP is not implemented')
                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
        else:
            model = CLIP(**model_cfg, cast_dtype=cast_dtype)

        if precision in ("fp16", "bf16"):
            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
            # manual mixed precision that matches original OpenAI behaviour
            if is_timm_model:
                # FIXME this is a bit janky, create timm based model in low-precision and
                # then cast only LayerNormFp32 instances back to float32 so they don't break.
                # Why? The convert_weights_to_lp fn only works with native models.
                model.to(device=device, dtype=dtype)
                from CLIP_utils.transformer import LayerNormFp32
                def _convert_ln(m):
                    if isinstance(m, LayerNormFp32):
                        m.weight.data = m.weight.data.to(torch.float32)
                        m.bias.data = m.bias.data.to(torch.float32)
                model.apply(_convert_ln)
            else:
                model.to(device=device)
                convert_weights_to_lp(model, dtype=dtype)
        elif precision in ("pure_fp16", "pure_bf16"):
            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
            model.to(device=device, dtype=dtype)
        else:
            model.to(device=device)

        pretrained_loaded = False
        if pretrained:
            checkpoint_path = ''
            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
            if pretrained_cfg:
                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
            elif os.path.exists(pretrained):
                checkpoint_path = pretrained

            if checkpoint_path:
                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
                load_checkpoint(model, checkpoint_path)
            else:
                error_str = (
                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'
                    f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
                logging.warning(error_str)
                raise RuntimeError(error_str)
            pretrained_loaded = True
        elif has_hf_hub_prefix:
            logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
            load_checkpoint(model, checkpoint_path)
            pretrained_loaded = True

        if require_pretrained and not pretrained_loaded:
            # callers of create_model_from_pretrained always expect pretrained weights
            raise RuntimeError(
                f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')

        # set image / mean metadata from pretrained_cfg if available, or use default
        model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
        model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

    if output_dict and hasattr(model, "output_dict"):
        model.output_dict = True

    if jit:
        model = torch.jit.script(model)

    return model


def create_loss(args):
    """
    Create a loss function based on command-line arguments.
    
    Args:
        args: Command-line arguments containing loss configuration.
        
    Returns:
        An instance of the appropriate loss function.
    """
    if args.distill:
        return DistillClipLoss(
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    elif "coca" in args.model.lower():
        return CoCaLoss(
            caption_loss_weight=args.coca_caption_loss_weight,
            clip_loss_weight=args.coca_contrastive_loss_weight,
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    return ClipLoss(
        local_loss=args.local_loss,
        gather_with_grad=args.gather_with_grad,
        cache_labels=True,
        rank=args.rank,
        world_size=args.world_size,
        use_horovod=args.horovod,
    )


def create_model_and_transforms(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
):
    """
    Create a CLIP model along with transforms for training and validation.
    
    Args:
        model_name: Name of the model architecture to create.
        pretrained: Name of pretrained weights to load or path to checkpoint file.
        precision: Numerical precision for model computation.
        device: Device to create the model on.
        jit: Whether to apply torch.jit.script to the model.
        force_quick_gelu: Force use of QuickGELU activation.
        force_custom_text: Force use of CustomTextCLIP.
        force_patch_dropout: Override the model config with a specific patch dropout value.
        force_image_size: Override the model config with a specific image size.
        pretrained_image: Whether to load pretrained weights into image tower.
        pretrained_hf: Whether to load pretrained weights for HuggingFace text models.
        image_mean: Image normalization mean values.
        image_std: Image normalization std values.
        aug_cfg: Configuration for image augmentations.
        cache_dir: Directory to cache downloaded model weights.
        output_dict: Whether the model should return a dict of outputs.
        
    Returns:
        A tuple containing (model, train_transform, val_transform).
    """
    model = create_model(
        model_name,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_patch_dropout=force_patch_dropout,
        force_image_size=force_image_size,
        pretrained_image=pretrained_image,
        pretrained_hf=pretrained_hf,
        cache_dir=cache_dir,
        output_dict=output_dict,
    )

    image_mean = image_mean or getattr(model.visual, 'image_mean', None)
    image_std = image_std or getattr(model.visual, 'image_std', None)
    preprocess_train = image_transform(
        model.visual.image_size,
        is_train=True,
        mean=image_mean,
        std=image_std,
        aug_cfg=aug_cfg,
    )
    preprocess_val = image_transform(
        model.visual.image_size,
        is_train=False,
        mean=image_mean,
        std=image_std,
    )

    return model, preprocess_train, preprocess_val


def create_model_from_pretrained(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        return_transform: bool = True,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        cache_dir: Optional[str] = None,
):
    """
    Create a CLIP model from pretrained weights with optional transform.
    
    Args:
        model_name: Name of the model architecture to create.
        pretrained: Name of pretrained weights to load or path to checkpoint file.
        precision: Numerical precision for model computation.
        device: Device to create the model on.
        jit: Whether to apply torch.jit.script to the model.
        force_quick_gelu: Force use of QuickGELU activation.
        force_custom_text: Force use of CustomTextCLIP.
        force_image_size: Override the model config with a specific image size.
        return_transform: Whether to return a preprocess transform along with the model.
        image_mean: Image normalization mean values.
        image_std: Image normalization std values.
        cache_dir: Directory to cache downloaded model weights.
        
    Returns:
        Either the model or a tuple of (model, preprocess_transform).
    """
    model = create_model(
        model_name,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_image_size=force_image_size,
        cache_dir=cache_dir,
        require_pretrained=True,
    )

    if not return_transform:
        return model

    image_mean = image_mean or getattr(model.visual, 'image_mean', None)
    image_std = image_std or getattr(model.visual, 'image_std', None)
    preprocess = image_transform(
        model.visual.image_size,
        is_train=False,
        mean=image_mean,
        std=image_std,
    )

    return model, preprocess