# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou (xueyan@cs.wisc.edu)
# --------------------------------------------------------

import torch
from torch import nn
from torch.nn import functional as F

from timm.models.layers import trunc_normal_

from .build import register_model
from ..utils import configurable
from .LangEncoder import build_tokenizer, build_lang_encoder
from xdecoder.utils.prompt_engineering import prompt_engineering, get_prompt_templates


class LanguageEncoder(nn.Module):

    @configurable
    def __init__(
        self,
        tokenizer,
        tokenizer_type,
        lang_encoder,
        lang_projection,
        max_token_num,
        queue_operator,
    ):
        super().__init__()
        # seg
        self.tokenizer = tokenizer
        self.tokenizer_type = tokenizer_type
        self.lang_encoder = lang_encoder
        self.lang_proj = lang_projection
        self.max_token_num = max_token_num
        self.logit_scale = nn.Parameter(torch.ones([]))
        
        # captioning & retrieval
        for key, value in queue_operator.items():
            self.register_buffer(key, value)
            

    @classmethod
    def from_config(cls, cfg):
        # build up text encoder for seg
        tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
        tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
        lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
        max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
        
        dim_lang = cfg['MODEL']['TEXT']['WIDTH']
        dim_projection = cfg['MODEL']['DIM_PROJ']
        lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
        trunc_normal_(lang_projection, std=.02)

        # tested not working better      
        queue_operator = {}

        return {
            "tokenizer": tokenizer,
            "tokenizer_type": tokenizer_type,
            "lang_encoder": lang_encoder,
            "lang_projection": lang_projection,
            "max_token_num": max_token_num,
            "queue_operator": queue_operator,
        }

    def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True, store_buffer=None):
        if not is_eval:
            if prompt:
                # randomly sample one template
                arbitary_concepts = [
                    prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
                    for label in range(len(class_names))
                ]
                if add_bgd:
                    arbitary_concepts.append("A background in coco.")
            else:
                arbitary_concepts = class_names
            
            input_ids = []
            attention_masks = []
            for txt in arbitary_concepts:
                tokens = self.tokenizer(
                    txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
                )
                tokens['input_ids'].squeeze_()
                tokens['attention_mask'].squeeze_()

                input_ids.append(tokens['input_ids'])
                attention_masks.append(tokens['attention_mask'])

            arbitary_tokens = torch.stack(input_ids)
            arbitary_attention_masks = torch.stack(attention_masks)
            text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
            setattr(self, '{}_text_embeddings'.format(name), text_emb)
        else:
            with torch.no_grad():
                def extract_mean_emb(txts):
                    tokens = self.tokenizer(
                        txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
                    )
                    clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
                    clss_embedding = clss_embedding.mean(dim=0)
                    clss_embedding /= clss_embedding.norm()
                    return clss_embedding

                templates = get_prompt_templates()
                clss_embeddings = []
                if prompt:
                    for clss in class_names:
                        txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
                        clss_embeddings.append(extract_mean_emb(txts))
                else:
                    for clss in class_names:
                        clss_embeddings.append(extract_mean_emb([clss]))

                if add_bgd:
                    txts = ["A background in coco."]
                    clss_embeddings.append(extract_mean_emb(txts))

                text_emb = torch.stack(clss_embeddings, dim=0)
                setattr(self, '{}_text_embeddings'.format(name), text_emb)

    def reset_text_embeddings(self, name='default'):
        pass

    def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
        if not token:
            tokens = self.tokenizer(
                txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
            )
            tokens = {key: value.cuda() for key, value in tokens.items()}
        else:
            tokens = txts
        token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
        ret = {"tokens": tokens,
                "token_emb": token_emb,
                "class_emb": class_emb,}
        setattr(self, '{}_token_embeddings'.format(name), ret)
        return ret

    def forward_language(self, texts, norm=True):
        x = self.lang_encoder(*texts)
        x = x['last_hidden_state']

        if self.tokenizer_type == 'clip':
            x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
        else:
            x = x[:, 0]

        x = x @ self.lang_proj
        if norm:
            x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
        return x
    
    def forward_language_token(self, texts, norm=False):
        x = self.lang_encoder(*texts)
        token_x = x['last_hidden_state']

        if self.tokenizer_type == 'clip':
            class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
        else:
            class_x = token_x[:, 0]

        class_x = class_x @ self.lang_proj
        token_x = token_x @ self.lang_proj

        if norm:
            class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
            token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)

        return token_x, class_x
    
    def compute_similarity(self, v_emb, name='default', fake=False):
        if fake:
            return None
        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
        t_emb = getattr(self, '{}_text_embeddings'.format(name))
        output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
        return output

    def get_single_text_embedding(self, text, norm=True):
        """
        Helper method to get embedding for a single text string.
        """
        tokens = self.tokenizer(
            text, padding='max_length', truncation=True, 
            max_length=self.max_token_num, return_tensors='pt'
        )
        
        # Move to device
        input_ids = tokens['input_ids'].cuda()
        attention_mask = tokens['attention_mask'].cuda()
        
        # Get embedding using existing forward_language method
        text_emb = self.forward_language((input_ids, attention_mask), norm=norm)
        
        return text_emb.squeeze(0)  # Remove batch dimension for single text

    def forward_xmask3d(self, outputs, targets=None, labels=None, test_labels=None,istrain=True):
        """
        Forward method similar to CategoryEmbed but using LanguageEncoder's functionality.
        
        Args:
            outputs: Model outputs (kept for compatibility)
            targets: Training targets (kept for compatibility)  
            labels: Training labels list
            test_labels: Test labels list
            use_additional_proj: Whether to use additional projection layer
            projection_dim: Additional projection dimension (only used if use_additional_proj=True)
        
        Returns:
            Dictionary with text_embed, null_embed, and labels
        """
        
        # Initialize null embedding if needed
        if not hasattr(self, 'xmask3d_null_embed'):
            # Create null embedding using empty string
            null_text_emb = self.get_single_text_embedding("background", norm=True)
            self.xmask3d_null_embed = nn.Parameter(null_text_emb.clone())
        
        if istrain:
            # Training mode
            current_labels = labels if labels is not None else []
            # Get text embeddings for training labels
            if len(current_labels) > 0:
                self.get_text_embeddings(current_labels, name='xmask3d_train', 
                                    is_eval=False, prompt=True, norm=True)
                text_embed = getattr(self, 'xmask3d_train_text_embeddings')
            else:
                text_embed = torch.empty(0, self.lang_proj.shape[1], 
                                    device=self.lang_proj.device)
            
            # Apply additional projection if needed
            null_embed = self.xmask3d_null_embed.unsqueeze(0)
            
            labels_re = [[label] for label in current_labels]
            return {
                "text_embed": text_embed,
                "null_embed": null_embed,
                "labels": labels_re,
            }
        
        else:
            # Evaluation mode
            current_labels = test_labels if test_labels is not None else []
            
            # Get text embeddings for test labels
            if len(current_labels) > 0:
                self.get_text_embeddings(current_labels, name='xmask3d_test', 
                                    is_eval=True, prompt=True, norm=True)
                text_embed = getattr(self, 'xmask3d_test_text_embeddings')
            else:
                text_embed = torch.empty(0, self.lang_proj.shape[1], 
                                    device=self.lang_proj.device)
            
            null_embed = self.xmask3d_null_embed.unsqueeze(0)
            labels_re = [[label] for label in current_labels]
            return {
                "text_embed": text_embed,
                "null_embed": null_embed,
                "labels": labels_re,
            }

@register_model
def get_language_model(cfg, **kwargs):
    return LanguageEncoder(cfg)