import re
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from lib import nethook
import timm

# Create a wrapper class to match the expected interface
class ModelWrapper:
    def __init__(self, model):
        self.model = model
        self.config = model.default_cfg
        
        # Create a nested model structure to match the expected interface
        class NestedModel(torch.nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model

                # Handle ViT and PiT block structure
                if hasattr(model, 'transformers'):  # PiT
                    self.transformers = model.transformers
                    # Flatten all blocks from all transformer modules
                    self.blocks = []
                    for t in model.transformers:
                        self.blocks.extend(t.blocks)
                elif hasattr(model, 'blocks'):  # ViT
                    self.blocks = model.blocks
                else:
                    raise ValueError("Model must have either 'blocks' or 'transformers' attribute.")

                self.patch_embed = getattr(model, 'patch_embed', None)
            

            def forward(self, **kwargs):
                # Forward pass through the ViT model
                if 'pixel_values' in kwargs:
                    x = kwargs['pixel_values']
                    output = self.model(x)
                    # Convert output tuple to dictionary format
                    if isinstance(output, tuple):
                        return {"logits": output[0]}  # Assuming first element contains logits
                    return {"logits": output}
                else:
                    raise ValueError("ViT models expect pixel_values")
        
        self.model = NestedModel(model)

        # Add pre-forward hook to handle tuple inputs for normalization layers
        for module in self.model.modules():
            if isinstance(module, torch.nn.LayerNorm):
                module.register_forward_pre_hook(self._handle_tuple_input)

    def _handle_tuple_input(self, module, inputs):
        # inputs could be a single tuple or a tuple of tuples
        if not inputs:  # empty tuple
            return inputs
        if isinstance(inputs[0], tuple):
            return (inputs[0][0],)  # Return first element of first tuple
        if isinstance(inputs, tuple):
            return (inputs[0],)  # Return first element as single-element tuple
        return inputs

    def __call__(self, **kwargs):
        # Handle both pixel_values and input_ids for compatibility
        if 'pixel_values' in kwargs:
            return self.model(**kwargs)
        else:
            raise ValueError("No input provided")

    def forward(self, x):
        # Ensure input is a tensor before passing to model
        if isinstance(x, tuple):
            x = x[0]
        return self.model(x)

class ModelAndTokenizer:
    """
    An object to hold on to (or automatically download and hold)
    a GPT-style language model and tokenizer.  Counts the number
    of layers.
    """
    def __init__(
        self,
        model_name=None,
        model=None,
        tokenizer=None,
        low_cpu_mem_usage=False,
        torch_dtype=None,
    ):
        if tokenizer is None:
            assert model_name is not None
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        if model is None:
            assert model_name is not None
            model = AutoModelForCausalLM.from_pretrained(
                model_name, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype
            )
            nethook.set_requires_grad(False, model)
            model.eval().cuda()
        self.tokenizer = tokenizer
        self.model = model
        self.layer_names = [
            n
            for n, m in model.named_modules()
            if (re.match(r"^(transformer|gpt_neox)\.(h|layers)\.\d+$", n))
        ]
        self.num_layers = len(self.layer_names)

    def __repr__(self):
        return (
            f"ModelAndTokenizer(model: {type(self.model).__name__} "
            f"[{self.num_layers} layers], "
            f"tokenizer: {type(self.tokenizer).__name__})"
        )
