# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from logging import getLogger
from typing import Optional
from REC.utils.enum_type import InputType
from REC.model.basemodel import BaseModel, all_gather
from REC.model.HLLM.modeling_llama import LlamaForCausalLM
from REC.model.HLLM.modeling_mistral import MistralForCausalLM
from REC.model.HLLM.modeling_bert import BertModel
from REC.model.HLLM.baichuan.modeling_baichuan import BaichuanForCausalLM
from REC.model.HLLM.modeling_qwen2_vl import Qwen2VLForConditionalGeneration


class HLLM_V(BaseModel):
    input_type = InputType.SEQ

    def __init__(self, config, dataload):
        super(HLLM_V, self).__init__()
        self.logger = getLogger()

        self.item_pretrain_dir = config['item_pretrain_dir']
        self.user_pretrain_dir = config['user_pretrain_dir']
        self.gradient_checkpointing = config['gradient_checkpointing']
        self.use_ft_flash_attn = config['use_ft_flash_attn']
        self.logger.info(f"create item llm")
        self.item_llm = self.create_llm(self.item_pretrain_dir, config['item_llm_init'])
        self.logger.info(f"create user llm")
        self.user_llm = self.create_llm(self.user_pretrain_dir, config['user_llm_init'])
        self.item_emb_token_n = config['item_emb_token_n']
        if self.item_emb_token_n > 1:
            raise NotImplementedError(f"Not support item_emb_token_n {self.item_emb_token_n} > 1")

        if self.item_emb_token_n > 0:
            self.item_emb_tokens = nn.Parameter(
                torch.zeros(1, self.item_emb_token_n, self.item_llm.config.hidden_size)
            )
            self.item_emb_tokens.data.normal_(mean=0.0, std=0.02)
            if config['item_emb_pretrain']:
                ckpt = torch.load(config['item_emb_pretrain'], map_location='cpu')
                self.logger.info(f"load item_emb_token from {config['item_emb_pretrain']} with {ckpt.size()}")
                self.item_emb_tokens.data = nn.Parameter(ckpt)
        else:  # mean pooling
            self.item_emb_tokens = None

        self.loss = config['loss']
        if self.loss == 'nce':
            self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            self.nce_thres = config['nce_thres'] if config['nce_thres'] else 0.99
            self.num_negatives = config['num_negatives']
            self.logger.info(f"nce thres setting to {self.nce_thres}")
        else:
            raise NotImplementedError(f"Only nce is supported")

        if config['load_pretrain']:
            state_dict = torch.load(config['load_pretrain'], map_location="cpu")
            msg = self.load_state_dict(state_dict, strict=False)
            self.logger.info(f"{msg.missing_keys = }")
            self.logger.info(f"{msg.unexpected_keys = }")

    def create_llm(self, pretrain_dir, init=True):
        self.logger.info(f"******* create LLM {pretrain_dir} *******")
        hf_config = AutoConfig.from_pretrained(pretrain_dir, trust_remote_code=True)
        self.logger.info(f"hf_config: {hf_config}")
        hf_config.gradient_checkpointing = self.gradient_checkpointing
        hf_config.use_cache = False
        hf_config.output_hidden_states = True
        hf_config.return_dict = True

        self.logger.info("xxxxx starting loading checkpoint")
        if isinstance(hf_config, transformers.LlamaConfig):
            hf_config.use_ft_flash_attn = self.use_ft_flash_attn
            self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for llama')
            self.logger.info(f'Init {init} for llama')
            if init:
                return LlamaForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
            else:
                return LlamaForCausalLM(config=hf_config).cuda()
        elif isinstance(hf_config, transformers.MistralConfig):
            hf_config.use_ft_flash_attn = self.use_ft_flash_attn
            self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for mistral')
            self.logger.info(f'Init {init} for mistral')
            if init:
                return MistralForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
            else:
                return MistralForCausalLM(config=hf_config).cuda()
        elif isinstance(hf_config, transformers.BertConfig):
            hf_config.use_ft_flash_attn = self.use_ft_flash_attn
            self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for bert')
            self.logger.info(f'Init {init} for bert')
            if init:
                return BertModel.from_pretrained(pretrain_dir, config=hf_config)
            else:
                return BertModel(config=hf_config).cuda()
        elif getattr(hf_config, "model_type", None) == "baichuan":
            hf_config.use_ft_flash_attn = self.use_ft_flash_attn
            self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for baichuan')
            self.logger.info(f'Init {init} for baichuan')
            if init:
                return BaichuanForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
            else:
                return BaichuanForCausalLM(config=hf_config).cuda()
        elif isinstance(hf_config, transformers.Qwen2VLConfig):
            hf_config.use_ft_flash_attn = self.use_ft_flash_attn
            self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for qwen2vl')
            self.logger.info(f'Init {init} for qwen2vl')
            if init:
                return Qwen2VLForConditionalGeneration.from_pretrained(pretrain_dir, config=hf_config)
            else:
                return Qwen2VLForConditionalGeneration(config=hf_config).cuda()
        else:
            return AutoModelForCausalLM.from_pretrained(
                self.local_dir, config=hf_config
            )


    def nce_loss(self, cur_embs, target_pos, target_neg, user_attention_mask):
        with torch.no_grad():
            self.logit_scale.clamp_(0, np.log(100))
        logit_scale = self.logit_scale.exp()
        D = target_neg.size(-1)
        output_embs = cur_embs / cur_embs.norm(dim=-1, keepdim=True)
        target_pos_embs = target_pos / target_pos.norm(dim=-1, keepdim=True)
        pos_logits = F.cosine_similarity(output_embs, target_pos_embs, dim=-1).unsqueeze(-1)

        target_neg = target_neg / target_neg.norm(dim=-1, keepdim=True)

        neg_embedding_all = all_gather(target_neg, sync_grads=True).reshape(-1, D)  # [num, dim]
        neg_embedding_all = neg_embedding_all.transpose(-1, -2)
        neg_logits = torch.matmul(output_embs, neg_embedding_all)
        fix_logits = torch.matmul(target_pos_embs, neg_embedding_all)
        neg_logits[fix_logits > self.nce_thres] = torch.finfo(neg_logits.dtype).min

        logits = torch.cat([pos_logits, neg_logits], dim=-1)
        logits = logits[user_attention_mask.bool()] * logit_scale
        labels = torch.zeros(logits.size(0), device=logits.device, dtype=torch.int64)
        return logits, labels

    def forward_item_emb(
        self,
        input_ids,
        position_ids,
        cu_input_lens,
        emb_token_n,
        emb_tokens,
        llm,
        pixel_values:Optional[torch.tensor] = None,
        image_grid_thw:Optional[torch.tensor] = None,
    ):
        # inputs_embeds = llm.get_input_embeddings()(input_ids)
        # emb_pos = cu_input_lens.cumsum(dim=0, dtype=torch.int32)
        # if emb_token_n > 0: #这个放到qwen2 forward里
        #     inputs_embeds[emb_pos - 1] = emb_tokens 
        # if pixel_values == None:
        #     model_out = llm(
        #     inputs_embeds=inputs_embeds.unsqueeze(0),
        #     #input_ids=input_ids.unsqueeze(0),
        #     position_ids=position_ids.unsqueeze(0),
        #     cu_input_lens=cu_input_lens,
        #     )
        emb_pos = cu_input_lens.cumsum(dim=0, dtype=torch.int32)
        # inputs_embeds = llm.get_input_embeddings()(input_ids)
        # if emb_token_n > 0: #这个放到qwen2 forward里
        #     inputs_embeds[emb_pos - 1] = emb_tokens 
        # model_out = llm(
        #         inputs_embeds=inputs_embeds.unsqueeze(0),
        #         #input_ids=input_ids.unsqueeze(0),
        #         position_ids=position_ids.unsqueeze(0),
        #         cu_input_lens=cu_input_lens,
        #         #pixel_values=pixel_values.unsqueeze(0),
        #         #pixel_values=pixel_values,
        #         #image_grid_thw=image_grid_thw.unsqueeze(0),
        #         #image_grid_thw=image_grid_thw
        #     )
        if pixel_values.shape[0] == 0: #纯文本，不用image
            inputs_embeds = llm.get_input_embeddings()(input_ids)
            if emb_token_n > 0: #这个放到qwen2 forward里
                inputs_embeds[emb_pos - 1] = emb_tokens 
            model_out = llm(
                inputs_embeds=inputs_embeds.unsqueeze(0),
                #input_ids=input_ids.unsqueeze(0),
                position_ids=position_ids.unsqueeze(0),
                cu_input_lens=cu_input_lens,
                #pixel_values=pixel_values.unsqueeze(0),
                #pixel_values=pixel_values,
                #image_grid_thw=image_grid_thw.unsqueeze(0),
                #image_grid_thw=image_grid_thw
            )
        else: #用image
            model_out = llm(
                #inputs_embeds=inputs_embeds.unsqueeze(0),
                input_ids=input_ids,
                position_ids=position_ids.unsqueeze(0),
                cu_input_lens=cu_input_lens,
                #pixel_values=pixel_values.unsqueeze(0),
                pixel_values=pixel_values,
                #image_grid_thw=image_grid_thw.unsqueeze(0),
                image_grid_thw=image_grid_thw,
                emb_tokens = emb_tokens,
                emb_pos = emb_pos
            )
        model_out = model_out.hidden_states[-1].squeeze(0)

        if emb_token_n > 0:
            emb = model_out[emb_pos - 1]
        # else:
        #     max_len = cu_input_lens.max().item()
        #     cu_seqlens = F.pad(cu_input_lens.cumsum(dim=0, dtype=torch.int32), (1, 0))
        #     seqs = [model_out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])]
        #     padded_seqs = [
        #         F.pad(
        #             seqs[i],
        #             (0, 0) * (seqs[i].dim() - 1) + (0, max_len - cu_input_lens[i]),
        #             value=0.0,
        #         )
        #         for i in range(cu_input_lens.size(0))
        #     ]
        #     out = torch.stack(padded_seqs)
        #     emb = out.sum(dim=1) / cu_input_lens.unsqueeze(1)

        return emb

    def forward(self, interaction, mode='train'): #model = "train"
        if mode == 'predict':
            return self.predict(interaction[0], interaction[1], interaction[2])
        if mode == 'compute_item':
            return self.compute_item(interaction)
        user_attention_mask = interaction['attention_mask']
        N, S = user_attention_mask.shape
        pos_input_ids, pos_cu_input_lens, pos_position_ids, pos_pixel_values, pos_image_grid_thw = interaction['pos_input_ids'], interaction['pos_cu_input_lens'], interaction['pos_position_ids'], interaction['pos_pixel_values'], interaction['pos_image_grid_thw']
        #pos_input_ids, pos_cu_input_lens, pos_position_ids = interaction['pos_input_ids'], interaction['pos_cu_input_lens'], interaction['pos_position_ids']
        neg_input_ids, neg_cu_input_lens, neg_position_ids, neg_pixel_values, neg_image_grid_thw = interaction['neg_input_ids'], interaction['neg_cu_input_lens'], interaction['neg_position_ids'], interaction['neg_pixel_values'], interaction['neg_image_grid_thw']
        #neg_input_ids, neg_cu_input_lens, neg_position_ids = interaction['neg_input_ids'], interaction['neg_cu_input_lens'], interaction['neg_position_ids']
        pos_embedding = self.forward_item_emb(pos_input_ids, pos_position_ids, pos_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm,pos_pixel_values, pos_image_grid_thw)
        #pos_embedding = self.forward_item_emb(pos_input_ids, pos_position_ids, pos_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm)
        pos_embedding = pos_embedding.reshape(N, S+1, -1)
        #neg_embedding = self.forward_item_emb(neg_input_ids, neg_position_ids, neg_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm)
        neg_embedding = self.forward_item_emb(neg_input_ids, neg_position_ids, neg_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm,neg_pixel_values, neg_image_grid_thw)
        neg_embedding = neg_embedding.reshape(N, -1, self.item_llm.config.hidden_size)

        target_pos_embs = pos_embedding[:, 1:]
        target_neg_embs = neg_embedding

        user_embedding = self.user_llm(inputs_embeds=pos_embedding[:, :-1], attention_mask=user_attention_mask).hidden_states[-1]

        model_out = {}
        logits, labels = self.nce_loss(user_embedding, target_pos_embs, target_neg_embs, user_attention_mask)
        model_out['loss'] = F.cross_entropy(logits, labels)
        model_out['nce_samples'] = (logits > torch.finfo(logits.dtype).min/100).sum(dim=1).float().mean()  # samples after filtering same negatives
        for k in [1, 5, 10, 50, 100]:
            if k > logits.size(1):
                break
            indices = logits.topk(k, dim=1).indices
            model_out[f"nce_top{k}_acc"] = labels.view(-1, 1).eq(indices).any(dim=1).float().mean()
        return model_out

    @torch.no_grad()
    def predict(self, item_seq, time_seq, item_feature):
        attention_mask = (item_seq > 0).int()

        #转化为user_llm的 dtype
        pos_embedding = item_feature[item_seq]

        user_embedding = self.user_llm(inputs_embeds=pos_embedding, attention_mask=attention_mask).hidden_states[-1]
        seq_output = user_embedding[:, -1]
        seq_output = seq_output / seq_output.norm(dim=-1, keepdim=True)
        item_feature = item_feature / item_feature.norm(dim=-1, keepdim=True)

        return torch.matmul(seq_output, item_feature.t())

    @torch.no_grad()
    def compute_item_all(self):
        return self.item_embedding.weight

    @torch.no_grad()
    def compute_item(self, interaction):#这一块没改
        pos_input_ids, pos_cu_input_lens, pos_position_ids, pos_pixel_values, pos_image_grid_thw = interaction['pos_input_ids'], interaction['pos_cu_input_lens'], interaction['pos_position_ids'], interaction['pos_pixel_values'], interaction['pos_image_grid_thw']
        pos_embedding = self.forward_item_emb(pos_input_ids, pos_position_ids, pos_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm,pos_pixel_values, pos_image_grid_thw)
        N = pos_cu_input_lens.size(0)
        pos_embedding = pos_embedding.view(N, -1)

        return pos_embedding
