import json
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F

from .llama import Transformer, ModelArgs, RMSNorm
from .utils import sample_top_p

from .tokenizer_llama3 import Tokenizer

from transformers import LlamaTokenizer
# from TransformerExplainability.baselines.ViT.ViT_new import vit_base_patch16_224 as vit
from TransformerExplainability.baselines.ViT.ViT_new_3d import vit_base_patch16_224 as vit
from model.longformer.modeling_longformer import LongformerModel,LongformerLMHead
from model.longformer.configuration_longformer import LongformerConfig
from transformers import AutoTokenizer
from torch.nn import CrossEntropyLoss

# 在文件顶部添加
try:
    import nltk
    nltk.download('punkt_tab')
    nltk.download('wordnet') # 下载必需的WordNet语料库
    nltk.download('omw-1.4') # 下载Open Multilingual WordNet
    from rouge_score import rouge_scorer
    from bert_score import BERTScorer
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    print("Metrics packages not fully installed. Some metrics will be unavailable.")



from .MFL import CrossModalAttention
from typing import List, Optional, Tuple, TypedDict

class CompletionPrediction(TypedDict, total=False):
    generation: str
    tokens: List[str]  # not required
    logprobs: List[float]  # not required
from transformers import PretrainedConfig

class CustomConfig(PretrainedConfig):
    def __init__(self, hidden_size=None, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size  
        for key, value in kwargs.items():
            setattr(self, key, value)

    def to_dict(self):
        output = super().to_dict()
        output.update(self.__dict__)  
        return output
    
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output  

class PixelShuffle3d(nn.Module):
    '''
    This class is a 3d version of pixelshuffle.
    '''
    def __init__(self, scale):
        '''
        :param scale: upsample scale
        '''
        super().__init__()
        self.scale = scale

    def forward(self, input):
        batch_size, channels, in_depth, in_height, in_width = input.size()
        nOut = channels // self.scale ** 3

        out_depth = in_depth * self.scale
        out_height = in_height * self.scale
        out_width = in_width * self.scale

        input_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, self.scale, in_depth, in_height, in_width)

        output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()

        return output.view(batch_size, nOut, out_depth, out_height, out_width)
 

