#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from typing import List, Optional, Tuple, Union
from torch.cuda.amp import autocast as autocast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from fastchat.model.caption_decoder import CaptionDecoder
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, \
    LlamaConfig, LlamaModel, LlamaForCausalLM, \
    CLIPVisionModel, CLIPImageProcessor

from fastchat.model.modeling_llama import LlamaForCausalLM
import numpy as np
import random

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"

from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)


def get_rand_des():
    text = ['Describe the image concisely.',
            'Provide a brief description of the given image.',
            'Offer a succinct explanation of the picture presented.',
            'Can you describe this image?',
            'Summarize the visual content of the image.',
            'Give a short and clear explanation of the subsequent image.',
            'Share a concise interpretation of the image provided.',
            'Present a compact description of the photo’s key features.',
            'Relay a brief, clear account of the picture shown.',
            'Render a clear and concise summary of the photo.',
            'Write a terse but informative summary of the picture.',
            'Create a compact narrative representing the image presented.']

    return text[random.randint(0, 11)]


def mos1(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)


def convert_weights_to_fp16(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

#         if isinstance(l, (nn.MultiheadAttention, Attention)):
#             for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
#                 tensor = getattr(l, attr)
#                 if tensor is not None:
#                     tensor.data = tensor.data.half()

    model.apply(_convert_weights_to_fp16)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class DiffusionLlmForCausalLM(nn.Module):

    @property
    def device(self):
        return list(self.parameters())[0].device

    def maybe_autocast(self, dtype=torch.bfloat16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    def __init__(self, pretrained_path="/home/data2/xiangyu/Code/unidiffuser/models/caption_decoder.pth", hidden_dim=64,
                 freeze_llm=False, freeze_proj=True, precision="bf16", freeze_diffusion=True):
        super().__init__()
        kwargs = {"torch_dtype": torch.float16}
        kwargs["device_map"] = "auto"
        '''
        self.caption_model.eval()
        self.caption_model.to(device)
        self.caption_model.requires_grad_(False)
        '''
        self.caption_decoder = CaptionDecoder(device="cuda", pretrained_path=pretrained_path, hidden_dim=hidden_dim)
        print("finishing loading caption decoder")
        # self.fastchat = T5ForConditionalGeneration()
        # self.fastchat.load_state_dict(torch.load('/data3/xiangyu/checkpoints_mm_llm/checkpoint-30000/pytorch_model.bin'))
        self.fastchat = LlamaForCausalLM.from_pretrained(
            pretrained_model_name_or_path="/home/data2/xiangyu/Data/Vicuna-7b",
            torch_dtype=torch.float16,
        ).cuda()
        print("finishing loading fastchat")
        if freeze_llm:
            # self.fastchat.requires_grad_(False)
            for name, param in self.fastchat.named_parameters():
                if name == "model.embed_tokens.weight":
                    param.requires_grad = False
        params_grad = [n for n, p in self.fastchat.named_parameters() if p.requires_grad]
        print(params_grad)
        # if freeze_diffusion:
        #     for name, param in self.caption_decoder.named_parameters():
        #         param.requires_grad = False
        # self.fastchat_proj = nn.Linear(768, 4096)
        self.fastchat_proj = Mlp(in_features=768, hidden_features=768 * 4, out_features=4096, act_layer=nn.GELU,
                                 drop=0.)
        if precision == "fp16":
            convert_weights_to_fp16(self.fastchat_proj)

        if freeze_proj:
            self.fastchat_proj.requires_grad_(False)
            for name, param in self.fastchat_proj.named_parameters():
                param.requires_grad = False

        self.tokenizer = transformers.LlamaTokenizer.from_pretrained(
            pretrained_model_name_or_path='/home/data2/xiangyu/Data/Vicuna-7b',
            use_fast=False,
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def encode_img(self, image):
        device = image[-1].device
        caption_d64 = self.caption_decoder.encode_prefix(image)
        caption_d768 = self.caption_decoder.decode_prefix(caption_d64)
        inputs_fastchat = self.fastchat_proj(caption_d768)
        atts_image = torch.ones(inputs_fastchat.size()[:-1], dtype=torch.long).to(device)
        return inputs_fastchat, atts_image

    def proj_image(self, tmp):
        device = tmp[-1].device
        with self.maybe_autocast():
            inputs_fastchat = self.fastchat_proj(tmp)
            atts_image = torch.ones(inputs_fastchat.size()[:-1], dtype=torch.long).to(device)
        return inputs_fastchat, atts_image

    def prompt_wrap(self, img_embeds, atts_img, prompt):
        if prompt:
            batch_size = img_embeds.shape[0]
            p_before, p_after = prompt.split('<ImageHere>')
            p_before_tokens = self.tokenizer(
                p_before, return_tensors="pt", max_length=None).to(img_embeds.device)
            p_after_tokens = self.tokenizer(
                p_after, return_tensors="pt", max_length=None).to(img_embeds.device)
            # embed_tokens = self.fastchat.get_input_embeddings()
            with self.maybe_autocast():
                p_before_embeds = self.fastchat.model.embed_tokens(p_before_tokens.input_ids). \
                    expand(batch_size, -1, -1)
                p_after_embeds = self.fastchat.model.embed_tokens(p_after_tokens.input_ids). \
                    expand(batch_size, -1, -1)
                wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
                # print(wrapped_img_embeds[0][60][:50])
                wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
            return wrapped_img_embeds, wrapped_atts_img
        else:
            return img_embeds, atts_img

    def get_model(self):
        return self.fastchat

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            query=None,
            labels=None,
            original_labels=None,
    ):
        original_labels = original_labels.cuda()
        img_embeds, atts_img = self.proj_image(original_labels)
        # print(img_embeds[0][60][:50])
        # mos = nn.MSELoss()
        # loss_mos = mos(img_embeds, encode_text)
        # query = get_rand_des()
        # vqa_prompt = '###Human: <Img><ImageHere></Img> ' + query + '\n' + '### Assistant: '
        vqa_prompt = '###Human: <Img><ImageHere></Img> '
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
        self.tokenizer.padding_side = "right"
        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
                       dtype=torch.long).to('cuda').fill_(-100)  # plus one for bos
        )
        # targets = labels.input_ids
        # targets = torch.nn.utils.rnn.pad_sequence(targets,
        #                                           batch_first=True,
        #                                           padding_value=IGNORE_INDEX)
        to_regress_tokens = self.tokenizer(
            labels,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=77,
            add_special_tokens=False
        ).to(img_embeds.device)

        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == self.tokenizer.pad_token_id, -100
        )
        targets = torch.cat([empty_targets, targets], dim=1)
        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.tokenizer.bos_token_id
        with self.maybe_autocast():
            bos_embeds = self.fastchat.model.embed_tokens(bos)
        atts_bos = atts_img[:, :1]
        with self.maybe_autocast():
            to_regress_embeds = self.fastchat.model.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
        # print(inputs_embeds.dtype)
        # print(attention_mask[0][100])
        with self.maybe_autocast():
            outputs = self.fastchat(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=targets
            )

        loss = outputs.loss
        return {"loss": loss}

    def prepare_inputs_for_generation(
            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
            }
        )
        return model_inputs
