import torch.nn as nn
from transformers import ViTImageProcessor
from einops import rearrange, repeat
from .dino import ViTModel


class DinoWrapper(nn.Module):
    """
    Dino v1 wrapper using huggingface transformer implementation.
    """
    def __init__(self, model_name: str, freeze: bool = True, local_files_only: bool = False):
        """
        Initialize DINO wrapper.
        
        Args:
            model_name: Hugging Face model name (e.g., "facebook/dino-vitb16") or local path
            freeze: Whether to freeze the model parameters
            local_files_only: If True, only use local cached files (no internet required)
        """
        super().__init__()
        self.model, self.processor = self._build_dino(model_name, local_files_only=local_files_only)
        self.camera_embedder = nn.Sequential(
            nn.Linear(16, self.model.config.hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
        )
        if freeze:
            self._freeze()

    def forward(self, image, camera):
        if image.ndim == 5:
            image = rearrange(image, 'b n c h w -> (b n) c h w')
        dtype = image.dtype
        inputs = self.processor(
            images=image.float(), 
            return_tensors="pt", 
            do_rescale=False, 
            do_resize=False,
        ).to(self.model.device).to(dtype)
        # embed camera
        N = camera.shape[1]
        camera_embeddings = self.camera_embedder(camera)
        camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
        embeddings = camera_embeddings

        outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states

    def _freeze(self):
        print(f"======== Freezing DinoWrapper ========")
        self.model.eval()
        for name, param in self.model.named_parameters():
            param.requires_grad = False

    @staticmethod
    def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5, local_files_only: bool = False):
        """
        Build DINO model from Hugging Face.
        
        Args:
            model_name: Hugging Face model name or local path
            proxy_error_retries: Number of retries for proxy errors
            proxy_error_cooldown: Cooldown time between retries
            local_files_only: If True, only use local cached files
        """
        import os
        import requests
        
        if os.path.isdir(model_name):
            print(f"Loading DINO model from local directory: {model_name}")
            model = ViTModel.from_pretrained(model_name, local_files_only=True, add_pooling_layer=False)
            processor = ViTImageProcessor.from_pretrained(model_name, local_files_only=True)
            return model, processor
        
        if local_files_only:
            print(f"Loading DINO model from local cache: {model_name}")
            try:
                model = ViTModel.from_pretrained(model_name, local_files_only=True, add_pooling_layer=False)
                processor = ViTImageProcessor.from_pretrained(model_name, local_files_only=True)
                print(f"✅ Successfully loaded model from cache")
                return model, processor
            except Exception as e:
                raise RuntimeError(
                    f"Cannot load model '{model_name}' from cache. "
                    f"Please download the model first or check if it exists in cache. Error: {e}"
                ) from e
        
        try:
            print(f"Loading DINO model from Hugging Face: {model_name}")
            model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
            processor = ViTImageProcessor.from_pretrained(model_name)
            return model, processor
        except (requests.exceptions.RequestException, OSError, ConnectionError) as err:
            error_msg = str(err)
            if "Network is unreachable" in error_msg or "Failed to establish" in error_msg or "[Errno 101]" in error_msg:
                print(f"⚠️  Network error: Cannot reach Hugging Face. Trying to load from cache...")
                try:
                    model = ViTModel.from_pretrained(model_name, local_files_only=True, add_pooling_layer=False)
                    processor = ViTImageProcessor.from_pretrained(model_name, local_files_only=True)
                    print("✅ Successfully loaded model from cache!")
                    return model, processor
                except Exception as cache_err:
                    print(f"❌ Failed to load from cache: {cache_err}")
                    raise RuntimeError(
                        f"Cannot load model '{model_name}': Network unreachable and model not found in cache. "
                        f"Please download the model first or check your network connection."
                    ) from err
            elif "ProxyError" in error_msg or isinstance(err, requests.exceptions.ProxyError):
                if proxy_error_retries > 0:
                    print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds... ({proxy_error_retries} retries left)")
                    import time
                    time.sleep(proxy_error_cooldown)
                    return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown, local_files_only)
                else:
                    raise err
            else:
                raise err
