#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from fastchat.model.caption_decoder import CaptionDecoder
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, \
    LlamaConfig, LlamaModel, LlamaForCausalLM, \
    CLIPVisionModel, CLIPImageProcessor

from fastchat.model.modeling_t5 import T5ForConditionalGeneration
import numpy as np
import random

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"

from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)


def get_rand_des():
    text = ['Describe the image concisely.',
            'Provide a brief description of the given image.',
            'Offer a succinct explanation of the picture presented.',
            'Can you describe this image?',
            'Summarize the visual content of the image.',
            'Give a short and clear explanation of the subsequent image.',
            'Share a concise interpretation of the image provided.',
            'Present a compact description of the photo’s key features.',
            'Relay a brief, clear account of the picture shown.',
            'Render a clear and concise summary of the photo.',
            'Write a terse but informative summary of the picture.',
            'Create a compact narrative representing the image presented.']

    return text[random.randint(0, 11)]


def mos1(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class T5Load(object):
    def __init__(self, device, pretrained_path, hidden_dim=-1):
        self.fastchat = T5ForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path="/data3/xiangyu/checkpoints_flant5_3b/checkpoint-1000",
            # pretrained_model_name_or_path="lmsys/fastchat-t5-3b-v1.0",
            low_cpu_mem_usage=True,
            **kwargs)
        self.caption_model.eval()
        self.caption_model.to(device)
        self.caption_model.requires_grad_(False)