class Cardio_LLaMA(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, llama_ckpt_dir, llama_tokenizer, model_args, stage=1,
                 load_llama=True):
        super().__init__()

        self.args = model_args
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # 4. llama
        with open(os.path.join(llama_ckpt_dir, "params.json"), "r") as f:
            params = json.loads(f.read())
        bias_lora = True
   
        self.model_args: ModelArgs = ModelArgs(
                max_seq_len=1024, max_batch_size=4, w_bias=bias_lora, w_lora=bias_lora,
                **params)  # max_batch_size only affects inference
        self.config = CustomConfig()
        self.config.hidden_size = self.model_args.dim
        print(f"model args: {self.model_args}")
        

        # 1. ViT Encoder
        print(f'Initialize ViT...')
        self.visual_encoder = vit(img_size=128,patch_size=16,in_chans=1)
        self.vit_proj = nn.Linear(768, self.model_args.dim)
        print(f'ViT initialized...')

        print(f'Initial text encoder...')
        self.max_txt_len = 256
        text_encoder_path = '/data/qiuhui/code/Aknowledge-based-2D/model/longformer-base-4096'
        self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
        # text encoder and decoder
        text_config = LongformerConfig.from_pretrained(text_encoder_path)
        self.text_encoder = LongformerModel.from_pretrained(text_encoder_path,config=text_config,add_pooling_layer=False)
        self.txt_proj = nn.Linear(768, self.model_args.dim)
        self.text_config = text_config
        print(f'Text encoder initialized...')

        print(f'Initial mfl...')
        self.mfl = CrossModalAttention(input_dim=self.model_args.dim)
        print(f'Mfl initialized...')

        if stage == 1:
            self.visual_decoder = nn.Sequential(
                nn.Conv3d(
                    in_channels=768,
                    out_channels=16**3 * 1,
                    kernel_size=1,
                ),
                PixelShuffle3d(16),
            )
            self.text_decoder = LongformerLMHead(config=text_config)
                    # create momentum encoders  
            self.visual_encoder_m = vit(img_size=128,patch_size=16,in_chans=1)
            self.text_encoder_m = LongformerModel.from_pretrained(text_encoder_path,config=text_config,add_pooling_layer=False)

            
            self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
                                [self.text_encoder,self.text_encoder_m],
                            ]
            self.copy_params()

            # create the queue
            self.register_buffer("image_queue", torch.randn(768, 57600))
            self.register_buffer("text_queue", torch.randn(768, 57600))
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

            self.image_queue = F.normalize(self.image_queue, dim=0)
            self.text_queue = F.normalize(self.text_queue, dim=0)
            
            self.queue_size = 57600
            self.momentum = 0.995
            self.temp = nn.Parameter(0.07*torch.ones([]))  

        # 5. tokenizer
        if self.args.llama_type=='7B':
            self.tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer)
            self.tokenizer.pad_id = self.tokenizer.eos_id
        elif self.args.llama_type=='llama3' or self.args.llama_type=='8B':
            self.tokenizer = Tokenizer(llama_tokenizer)
            self.tokenizer.pad_id = self.tokenizer.eos_id

        if torch.cuda.is_available():
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        self.llama = Transformer(self.model_args)
        torch.set_default_tensor_type(torch.FloatTensor)
        if load_llama:
            print(f"Loading LLaMA Checkpoint...")
            ckpts = sorted(Path(llama_ckpt_dir).glob("*.pth"))

            """
            Adapted from https://github.com/cedrickchee/llama/blob/main/chattyllama/combined/inference.py
            """
            key_to_dim = {
                "w1": 0,
                "w2": -1,
                "w3": 0,
                "wo": -1,
                "wq": 0,
                "wk": 0,
                "wv": 0,
                "output": 0,
                "tok_embeddings": -1,
                "ffn_norm": None,
                "attention_norm": None,
                "norm": None,
                "rope": None,
            }
            for i, ckpt in enumerate(ckpts):
                checkpoint = torch.load(ckpt, map_location="cpu")
                for parameter_name, parameter in self.llama.named_parameters():
                    short_name = parameter_name.split(".")[-2] #'layers.0.attention.wk.weight'
                    if "gate" in parameter_name or "lora" in parameter_name or "bias" in parameter_name:
                        continue
                    if key_to_dim[short_name] is None and i == 0:
                        parameter.data = checkpoint[parameter_name]
                    elif key_to_dim[short_name] == 0: 
                        size = checkpoint[parameter_name].size(0) #[4096, 4096] 大小一样
                        parameter.data[size * i: size * (i + 1), :] = checkpoint[
                            parameter_name
                        ]
                    elif key_to_dim[short_name] == -1:
                        size = checkpoint[parameter_name].size(-1) #[32000,4096] /[4096,4096]/ [11008, 4096]大小一样
                        parameter.data[:, size * i: size * (i + 1)] = checkpoint[
                            parameter_name
                        ]
  
                del checkpoint
            print(f"LLaMA Checkpoint Loaded")        

        # 6. training criterion
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.l2_loss = torch.nn.MSELoss()
        self.stage = stage
        self.set_default_trainability(self.stage)

    def get_trainable_params(self, stage=1):
        trainable = {}
        if stage == 1:
            # import pdb;pdb.set_trace()
            for name, para in self.named_parameters():
                if 'visual_encoder' in name:
                    trainable[name] = para
                if "text_encoder" in name:
                    trainable[name] = para

        elif stage == 2 or stage == 3:
            for name, para in self.named_parameters():
                if 'vit_proj' in name:
                    trainable[name] = para
                if "txt_proj" in name:
                    trainable[name] = para
                if "llama." in name:
                    if 'norm' in name or 'bias' in name or 'lora' in name:
                        trainable[name] = para
                if "tok_embeddings" in name:
                    trainable[name] = para

        return trainable

    def set_default_trainability(self, stage=1):
        for key, value in self.named_parameters():
            value.requires_grad = False
        trainable_params = self.get_trainable_params(stage)
        print(f"Trainable Params: {trainable_params.keys()}")
        for key, value in trainable_params.items():
            value.data = value.data.float()
            value.requires_grad = True

    @torch.no_grad()
    def copy_params(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data.copy_(param.data)  # initialize
                param_m.requires_grad = False  # not update by gradient

    @torch.no_grad()        
    def _momentum_update(self):
        for model_pair in self.model_pairs:           
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, image_feat, text_feat):
        # gather keys before updating queue
        image_feats = concat_all_gather(image_feat)
        text_feats = concat_all_gather(text_feat)

        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
        self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
        ptr = (ptr + batch_size) % self.queue_size  # move pointer

        self.queue_ptr[0] = ptr 


    def forward(self, input2_tensor, labels, input2_mask, image, non_image, label):

        h = self.llama.tok_embeddings(input2_tensor.to(self.device))
        token_img = torch.tensor(self.tokenizer.encode('<img>', bos=False, eos=False, allowed_special={"<img>","<non_img>"}), dtype=torch.int64, device=h.device)
        token_txt = torch.tensor(self.tokenizer.encode('<non_img>', bos=False, eos=False, allowed_special={"<img>","<non_img>"}), dtype=torch.int64, device=h.device)

        
        img_embedding,_ = self.visual_encoder(image.unsqueeze(1))
        img_feat = F.normalize(img_embedding[:,0,:], dim=-1)
        
        
        text = self.text_tokenizer(non_image, padding="max_length", truncation=True, max_length=600, return_tensors="pt",).to(image.device)
        global_attention_mask = torch.zeros(text.input_ids.shape, dtype=torch.long, device=text.input_ids.device)
        global_attention_mask[:,0] = 1
        text_output = self.text_encoder(text.input_ids,attention_mask=text.attention_mask,global_attention_mask=global_attention_mask,output_attentions=True)
        txt_embedding = text_output.last_hidden_state
        txt_feat = F.normalize(txt_embedding[:,0,:], dim=-1)

        if self.stage==1:
            # get momentum features
            with torch.no_grad():
                self._momentum_update()
                image_embeds_m = self.visual_encoder_m(image.unsqueeze(1))[0]
                image_feat_m = F.normalize(image_embeds_m[:,0,:],dim=-1)  
                image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                   
                
                text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,global_attention_mask=global_attention_mask)
                text_feat_m = F.normalize(text_output_m.last_hidden_state[:,0,:],dim=-1) 
                text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

                sim_i2t_m = image_feat_m @ text_feat_all / self.temp
                sim_t2i_m = text_feat_m @ image_feat_all / self.temp

                sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
                sim_targets.fill_diagonal_(1)
                alpha = 0.4
                sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
                sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets

            sim_i2t = img_feat @ text_feat_all / self.temp
            sim_t2i = txt_feat @ image_feat_all / self.temp
                                
            loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
            loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 

            
            loss_itc = (loss_i2t+loss_t2i)/2
            self._dequeue_and_enqueue(image_feat_m, text_feat_m)

            vision_sequence_output = img_embedding[:, 1:]
            batch_size, sequence_length, num_channels = vision_sequence_output.shape
            depth = height = width = round(pow(sequence_length,1.0/3))
            vision_sequence_output = vision_sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, depth, height, width)

            # Reconstruct image
            reconstructed_image = self.visual_decoder(vision_sequence_output)
            reconstruction_loss = F.l1_loss(image, reconstructed_image, reduction="none")
            image_res_loss = reconstruction_loss.sum() / (torch.ones_like(reconstruction_loss).sum() + 1e-5)
            # Reconstruct text
            prediction_scores = self.text_decoder(txt_embedding)
            loss_fct = CrossEntropyLoss()
            labels_res = text.input_ids.to(prediction_scores.device)
            text_res_loss = loss_fct(prediction_scores.view(-1, self.text_config.vocab_size), labels_res.view(-1))
            loss_res = text_res_loss+image_res_loss
            print('ITC loss: ', loss_itc)
            print('Restore Loss: ', loss_res, 'text loss: ',text_res_loss, 'img loss: ', image_res_loss)

            return None, loss_itc+loss_res
        
        
        img_feat_proj = self.vit_proj(img_feat)
        txt_feat_proj = self.txt_proj(txt_feat)

        img_feat_proj, txt_feat_proj = self.mfl(img_feat_proj, txt_feat_proj)
        img_positions = (input2_tensor == token_img).nonzero(as_tuple=True)
        for batch_idx, pos in zip(img_positions[0], img_positions[1]):
                h[batch_idx, pos, :] = img_feat_proj[batch_idx]

        txt_positions = (input2_tensor == token_txt).nonzero(as_tuple=True)
        for batch_idx, pos in zip(txt_positions[0], txt_positions[1]):
                h[batch_idx, pos, :] = txt_feat_proj[batch_idx]
            
        seqlen_now = h.shape[1]
        freqs_cis = self.llama.freqs_cis.to(h.device)
        freqs_cis = freqs_cis[:seqlen_now]
        mask = torch.full((1, 1, seqlen_now, seqlen_now), float("-inf"), device=h.device)
        mask = torch.triu(mask, diagonal=0 + 1).type_as(h)

        
        with torch.amp.autocast('cuda'):
            for layer in self.llama.layers:
                h = layer(h, 0, freqs_cis, mask)
  
            h = self.llama.norm(h)
            output = self.llama.output(h)
            output = output[:, :-1, :]
            labels = labels[:, 1:]

            if labels.sum() == 0:
                c_loss = output.mean() * 0
            else:
                assert self.llama.vocab_size == self.model_args.vocab_size
                c_loss = self.criterion(output.reshape(-1, self.llama.vocab_size), labels.flatten().to(self.device))
            
            # print('output: ',self.tokenizer.decode(output[0].argmax(-1).detach().cpu().tolist()).replace('<|end_of_text|>',''))
            # print('labels: ',self.tokenizer.decode(labels[0].detach().cpu().tolist()).replace('<|end_of_text|>',''))
            # 获取预测结果和标签文本
            pred_tokens = output.argmax(-1)
            label_tokens = labels
            
            # 将token转换为文本
            pred_texts = []
            label_texts = []
            
            for i in range(pred_tokens.shape[0]):
                pred_text = self.tokenizer.decode(pred_tokens[i].detach().cpu().tolist()).replace('<|end_of_text|>','').strip()
                label_text = self.tokenizer.decode(label_tokens[i].detach().cpu().tolist()).replace('<|end_of_text|>','').strip()
                
                pred_texts.append(pred_text)
                label_texts.append(label_text)
            
            # print('output: ', pred_texts[0])
            # print('labels: ', label_texts[0])
            # 计算评估指标
            metrics = self.calculate_metrics(pred_texts, label_texts)
            print('Metrics:', metrics)
            
            return output, c_loss

    def calculate_metrics(self, predictions, references):
        """
        计算BLEU, METEOR, ROUGE, BERTScore等指标
        """
        try:
            import nltk
            from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
            from nltk.translate.meteor_score import meteor_score
            from rouge_score import rouge_scorer
            from bert_score import BERTScorer
            
            # 确保nltk数据已下载
            try:
                nltk.data.find('tokenizers/punkt')
            except LookupError:
                nltk.download('punkt')
            
            metrics = {}
            
            # 计算BLEU
            bleu_scores = []
            smoothie = SmoothingFunction().method1
            for pred, ref in zip(predictions, references):
                # 将文本tokenize
                pred_tokens = nltk.word_tokenize(pred.lower())
                ref_tokens = nltk.word_tokenize(ref.lower())
                
                bleu = sentence_bleu([ref_tokens], pred_tokens, 
                                    weights=(0.25, 0.25, 0.25, 0.25),
                                    smoothing_function=smoothie)
                bleu_scores.append(bleu)
            
            metrics['bleu'] = sum(bleu_scores) / len(bleu_scores)
        
            # 计算METEOR（需要确保有METEOR jar文件）
            try:
                meteor_scores = []
                for pred, ref in zip(predictions, references):
                    # 使用简单的nltk版本，或者安装nltk的METEOR支持
                    score = meteor_score([ref], pred)
                    meteor_scores.append(score)
                metrics['meteor'] = sum(meteor_scores) / len(meteor_scores)
            except:
                metrics['meteor'] = 0.0
                print("METEOR calculation failed, make sure METEOR is properly installed")
            # 计算ROUGE
            rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
            scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
            
            for pred, ref in zip(predictions, references):
                scores = scorer.score(ref, pred)
                for key in rouge_scores.keys():
                    rouge_scores[key].append(scores[key].fmeasure)
            
            for key in rouge_scores.keys():
                metrics[key] = sum(rouge_scores[key]) / len(rouge_scores[key])
            
            # 计算BERTScore
            try:
                bert_scorer = BERTScorer(lang="en", rescale_with_baseline=True)
                P, R, F1 = bert_scorer.score(predictions, references)
                metrics['bertscore_precision'] = P.mean().item()
                metrics['bertscore_recall'] = R.mean().item()
                metrics['bertscore_f1'] = F1.mean().item()
            except:
                metrics['bertscore_precision'] = 0.0
                metrics['bertscore_recall'] = 0.0
                metrics['bertscore_f1'] = 0.0
                print("BERTScore calculation failed")
            
            return metrics
        except ImportError as e:
            print(f"Some metrics packages not installed: {e}")
            print("Please install: pip install nltk rouge-score bert-score")
            return {}
    
    @torch.inference_mode()
    def forward_inference(self, tokens, start_pos: int, img_feat=None, txt_feat=None):
        h = self.llama.tok_embeddings(tokens)

        token_img = torch.tensor(self.tokenizer.encode('<img>', bos=False, eos=False, allowed_special={"<img>","<non_img>"}), dtype=torch.int64, device=h.device)
        token_txt = torch.tensor(self.tokenizer.encode('<non_img>', bos=False, eos=False, allowed_special={"<img>","<non_img>"}), dtype=torch.int64, device=h.device)

        if self.stage==3:

            img_feat_proj = self.vit_proj(img_feat)
            txt_feat_proj = self.txt_proj(txt_feat)
            img_positions = (tokens == token_img).nonzero(as_tuple=True)
            for batch_idx, pos in zip(img_positions[0], img_positions[1]):
                    h[batch_idx, pos, :] = img_feat_proj[batch_idx]

            txt_positions = (tokens == token_txt).nonzero(as_tuple=True)
            for batch_idx, pos in zip(txt_positions[0], txt_positions[1]):
                    h[batch_idx, pos, :] = txt_feat_proj[batch_idx]
        
        seqlen_now = h.shape[1]
        freqs_cis = self.llama.freqs_cis.to(h.device)
        freqs_cis = freqs_cis[start_pos:start_pos + seqlen_now]
        mask = None
        mask = torch.full((1, 1, seqlen_now, seqlen_now), float("-inf"), device=h.device)
        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        for layer in self.llama.layers:
            h = layer(h, 0, freqs_cis, mask)

        h = self.llama.norm(h)
        output = self.llama.output(h[:, -1, :])

        return output.float()

    @torch.inference_mode()
    def generate(
            self,
            prompt_tokens,
            imgs=None,
            non_imgs=None,
            max_gen_len: int = 512,
            cache_size=10,
            cache_t=20,
            cache_weight=0.5,
            temperature: float = 0.6,
            top_p: float = 0.9,
            logprobs: bool = False,
            echo: bool = False,
    ):
        bsz = len(prompt_tokens)
        params = self.llama.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
        
        with torch.amp.autocast('cuda'):

            img_embedding,_ = self.visual_encoder(imgs.unsqueeze(1))
            img_feat = F.normalize(img_embedding[:,0,:], dim=-1)
            
            
            text = self.text_tokenizer(non_imgs, padding="max_length", truncation=True, max_length=600, return_tensors="pt",).to(imgs.device)
            global_attention_mask = torch.zeros(text.input_ids.shape, dtype=torch.long, device=text.input_ids.device)
            global_attention_mask[:,0] = 1
            text_output = self.text_encoder(text.input_ids,attention_mask=text.attention_mask,global_attention_mask=global_attention_mask,output_attentions=True)
            txt_embedding = text_output.last_hidden_state
            txt_feat = F.normalize(txt_embedding[:,0,:], dim=-1)

        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert max_prompt_len <= params.max_seq_len
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

        pad_id = self.tokenizer.pad_id
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
        
        prev_pos = 0
        eos_reached = torch.tensor([False] * bsz, device="cuda")
        input_text_mask = tokens != pad_id
   
        stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens), device="cuda")

        for cur_pos in range(min_prompt_len, total_len):
            with torch.amp.autocast('cuda'):
                
                logits = self.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos,
                                                                          img_feat, txt_feat) #[1,1,4096]
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)

            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
           
            # if bsz == 1 and self.tokenizer.decode(tokens[0, cur_pos - 2:cur_pos + 1]) == "\n###":
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
                torch.isin(next_token, stop_tokens)
            )
            # prev_pos = cur_pos
            if all(eos_reached):
                break   
          
        if logprobs:
            token_logprobs = token_logprobs.tolist()
        out_tokens, out_logprobs = [], []
        for i, toks in enumerate(tokens.tolist()):
            # cut to max gen len
            start = 0 if echo else len(prompt_tokens[i])
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
            probs = None
            if logprobs:
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
            # cut to after eos tok if any
            for stop_token in self.tokenizer.stop_tokens:
                try:
                    eos_idx = toks.index(stop_token)
                    toks = toks[:eos_idx]
                    probs = probs[:eos_idx] if logprobs else None
                except ValueError:
                    pass
            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)
 
    @torch.inference_mode()
    def text_completion(
        self,
        prompts: List[str],
        img,
        non_img,
        temperature: float = 0.6,
        top_p: float = 0.9,
        max_gen_len: Optional[int] = None,
        logprobs: bool = False,
        echo: bool = False,
        add_special_token: bool = False
    ) -> List[CompletionPrediction]:
        """
        Perform text completion for a list of prompts using the language generation model.

        Args:
            prompts (List[str]): List of text prompts for completion.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
                If not provided, it's set to the model's maximum sequence length minus 1.
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

        Returns:
            List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.

        Note:
            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
            If logprobs is True, token log probabilities are computed for each generated token.

        """
        if max_gen_len is None:
            max_gen_len = self.model.params.max_seq_len - 1
        if add_special_token:
            prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False, allowed_special={"<img>","<non_img>"}) for x in prompts]
        else:
            prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
        generation_tokens, generation_logprobs = self.generate(
            prompt_tokens=prompt_tokens,
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
            logprobs=logprobs,
            echo=echo,
            imgs=img,
            non_imgs=non_img,
            cache_size=10,
            cache_t=20,
            cache_weight=0.5,
        )
        if logprobs:
            return [
                {
                    "generation": self.tokenizer.decode(t),
                    "tokens": [self.tokenizer.decode([x]) for x in t],
                    "logprobs": logprobs_i,
                }
                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
            ]
        return [{"generation": self.tokenizer.decode(t), "completion_ids": t} for t in generation_tokens]


