import logging
from collections import OrderedDict

import torch
import torch.nn as nn

try:
    import timm    # new timm imports >= 0.8.1
    from timm.layers import Mlp, to_2tuple
    from timm.layers import RotAttentionPool2d
    from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
    timm = None

from CLIP_utils.misc import freeze_batch_norm_2d

"""
Timm Model Adapter

This module provides adapters for integrating `timm` models into the CLIP_utils architecture, 
allowing the use of various pretrained `timm` vision models as the vision tower in CLIP architectures.
Key features include:

1. **Support for Custom Pooling and Projection**: Supports attention-based pooling ('abs_attn' and 'rot_attn')
    and custom projection layer types such as 'linear' or 'mlp'.

2. **Flexible Model Configuration**: Supports dynamic reconfiguration of pretrained `timm` models,
    including modification of classification heads, pooling strategies, and dropout settings.

3. **Locking and Freezing Mechanisms**: Provides options to freeze parts or all of the model,
    including freezing batch normalization statistics, for transfer learning or fine-tuning.

4. **Gradient Checkpointing Support**: Provides functionality to enable gradient checkpointing
    to save memory during training.

5. **Dynamic Model Initialization**: Ensures compatibility with different versions of `timm`
    by handling breaking changes in the `timm` API.

Note:
- The `timm` library must be installed to use this module. If not installed, appropriate errors will be raised.
- Batch normalization freezing is supported by the `freeze_batch_norm_2d` function from the `CLIP_utils.misc` module.
"""

class TimmModel(nn.Module):
    """ 
    Timm model adapter for CLIP vision towers.
    
    This class wraps models from the timm library to be used as the vision encoder
    in CLIP architectures, providing customizable pooling and projection options.
    """

    def __init__(
            self,
            model_name,
            embed_dim,
            image_size=224,
            pool='avg',
            proj='linear',
            proj_bias=False,
            drop=0.,
            drop_path=None,
            patch_drop=None,
            pretrained=False,
    ):
        """
        Initialize a timm model adapter.
        
        Args:
            model_name (str): Name of the timm model to create.
            embed_dim (int): Output embedding dimension.
            image_size (int): Input image size (resized if needed).
            pool (str): Type of pooling ('avg', 'max', 'abs_attn', 'rot_attn', or '').
            proj (str): Type of projection ('linear', 'mlp', or '').
            proj_bias (bool): Whether to include bias in projection layers.
            drop (float): Dropout rate.
            drop_path (float, optional): Drop path rate.
            patch_drop (float, optional): Patch dropout rate.
            pretrained (bool): Whether to load pretrained weights.
        """
        super().__init__()
        if timm is None:
            raise RuntimeError("Please `pip install timm` to use timm models.")
        self.image_size = to_2tuple(image_size)

        # setup kwargs that may not be common across all models
        timm_kwargs = {}
        if drop_path is not None:
            timm_kwargs['drop_path_rate'] = drop_path
        if patch_drop is not None:
            timm_kwargs['patch_drop_rate'] = patch_drop

        custom_pool = pool in ('abs_attn', 'rot_attn')
        if not proj and not custom_pool:
            # use network classifier head as projection if no proj specified and no custom pooling used
            self.trunk = timm.create_model(
                model_name,
                num_classes=embed_dim,
                global_pool=pool,
                pretrained=pretrained,
                **timm_kwargs,
            )
            prev_chs = embed_dim
        else:
            self.trunk = timm.create_model(
                model_name,
                pretrained=pretrained,
                **timm_kwargs,
            )
            feat_size = self.trunk.default_cfg.get('pool_size', None)
            feature_ndim = 1 if not feat_size else 2
            if custom_pool:
                assert feature_ndim == 2
                # if attn pooling used, remove both classifier and default pool
                self.trunk.reset_classifier(0, global_pool='')
            else:
                # reset global pool if pool config set, otherwise leave as network default
                reset_kwargs = dict(global_pool=pool) if pool else {}
                self.trunk.reset_classifier(0, **reset_kwargs)
            prev_chs = self.trunk.num_features

        head_layers = OrderedDict()

        # Add custom pooling to head
        if pool == 'abs_attn':
            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
            prev_chs = embed_dim
        elif pool == 'rot_attn':
            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
            prev_chs = embed_dim

        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
        if proj == 'linear':
            head_layers['drop'] = nn.Dropout(drop)
            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
        elif proj == 'mlp':
            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
        else:
            assert not proj, f'Unknown projection type {proj}.'

        self.head = nn.Sequential(head_layers)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        """
        Lock model parameters to prevent training specific parts of the model.
        
        Args:
            unlocked_groups (int): Number of layer groups to leave unlocked for training.
                                0 means lock the entire model.
            freeze_bn_stats (bool): Whether to freeze batch normalization statistics.
        
        Raises:
            RuntimeError: If partial freezing is requested but the required timm functionality
                        is not available.
        """
        if not unlocked_groups:
            # lock full model
            for param in self.trunk.parameters():
                param.requires_grad = False
            if freeze_bn_stats:
                freeze_batch_norm_2d(self.trunk)
        else:
            # NOTE: partial freeze requires latest timm (master) branch and is subject to change
            try:
                # FIXME import here until API stable and in an official release
                from timm.models.helpers import group_parameters, group_modules
            except ImportError:
                raise RuntimeError(
                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
            matcher = self.trunk.group_matcher()
            gparams = group_parameters(self.trunk, matcher)
            max_layer_id = max(gparams.keys())
            max_layer_id = max_layer_id - unlocked_groups
            for group_idx in range(max_layer_id + 1):
                group = gparams[group_idx]
                for param in group:
                    self.trunk.get_parameter(param).requires_grad = False
            if freeze_bn_stats:
                gmodules = group_modules(self.trunk, matcher, reverse=True)
                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
                freeze_batch_norm_2d(self.trunk, gmodules)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        """
        Enable or disable gradient checkpointing to save memory during training.
        
        Args:
            enable (bool): Whether to enable gradient checkpointing.
        
        Note:
            Not all timm models support gradient checkpointing. If not supported,
            a warning will be logged and the operation will be ignored.
        """
        try:
            self.trunk.set_grad_checkpointing(enable)
        except Exception as e:
            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')

    def forward(self, x):
        """
        Forward pass through the model.
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width].
            
        Returns:
            torch.Tensor: Output tensor of shape [batch_size, embed_dim].
        """
        x = self.trunk(x)
        x = self.head(x)
        return x