from ast import arg
from pickle import TRUE
from header import *
import torch.nn.functional as F
from .ImageBind import *
from .ImageBind import data
from .modeling_llama import LlamaForCausalLM
from transformers import StoppingCriteria, StoppingCriteriaList

import torch
from torch.nn.utils import rnn
from model.ImageBind.models import lora as LoRA
from model.ImageBind.models.imagebind_model import load_module
from peft import get_peft_model, LoraConfig, TaskType,LoraModel

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = stops
        self.ENCOUNTERS = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        stop_count = 0
        for stop in self.stops:
            stop_count = (stop == input_ids[0]).sum().item()
        if stop_count >= self.ENCOUNTERS:
            return True
        return False

def build_one_instance(tokenizer, conversation,prompt_start):
    text_list = []
    turn_num = len(conversation)
    input_ids, target_ids = [], []
    for i in range(turn_num):
        turn = conversation[i]
        role = turn['from']
        if i == 0: # the first human turn
            assert role == 'human'
            if prompt_start == '### Human: <Img>':
                text = '</Img> Answer the following question in one sentence:' + turn['value'] + '\n### Assistant:'
            if prompt_start == '### Human: <audio>':
                text = '</Audio> Please answer briefly:' + turn['value'] + '\n### Assistant:'
            one_input_id = tokenizer(text, add_special_tokens=False).input_ids
            input_ids += one_input_id
            target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt
        else:
            if role == 'human':
                text = 'Human: ' + turn['value'] + '\n### Assistant:'
                one_input_id = tokenizer(text, add_special_tokens=False).input_ids
                input_ids += one_input_id
                target_ids += [-100]*len(one_input_id)
            elif role == 'gpt':
                text = turn['value'] + '\n###'
                one_input_id = tokenizer(text, add_special_tokens=False).input_ids
                input_ids += one_input_id
                target_ids += one_input_id
            else:
                raise Exception('Wrong Role!!!')
        text_list.append(text)
        assert len(input_ids) == len(target_ids)
    return text_list, input_ids, target_ids

def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len,prompt_start):
    batch_input_ids, batch_target_ids = [], []
    for conversation in batch_of_conversations:
        _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation,prompt_start)
        batch_input_ids.append(torch.LongTensor(one_input_ids))
        batch_target_ids.append(torch.LongTensor(one_target_ids))
    input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100)
    assert input_ids.size() == target_ids.size()
    input_ids = input_ids[:,:max_tgt_len]
    target_ids = target_ids[:,:max_tgt_len]
    attention_mask = input_ids.ne(tokenizer.pad_token_id)
    assert attention_mask.size() == input_ids.size()
    return input_ids, target_ids, attention_mask.long()

