import torch
import torch.nn as nn
import timm
import torchvision.transforms as T
from abc import ABC, abstractmethod

try:
    from transformers import AutoModel, AutoTokenizer, AutoConfig
except ImportError:
    print("Warning: 'transformers' library not found. NLPEncoderManager will not be available.")
    print("Please install it: pip install transformers")
    AutoModel = None

FEATURE_DIM_MAP = {
    ("vit", "tiny"): 192,
    ("vit", "small"): 384,
    ("vit", "base"): 768,
    ("vit", "large"): 1024,
    ("vit", "huge"): 1280,

    ("dinov2", "small"): 384,
    ("dinov2", "base"): 768,
    ("dinov2", "large"): 1024,
    ("dinov2", "giant"): 1536,
    
    ("resnet", "18"): 512,
    ("resnet", "34"): 512,
    ("resnet", "50"): 2048,
    ("resnet", "101"): 2048,
    
    ("efficientnet", "b0"): 1280,
    ("efficientnet", "b1"): 1280,
    
    ("convnext", "tiny"): 768,
    ("convnext", "small"): 768,
    ("convnext", "base"): 1024,
    ("convnext", "large"): 1536,
}

class BaseEncoderManager(ABC):
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device
        self.ecfg = cfg.get("encoder", {}) 

        self.encoder_type    = self.ecfg.get("type", "vit").lower()
        self.from_pretrained = bool(self.ecfg.get("from_pretrained", True))
        self.tune            = bool(self.ecfg.get("tune", False))
        self.precompute      = bool(self.ecfg.get("precompute", True))
        self.save_cache      = bool(self.ecfg.get("use_cache", True))
        self.precompute_batch_size = int(self.ecfg.get("precompute_batch_size", 256))

        self.model = None
        self.feature_dim = 0

        self._load_model()

        if not self.tune:
            for p in self.model.parameters():
                p.requires_grad = False
        
        self.model.eval()

    @abstractmethod
    def _load_model(self):
        raise NotImplementedError

    @abstractmethod
    def get_transform(self):
        raise NotImplementedError

    @abstractmethod
    @torch.no_grad()
    def encode(self, x):
        self.model.eval()
        raise NotImplementedError

    def can_precompute(self):
        return (not self.tune) and self.precompute

    def get_info(self):
        info = self.ecfg.copy()
        info["feature_dim"] = self.feature_dim
        return info

class VisionEncoderManager(BaseEncoderManager):
    def _load_model(self):
        self.model_size = self.ecfg.get("size", "base").lower()
        self.patch_size = int(self.ecfg.get("patch", 16))

        model_name = self._resolve_model_name(
            self.encoder_type, self.model_size, self.patch_size
        )
        
        self.model = timm.create_model(
            model_name,
            pretrained=self.from_pretrained,
            num_classes=0, 
        ).to(self.device)

        self.feature_dim = FEATURE_DIM_MAP.get((self.encoder_type, self.model_size), None)
        actual_dim = self._infer_feature_dim()
        if self.feature_dim is None or self.feature_dim != actual_dim:
            self.feature_dim = actual_dim

    def _resolve_model_name(self, et, sz, pt):
        if et == "vit":
            return f"vit_{sz}_patch{pt}_224"

        if et == "dinov2":
            return f"vit_{sz}_patch{pt}_dinov2"

        if et == "resnet":
            return f"resnet{sz}" 

        if et == "efficientnet":
            return f"efficientnet_{sz}"
        
        if et == "convnext":
            return f"convnext_{sz}"

        raise ValueError(f"Unknown encoder_type={et} in VisionEncoderManager")

    @torch.no_grad()
    def _infer_feature_dim(self):
        self.model.eval()
        if hasattr(self.model, "default_cfg") and "input_size" in self.model.default_cfg:
            _, h, w = self.model.default_cfg["input_size"]
        else:
            h = w = 224
        
        try:
            x = torch.zeros(2, 3, h, w, device=self.device)
            y = self.model(x)
            return int(y.shape[-1])
        except Exception as e:
            print(f"Warning: Failed to infer feature dim automatically. {e}")
            if self.feature_dim: return self.feature_dim
            raise RuntimeError("Could not infer feature_dim.")

    def get_transform(self):
        if hasattr(self.model, "default_cfg") and "input_size" in self.model.default_cfg:
            _, h, w = self.model.default_cfg["input_size"]
        else:
            h = w = 224
        
        mean = self.model.default_cfg.get('mean', (0.485, 0.456, 0.406))
        std = self.model.default_cfg.get('std', (0.229, 0.224, 0.225))

        return T.Compose([
            T.Resize(h, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
            T.CenterCrop(h),
            T.ToTensor(),
            T.Normalize(mean, std),
        ])

    @torch.no_grad()
    def encode(self, x):
        self.model.eval()
        return self.model(x.to(self.device))

class NLPEncoderManager(BaseEncoderManager):
    def __init__(self, cfg, device):
        if AutoModel is None:
            raise ImportError("NLPEncoderManager requires 'transformers' library.")
        self.tokenizer = None 
        super().__init__(cfg, device)

    def _load_model(self):
 
        self.model_name_or_path = self.ecfg.get("model_name_or_path", "bert-base-uncased")
        
        if self.from_pretrained:
            self.model = AutoModel.from_pretrained(self.model_name_or_path).to(self.device)
        else:
            config = AutoConfig.from_pretrained(self.model_name_or_path)
            self.model = AutoModel.from_config(config).to(self.device)
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
        self.feature_dim = self.model.config.hidden_size

    def get_transform(self):
        max_length = int(self.ecfg.get("max_length", 128))
        
        def tokenize_batch(texts: list[str]):
            return self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                max_length=max_length
            )
        return tokenize_batch

    @torch.no_grad()
    def encode(self, x):
        """
        x: Tokenizer output dict
        """
        self.model.eval()
        inputs = {k: v.to(self.device) for k, v in x.items()}
        outputs = self.model(**inputs)
        
        if self.encoder_type == "sbert":
            return self._mean_pooling(outputs, inputs['attention_mask'])
        
        last_hidden_state = outputs.last_hidden_state
        cls_embedding = last_hidden_state[:, 0]
        return cls_embedding

    def _mean_pooling(self, model_output, attention_mask):
        """SBERT를 위한 Mean Pooling"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

VISION_ENCODER_TYPES = {"vit", "dinov2", "resnet", "efficientnet", "convnext"}
NLP_ENCODER_TYPES = {"bert", "roberta", "distilbert", "albert", "sbert"}

def create_encoder_manager(cfg, device) -> BaseEncoderManager:
    encoder_type = cfg.get("encoder", {}).get("type", "vit").lower()
    
    if encoder_type in VISION_ENCODER_TYPES:
        return VisionEncoderManager(cfg, device)
    
    if encoder_type in NLP_ENCODER_TYPES:
        return NLPEncoderManager(cfg, device)
    
    raise ValueError(f"Unsupported encoder type: {encoder_type}")