class DiffusionLlmForCausalLM(nn.Module):

    def __init__(self, pretrained_path="/home/data2/xiangyu/Code/unidiffuser/models/caption_decoder.pth", hidden_dim=64,
                 freeze_llm=True, freeze_diffusion=True):
        super().__init__()
        kwargs = {"torch_dtype": torch.float16}
        kwargs["device_map"] = "auto"
        '''
        self.caption_model.eval()
        self.caption_model.to(device)
        self.caption_model.requires_grad_(False)
        '''
        self.caption_decoder = CaptionDecoder(device="cuda", pretrained_path=pretrained_path, hidden_dim=hidden_dim)
        print("finishing loading caption decoder")
        # self.fastchat = T5ForConditionalGeneration()
        # self.fastchat.load_state_dict(torch.load('/data3/xiangyu/checkpoints_mm_llm/checkpoint-30000/pytorch_model.bin'))
        self.fastchat = T5ForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path="/home/data2/xiangyu/Data/checkpoints_flant5_3b/checkpoint-1000",
            # pretrained_model_name_or_path="lmsys/fastchat-t5-3b-v1.0",
            low_cpu_mem_usage=True,
            use_cache=True,
        ).cuda()
        print("finishing loading fastchat")
        self.fastchat = self.fastchat.eval()
        if freeze_llm:
            self.fastchat.requires_grad_(False)
            for name, param in self.fastchat.named_parameters():
                param.requires_grad = False
        self.encoder = self.fastchat.get_encoder()
        # if freeze_diffusion:
        #     for name, param in self.caption_decoder.named_parameters():
        #         param.requires_grad = False
        # self.fastchat_proj = nn.Linear(768, 2048)
        self.fastchat_proj = Mlp(in_features=768, hidden_features=768*4, out_features=2048, act_layer=nn.GELU, drop=0.)
        self.tokenizer = transformers.T5Tokenizer.from_pretrained(
            pretrained_model_name_or_path='lmsys/fastchat-t5-3b-v1.0',
            model_max_length=2048,
            padding_side="right",
            use_fast=False,
        )

    def encode_img(self, image):
        device = image[-1].device
        caption_d64 = self.caption_decoder.encode_prefix(image)
        caption_d768 = self.caption_decoder.decode_prefix(caption_d64)
        inputs_fastchat = self.fastchat_proj(caption_d768)
        atts_image = torch.ones(inputs_fastchat.size()[:-1], dtype=torch.long).to(device)
        return inputs_fastchat, atts_image

    def proj_image(self, tmp):
        device = tmp[-1].device
        inputs_fastchat = self.fastchat_proj(tmp)
        atts_image = torch.ones(inputs_fastchat.size()[:-1], dtype=torch.long).to(device)
        return inputs_fastchat, atts_image

    def prompt_wrap(self, img_embeds, atts_img, prompt):
        if prompt:
            batch_size = img_embeds.shape[0]
            p_before, p_after = prompt.split('<ImageHere>')
            p_before_tokens = self.tokenizer(
                p_before, return_tensors="pt", max_length=None).to(img_embeds.device)
            p_after_tokens = self.tokenizer(
                p_after, return_tensors="pt", max_length=None).to(img_embeds.device)
            embed_tokens = self.fastchat.get_input_embeddings()
            p_before_embeds = embed_tokens(p_before_tokens.input_ids). \
                expand(batch_size, -1, -1)
            p_after_embeds = embed_tokens(p_after_tokens.input_ids). \
                expand(batch_size, -1, -1)
            wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
            wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
            return wrapped_img_embeds, wrapped_atts_img
        else:
            return img_embeds, atts_img

    def qa_prompt_wrap(self, img_embeds, atts_img, prompt, labels):
        if prompt:
            batch_size = img_embeds.shape[0]
            p_before, p_after = prompt.split('<ImageHere>')
            a, b = p_after.split('<question>')
            p_before_tokens = self.tokenizer(
                p_before, return_tensors="pt", max_length=None).to(img_embeds.device)
            p_after_tokens = self.tokenizer(
                a, return_tensors="pt", max_length=None).to(img_embeds.device)
            p_last_tokens = self.tokenizer(
                b, return_tensors="pt", max_length=None).to(img_embeds.device)
            for i in range(len(labels)):
                labels[i] = labels[i].to(img_embeds.device)
                labels[i] = torch.cat([labels[i], p_last_tokens.input_ids[0]])
            labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                     batch_first=True,
                                                     padding_value=self.tokenizer.pad_token_id)
            attention_mask = labels.ne(self.tokenizer.pad_token_id)
            a_mask = p_before_tokens.attention_mask.expand(batch_size, -1)
            b_mask = p_after_tokens.attention_mask.expand(batch_size, -1)
            attention_mask = torch.cat([a_mask, atts_img, b_mask, attention_mask], dim=1)
            embed_tokens = self.fastchat.get_input_embeddings()
            p_before_embeds = embed_tokens(p_before_tokens.input_ids). \
                expand(batch_size, -1, -1)
            p_after_embeds = embed_tokens(p_after_tokens.input_ids). \
                expand(batch_size, -1, -1)
            # p_last_embeds = embed_tokens(p_last_tokens.input_ids). \
            #     expand(batch_size, -1, -1)
            labels_embeds = embed_tokens(labels)
            wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds, labels_embeds], dim=1)
            # wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
            return wrapped_img_embeds, attention_mask
        else:
            return img_embeds, atts_img

    def get_model(self):
        return self.fastchat

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            encode_text: Optional[Tuple[Tuple[torch.Tensor]]] = None,
            labels = None,
            original_labels=None,
    ):
        img_embeds, atts_image = self.proj_image(encode_text)
        # mos = nn.MSELoss()
        # loss_mos = mos(img_embeds, encode_text)
        query = get_rand_des()
        # vqa_prompt = '###Human: <Img><ImageHere></Img> ' + query + '\n' + '### Assistant: '
        vqa_prompt = '###Human: Please answer question from this image: <Img><ImageHere></Img> ' \
                     + ' Question: <question>' + '\n' + '### Assistant: '
        img_embeds, atts_image = self.qa_prompt_wrap(img_embeds, atts_image, vqa_prompt, labels)
        outputs = self.fastchat(
            inputs_embeds=img_embeds,
            attention_mask=atts_image,
            labels=original_labels
        )
        # outputs = self.fastchat(
        #     encoder_outputs=encode_text,
        #     labels=labels
        # )
        loss_llm = outputs.loss
        lm_logits = outputs.logits
        loss = loss_llm
        # return Seq2SeqLMOutput(
        #     loss=loss,
        #     logits=lm_logits
        # )
        return outputs

    def prepare_inputs_for_generation(
            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
            }
        )
        return model_inputs