#PROMPT_START = '### Human: <Img>'
class OpenLLAMAPEFTModel(nn.Module):

    '''LoRA for LLaMa model'''

    def __init__(self, **args):
        super(OpenLLAMAPEFTModel, self).__init__()
        self.args = args
        imagebind_ckpt_path = args['imagebind_ckpt_path']
        vicuna_ckpt_path = args['vicuna_ckpt_path']
        max_tgt_len = args['max_tgt_len']
        stage = args['stage']
        lora_layer_idxs = args['lora_layer_idxs']
        lora_modality_names = args['lora_modality_names']
        print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
        self.visual_encoder, self.visual_hidden_size = \
        imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
        # free vision encoder
        # for name, param in self.visual_encoder.named_parameters():
        #     param.requires_grad = False
        # self.visual_encoder.eval()
        lora=args['lora']
        lora_checkpoint_dir="./.checkpoints/lora"

        #if lora:
        # for modality_preprocessor in self.visual_encoder.modality_preprocessors.children():
        #     modality_preprocessor.requires_grad_(False)
        # for modality_trunk in self.visual_encoder.modality_trunks.children():
        #     modality_trunk.requires_grad_(False)

        # self.visual_encoder.modality_trunks.update(LoRA.apply_lora_modality_trunks(self.visual_encoder.modality_trunks, rank=4,
        #                                                                 layer_idxs=lora_layer_idxs,
        #                                                                 modality_names=lora_modality_names))
        # LoRA.load_lora_modality_trunks(self.visual_encoder.modality_trunks, checkpoint_dir=lora_checkpoint_dir)
        # print("=== [LoRA] Trainable LoRA Parameters in visual_encoder.modality_trunks ===")
        # # for name, param in self.visual_encoder.modality_trunks.named_parameters():
        # #     if 'lora' in name.lower():
        # #         print(f"{name:80} | requires_grad = {param.requires_grad} | shape = {tuple(param.shape)}")

        # # Load postprocessors & heads
        # load_module(self.visual_encoder.modality_postprocessors, module_name="postprocessors",
        #             checkpoint_dir=lora_checkpoint_dir)
        # load_module(self.visual_encoder.modality_heads, module_name="heads",
        #             checkpoint_dir=lora_checkpoint_dir)

        # print ('Visual encoder initialized.')
        # 1. 先插入 LoRA
        self.visual_encoder.modality_trunks.update(
            LoRA.apply_lora_modality_trunks(
                self.visual_encoder.modality_trunks,
                rank=4,
                layer_idxs=lora_layer_idxs,
                modality_names=lora_modality_names
            )
        )
        LoRA.load_lora_modality_trunks(
            self.visual_encoder.modality_trunks,
            checkpoint_dir=lora_checkpoint_dir
        )

        # 2. 冻结非 LoRA 参数
        for name, param in self.visual_encoder.modality_trunks.named_parameters():
            if 'lora' in name.lower():
                param.requires_grad = True   # LoRA 权重可训练
            else:
                param.requires_grad = False  # 其他冻结

        for modality_preprocessor in self.visual_encoder.modality_preprocessors.children():
            modality_preprocessor.requires_grad_(False)

        print("=== [LoRA] Trainable LoRA Parameters in visual_encoder.modality_trunks ===")
        # for name, param in self.visual_encoder.modality_trunks.named_parameters():
        #     if param.requires_grad:
        #         print(f"{name:80} | shape = {tuple(param.shape)}")

        # for name, param in self.visual_encoder.named_parameters():
        #     if 'lora_' in name:
        #         param.requires_grad = True

        print (f'Initializing language decoder from {vicuna_ckpt_path} ...')
        # add the lora module
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, 
            inference_mode=False, 
            r=self.args['lora_r'], 
            lora_alpha=self.args['lora_alpha'], 
            lora_dropout=self.args['lora_dropout'],
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
        )

        self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
        #self.llama_model = LlamaForCausalLM.from_pretrained('/home-ssd/Users/gm_intern/lhy/PandaGPT/pretrained_ckpt/vicuna_ckpt/1b_v0')
        self.llama_model = get_peft_model(self.llama_model, peft_config)
        # if self.args.get("lora", False):
        #     # 如果是 encoder LoRA 模式：冻结 LLM 的所有 LoRA 参数
        #     for name, param in self.llama_model.named_parameters():
        #         if "lora_" in name.lower():   # LoRA 插入的层
        #             param.requires_grad = False
        #     print("=== [LoRA] LLM 的 LoRA 已冻结，不会更新 ===")
        # else:
        #     print("=== [LoRA] 正常训练 LLM LoRA 参数 ===")

        self.llama_model.print_trainable_parameters()
        self.llama_model.print_trainable_parameters()
        #self.llama_tokenizer = LlamaTokenizer.from_pretrained('/home-ssd/Users/gm_intern/lhy/PandaGPT/pretrained_ckpt/vicuna_ckpt/1b_v0')
        self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
        self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
        self.llama_tokenizer.padding_side = "right"
        print ('Language decoder initialized.')

        self.llama_proj = nn.Linear(
            self.visual_hidden_size, self.llama_model.config.hidden_size
        )

        self.max_tgt_len = max_tgt_len
        self.device = torch.cuda.current_device()

    def encode_video(self, video_paths):
        inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
        # convert into visual dtype
        inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
        #with torch.no_grad():
        embeddings = self.visual_encoder(inputs)
        video_embeds = embeddings[ModalityType.VISION] # bsz x 1024
        inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
        return inputs_llama, atts_llama

    def encode_audio(self, audio_paths):
        inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
        # convert into visual dtype
        inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
        #with torch.no_grad():
        embeddings = self.visual_encoder(inputs)
        audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024
        inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
        return inputs_llama, atts_llama
    # def encode_audio_tensor(self, audio_arrays, sampling_rates):
    #     # 1. waveform -> tensor [B, C, T]
    #     audio_tensor = self.transform_raw_waveform(audio_arrays, sampling_rates).to(self.device)
    #     if audio_tensor.dim() == 2:
    #         audio_tensor = audio_tensor.unsqueeze(-1)  # [B, C, T, 1] to match conv2d expectations
        
    #     # 2. 输入 dict
    #     inputs = {ModalityType.AUDIO: audio_tensor.to(self.llama_model.dtype)}  # ✅ 不要再 list 包裹了

    #     with torch.no_grad():
    #         embeddings = self.visual_encoder(inputs)
    #         audio_embeds = embeddings[ModalityType.AUDIO]  # bsz x 1024

    #     inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1)  # bsz x 1 x llama_dim
    #     atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device)  # bsz x 1

    #     return inputs_llama, atts_llama

    def encode_thermal(self, thermal_paths):
        inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
        # convert into visual dtype
        inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
        # with torch.no_grad():
        embeddings = self.visual_encoder(inputs)
        image_embeds = embeddings['thermal'] # bsz x 1024
        inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
        return inputs_llama, atts_llama

    def encode_image(self, image_paths):
        inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
        # convert into visual dtype
        inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
        #with torch.no_grad():
        embeddings = self.visual_encoder(inputs)
        image_embeds = embeddings['vision'] # bsz x 1024
        inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
        return inputs_llama, atts_llama

    def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask,prompt_start):
        '''
            input_ids, target_ids, attention_mask: bsz x s2
        '''
        input_ids = input_ids.to(self.device) # bsz x s2
        target_ids = target_ids.to(self.device) # bsz x s2
        attention_mask = attention_mask.to(self.device) # bsz x s2

        batch_size = img_embeds.shape[0]
        p_before = prompt_start
        p_before_tokens = self.llama_tokenizer(p_before, 
            return_tensors="pt", add_special_tokens=False).to(self.device)
        # peft model need deeper call
        p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
        p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
        bos = torch.ones([batch_size, 1],
                         dtype=p_before_tokens.input_ids.dtype,
                         device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
        bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
        inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim

        # create targets
        empty_targets = (
            torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], # 1 (bos) + s1 + 1 (image vector)
                       dtype=torch.long).to(self.device).fill_(-100)  
        ) # bsz x (1 + s1 + 1)
        targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2)
        assert inputs_embeds.size()[1] == targets.size()[1]

        atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1)
        attention_mask = torch.cat([atts_prefix, attention_mask], dim=1)
        assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2)
        return inputs_embeds, targets, attention_mask 

    def forward(self, inputs):
        # image_paths = inputs['image_paths']
        # img_embeds, _ = self.encode_image(image_paths)
        if inputs.get('image_paths', None):
            img_embeds, _ = self.encode_image(inputs['image_paths'])
        elif inputs.get('audio_paths', None):
            img_embeds, _ = self.encode_audio(inputs['audio_paths'])
        # elif inputs.get('audio', None):
        #     #audio_tensors = self.transform_raw_waveform(inputs['audio'], inputs['sampling_rate'])
        #     img_embeds, _ = self.encode_audio_tensor(inputs['audio'], inputs['sampling_rate'])
        elif inputs.get('video_paths', None):
            img_embeds, _ = self.encode_video(inputs['video_paths'])
        elif inputs.get('thermal_paths', None):
            img_embeds, _ = self.encode_thermal(inputs['thermal_paths'])
        else:
            raise ValueError("No valid modality input provided.")

        output_texts = inputs['output_texts']
        if inputs.get('image_paths', None):
            input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len,prompt_start = '### Human: <Img>')
            inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask,prompt_start = '### Human: <Img>')
        if inputs.get('audio_paths', None):
            input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len,prompt_start = '### Human: <audio>')
            inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask,prompt_start = '### Human: <audio>')
        

        outputs = self.llama_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss
        # calculate the token accuarcy
        chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1]    # [B, S-1]
        labels = targets[:, 2:]
        gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long)    # [B*S]
        valid_mask = (labels != -100).reshape(-1)
        valid_tokens = gen_acc & valid_mask    # [B*S]
        gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
        return loss, gen_acc

    def extract_multimodal_feature(self, inputs):
        features = []
        if inputs['image_paths']:
            image_embeds, _ = self.encode_image(inputs['image_paths'])
            features.append(image_embeds)
        if inputs['audio_paths']:
            audio_embeds, _ = self.encode_audio(inputs['audio_paths'])
            #audio_embeds, _ = self.encode_audio_tensor(inputs['audio'], inputs['sampling_rate'])
            features.append(audio_embeds)
        if inputs['video_paths']:
            video_embeds, _ = self.encode_video(inputs['video_paths'])
            features.append(video_embeds)
        if inputs['thermal_paths']:
            thermal_embeds, _ = self.encode_thermal(inputs['thermal_paths'])
            features.append(thermal_embeds)

        feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0)
        return feature_embeds

    def prepare_generation_embedding(self, inputs):
        prompt = inputs['prompt']
        if len(inputs['modality_embeds']) == 1:
            feature_embeds = inputs['modality_embeds'][0]
        else:
            feature_embeds = self.extract_multimodal_feature(inputs)
            inputs['modality_embeds'].append(feature_embeds)

        batch_size = feature_embeds.shape[0]
        if inputs.get('image_paths', None):
            p_before ='### Human: <Img>'
            p_before_tokens = self.llama_tokenizer(p_before, 
                return_tensors="pt", add_special_tokens=False).to(self.device)
            p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
            text = '</Img> Answer the following question in one sentence:' + prompt + '\n### Assistant:'
        if inputs.get('audio_paths', None):
            p_before ='### Human: <audio>'
            p_before_tokens = self.llama_tokenizer(p_before, 
                return_tensors="pt", add_special_tokens=False).to(self.device)
            p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
            text = '</audio> Please answer briefly:' + prompt + '\n### Assistant:'
        p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
        p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
        bos = torch.ones([batch_size, 1],
                         dtype=p_before_tokens.input_ids.dtype,
                         device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
        bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
        inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
        return inputs_embeds

    def generate(self, inputs):
        '''
            inputs = {
                'image_paths': optional,
                'audio_paths': optional
                'video_paths': optional
                'thermal_paths': optional
                'mode': generation mode,
                'prompt': human input prompt,
                'max_tgt_len': generation length,
                'top_p': top_p,
                'temperature': temperature
                'modality_embeds': None or torch.tensor
                'modality_cache': save the image cache
            }
        '''
        input_embeds = self.prepare_generation_embedding(inputs)
        stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)])
        outputs = self.llama_model.generate(
            inputs_embeds=input_embeds,
            max_new_tokens=inputs['max_tgt_len'],
            top_p=inputs['top_p'],
            temperature=inputs['temperature'],
            do_sample=True,
            use_cache=True,
            stopping_criteria=stopping_criteria,
        )
        output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True)
        #print(output_text)
        return output_text
    



