#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import argparse
import os
import shutil
import sys
import time
from functools import partial
from knn_cuda import KNN
import deepspeed
import numpy as np
import torch
import tqdm
import transformers
from peft import LoraConfig, get_peft_model
from einops import rearrange
from llava import conversation as conversation_lib

from llava.model.Uni3D.models import uni3d as modelss

from typing import List, Optional, Tuple, Union

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast

from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM,
                                                     LlavaLlamaModel)


import torch
import torch.nn.functional as F
from time import time
import numpy as np
from utils import *

class LisaMetaModel:
    def __init__(
        self,
        config,
        **kwargs,
    ):
        super(LisaMetaModel, self).__init__(config)

        self.config = config
        self.initialize_lisa_modules(self.config)
        

    def initialize_lisa_modules(self, config):
        self.point_model = modelss.create_uni3d()
        

        checkpoint = torch.load("./models/Uni3D/modelzoo/uni3d-b/model.pt", map_location="cpu")
        # logging.info('loaded checkpoint {}'.format(args.ckpt_path))
        sd = checkpoint['module']
        distributed = False
        if not distributed and next(iter(sd.items()))[0].startswith('module'):
            sd = {k[len('module.'):]: v for k, v in sd.items()}
        self.point_model.load_state_dict(sd)

            # 遍历模型的所有参数和名称
        for name, param in self.point_model.named_parameters():            
            # if name.startswith('point_encoder.visual.blocks.10') or name.startswith('point_encoder.visual.blocks.11'):  
            #     param.requires_grad = True
            # else:
            param.requires_grad = False

        print("using Uni3D as the point backbone!")
        self.point_model = self.point_model



    # Projection layer
        in_dim = config.hidden_size
        out_dim = 256#
        text_fc = [
            nn.Linear(in_dim, in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, out_dim),
            nn.Dropout(0.0),
        ]
        self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
        self.text_hidden_fcs.train()
        for param in self.text_hidden_fcs.parameters():
            param.requires_grad = True

class LisaModel(LisaMetaModel, LlavaLlamaModel):
    def __init__(
        self,
        config,
        **kwargs,
    ):
        super(LisaModel, self).__init__(config, **kwargs)

        # self.config.use_cache = False
        # self.config.vision_tower = self.config.mm_vision_tower
        self.config.mm_vision_select_feature = "patch"
        self.config.image_aspect_ratio = "square"
        self.config.image_grid_pinpoints = None
        self.config.tune_mm_mlp_adapter = False
        self.config.freeze_mm_mlp_adapter = True
        self.config.pretrain_mm_mlp_adapter = None
        self.config.mm_use_im_patch_token = False

class LISAForCausalLM(LlavaLlamaForCausalLM):
    def __init__(
        self,
        config,
        extra_seg_token=False,
        **kwargs,
    ):       
        self.seg_token_idx = kwargs.pop("seg_token_idx")

        super().__init__(config)

        self.model = LisaModel(config, **kwargs)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.projection = nn.Sequential(
                            nn.Linear(256, 512),   
                            nn.ReLU(),           
                            nn.Linear(512, 512),  
                            nn.ReLU(),
                        )
        

        self.Geometry_Correlation = Curvature_guided_Geometric_Correlation(512)
        self.sim_score = SimScore()
        self.propagation_2 = PointNetFeaturePropagation(in_channel= 768+ 3, mlp = [768, 768])
        self.propagation_1= PointNetFeaturePropagation(in_channel= 768 + 3, mlp = [768 , 768])
        self.propagation_0 = PointNetFeaturePropagation(in_channel= 768 + 3+3, mlp = [768, 512])
        self.dgcnn_pro_1 = DGCNN_Propagation(k = 4)
        self.dgcnn_pro_2 = DGCNN_Propagation(k = 4)
        self.decoder = Decoder(512)
        self.loss_ca = L_ca()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.post_init()
        self.extra_seg_token=False
        self.text_projector = nn.Linear(4096, 512)

    def get_visual_embs(self, points, rgbs):
        if rgbs is None:
            rgbs = torch.full_like(points, 0.4).to(points.device)
        points = torch.cat((points, rgbs),dim=-1) #torch.Size([2, 2048, 6])
        h4,h8,h12,pts,center_level_0,center_level_1,center_level_2,center_level_3= self.model.point_model.encode_pc(points)

        h4 = h4.permute(0,2,1)
        h8 = h8.permute(0,2,1)
        h12=h12.permute(0,2,1)


        f_level_1 = center_level_1
        f_level_2 = center_level_2
        f_level_3 = h12
        f_level_2 = self.propagation_2(center_level_2, center_level_3, f_level_2, h8)
        f_level_1 = self.propagation_1(center_level_1, center_level_3, f_level_1, h4)
        
        f_level_2 = self.dgcnn_pro_2(center_level_3, f_level_3, center_level_2, f_level_2)
        f_level_1 = self.dgcnn_pro_1(center_level_2, f_level_2, center_level_1, f_level_1)
        f_level_0 =  self.propagation_0(center_level_0, center_level_1, points.transpose(1,2), f_level_1)


        return f_level_0.transpose(1,2)

    def forward(self, **kwargs):
        if "past_key_values" in kwargs:
            return super().forward(**kwargs)
        return self.model_forward(**kwargs)
    
    def insert_false_before_true(self, arr):
        true_indices = torch.where(arr)[0]
        new_length = len(arr) + len(true_indices)
        new_arr = torch.zeros(new_length, dtype=bool).to(arr.device)
        offset = 0
        for idx in true_indices:
            adjusted_idx = idx + offset
            new_arr[adjusted_idx] = False
            new_arr[adjusted_idx + 1] = True
            offset += 1
        old_indices = torch.arange(len(arr)).to(arr.device)
        mask = torch.zeros(len(arr), dtype=bool).to(arr.device)
        mask[true_indices] = True
        offset_arr = torch.zeros(len(arr), dtype=int).to(arr.device)
        offset_arr[true_indices] = 1
        cumulative_offset = torch.cumsum(offset_arr, dim=0)
        new_indices = old_indices + cumulative_offset
        new_arr[new_indices[~mask]] = arr[~mask]
        return new_arr
    
    def insert_num_after_mask(self, input_ids, mask, num=-1):
        true_indices = torch.where(mask)[0]
        new_length = len(input_ids) + len(true_indices)
        new_input_ids = torch.full((new_length,), -1, dtype=input_ids.dtype).to(input_ids.device)
        offset = 0
        for idx in true_indices:
            adjusted_idx = idx + offset
            new_input_ids[adjusted_idx] = input_ids[idx]
            new_input_ids[adjusted_idx + 1] = num
            offset += 1
        old_indices = torch.arange(len(input_ids)).to(input_ids.device)
        mask_processed = torch.zeros(len(input_ids), dtype=bool).to(input_ids.device)
        mask_processed[true_indices] = True
        offset_arr = torch.zeros(len(input_ids), dtype=int).to(input_ids.device)
        offset_arr[true_indices] = 1
        cumulative_offset = torch.cumsum(offset_arr, dim=0)
        new_indices = old_indices + cumulative_offset
        new_input_ids[new_indices[~mask_processed]] = input_ids[~mask_processed]
        return new_input_ids

    def model_forward(
        self,
        points:torch.FloatTensor,
        colors:torch.FloatTensor,
        input_ids: torch.LongTensor,
        labels: torch.LongTensor,
        attention_masks: torch.Tensor,
        offset: torch.LongTensor,
        seg_label:torch.FloatTensor,
        logist_label:torch.FloatTensor,
        seg_type_ids:List,
        return_lm_out:bool=False,
        **kwargs,
    ):

        point_embeddings = self.get_visual_embs(points,colors) # torch.Size([1, 3, 2048]) -> torch.Size([1, 2048, 512])
        batch_size = point_embeddings.shape[0]
        assert batch_size == len(offset) - 1

        seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
        seg_token_mask = torch.cat(
            [
                seg_token_mask,
                torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
            ],
            dim=1,
        )

        if self.extra_seg_token:
            input_seg_mask = input_ids == self.seg_token_idx
            max_token_num = seg_token_mask.sum(dim=-1).max()
            seg_token_mask_list = []
            input_ids_list = []
            label_list = []
            attention_mask_list = []
            for i in range(seg_token_mask.shape[0]):
                seg_num = seg_token_mask[i].sum().item()
                padding = max_token_num - seg_num
                # the token after seg token is what we wanted, thus True mask are shifted right, neg1s are inserted after seg token ids
                seg_token_mask_list.append(torch.cat([self.insert_false_before_true(seg_token_mask[i]), torch.zeros(padding, dtype=torch.bool).to(seg_token_mask.device)], dim=0))
                input_ids_list.append(torch.cat([self.insert_num_after_mask(input_ids[i], input_seg_mask[i]), torch.ones(padding, dtype=torch.int).to(seg_token_mask.device)*2], dim=0))
                label_list.append(torch.cat([self.insert_num_after_mask(labels[i], input_seg_mask[i], -100), torch.ones(padding, dtype=torch.int).to(seg_token_mask.device)*-100], dim=0))
                attention_mask_list.append(torch.cat([self.insert_num_after_mask(attention_masks[i], input_seg_mask[i], True), torch.ones(padding, dtype=torch.bool).to(seg_token_mask.device)], dim=0))
            seg_token_mask = torch.stack(seg_token_mask_list)
            input_ids = torch.stack(input_ids_list)
            labels = torch.stack(label_list)
            attention_masks = torch.stack(attention_mask_list)

        # TODO: maybe could be deleted
        seg_token_mask = torch.cat(
            [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
            dim=1,
        )
        points = torch.cat((points, colors),dim=-1)
        points_list = []
        for i in range(len(offset) - 1):
            start_i, end_i = offset[i], offset[i + 1]
            points_i = (
                points[i]
                .unsqueeze(0)
                .expand(end_i - start_i, -1, -1)
                .contiguous()
            )
            points_list.append(points_i)
        points_input = torch.cat(points_list, dim=0)
      
        output = super().forward(
                points=points_input,
                attention_mask=attention_masks,
                input_ids=input_ids,
                labels=labels,
                output_hidden_states=True,
            )
        output_hidden_states = output.hidden_states
        hidden_states = []
        assert len(self.model.text_hidden_fcs) == 1
        # last_hidden_state = output_hidden_states[-1]
        lm_out = self.lm_head(output.hidden_states[-1])
        hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
        last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) #torch.Size([2, 1209, 256])
        sequence_length = last_hidden_state.size(1)
        padding_length = max(0, sequence_length - seg_token_mask.size(1))
        seg_token_mask = torch.cat(
                [torch.zeros((seg_token_mask.shape[0], padding_length), dtype=torch.bool).cuda(), seg_token_mask],dim=1,
                                    )
        pred_embeddings = last_hidden_state[seg_token_mask]
        seg_token_counts = seg_token_mask.int().sum(-1)  # [bs, ]
        
        if seg_token_counts == 0:
            pred_embeddings = torch.zeros_like(last_hidden_state[:, 0, :], dtype=last_hidden_state.dtype).to(last_hidden_state.device)
            seg_token_counts = seg_token_counts + 1
        seg_token_offset = seg_token_counts.cumsum(-1)
        seg_token_offset = torch.cat(
            [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
        )

        seg_token_offset = seg_token_offset[offset]

        pred_embeddings_ = []

        for i in range(len(seg_token_offset) - 1):
            start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
            pred_embeddings_.append(pred_embeddings[start_i:end_i])
        text_embeddings = [
            torch.cat([
                self.text_projector(self.get_model().embed_tokens(torch.tensor(x).cuda())).mean(dim=0, keepdim=True) for x in y
            ], dim=0)
            for y in seg_type_ids
        ] # [n, 512]
        pred_embeddings = pred_embeddings_
        # print("------------------------------")
        # print(len(pred_embeddings))
        pred_seg = []
        seg_loss = 0
        for i in range(len(pred_embeddings)):
            # hseg = self.projection(pred_embeddings[i]) + text_embeddings[i] # [n, 512]
            hseg = text_embeddings[i] # [n, 512]
            # print(hseg.shape)
            phi_a = self.Geometry_Correlation(point_embeddings[i].unsqueeze(0), hseg)  # phi_a[5, 2048, 512]

            seg = torch.softmax(self.decoder(phi_a), dim=0).squeeze(-1)               # seg[5, 2048, 1]
            # print(seg.shape)
            # seg = self.sim_score(point_embeddings[i], hseg)
            pred_seg.append(seg)
            seg_loss += self.loss_ca(seg, seg_label[i]) # [5, 2048, 1]
            seg_loss += self.cross_entropy(seg, seg_label[i].argmax(0))
        seg_loss = seg_loss / (len(pred_embeddings)+1)
        model_output = output
        ce_loss = model_output.loss
        loss = ce_loss + seg_loss
        
        if return_lm_out:
            return loss,pred_seg,seg_label, lm_out
        return ce_loss, seg_loss, pred_seg,seg_label




    def evaluate(
        self,
        points,
        input_ids,
        max_new_tokens=32,
        seg_type_ids=[],
        tokenizer=None,
        use_cache=True
    ):
        with torch.no_grad():
            outputs = self.generate(
                points=points,
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                num_beams=1,
                output_hidden_states=True,
                return_dict_in_generate=True,
                use_cache=use_cache
            )
            output_ids = outputs.sequences
            output_ids = output_ids[:, input_ids.shape[1]:]
            
            # 提取模型最后一层的隐藏状态
            # output_hidden_states = outputs.hidden_states[-1][-1][:, -output_ids.shape[1]:, :]
            hidden_states = [h[-1][:, -1:, :] for h in outputs.hidden_states]
            hidden_states = torch.cat(hidden_states, dim=-2) # [bs, new_gen_seq_len, hidden_size]
            # 找出 SEG token 的位置，生成 mask
            seg_token_mask = output_ids == self.seg_token_idx
            
            seg_token_counts = seg_token_mask.int().sum(-1)  # [bs, ]
            seg_token_offset = seg_token_counts.cumsum(-1)
            seg_token_offset = torch.cat(
                [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
            )
            # print('================================')
            # print('seg_token_counts:',seg_token_counts)
            if seg_token_counts == 0: # 如果没有生成 SEG token，强行拿最后一个 token 来解码
                seg_token_mask[:, -1] = True


            # hidden_states = []

            assert len(self.model.text_hidden_fcs) == 1

            # hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
            # last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
            last_hidden_state = self.model.text_hidden_fcs[0](hidden_states)
            pred_embeddings = last_hidden_state[:,-seg_token_mask.shape[-1]:][seg_token_mask]


            pred_embeddings_ = []
            for i in range(len(seg_token_offset) - 1):
                start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
                pred_embeddings_.append(pred_embeddings[start_i:end_i])
            pred_embeddings = pred_embeddings_

            text_embeddings = [
                torch.cat([
                    self.text_projector(self.get_model().embed_tokens(torch.tensor(x).cuda())).mean(dim=0, keepdim=True) for x in y
                ], dim=0)
                for y in seg_type_ids
            ] # [n, 512]

            point_embeddings = self.get_visual_embs(points[:, :, :3], points[:, :, 3:])

            multimask_output = False
            pred_masks = []
            for i in range(len(pred_embeddings)):
                hseg = self.projection(pred_embeddings[i]) + text_embeddings[i]
                # hseg = self.projection(pred_embeddings[i]) # torch.Size([4, 512])
                # print(hseg.shape)
                # phi_a = self.Geometry_Correlation(point_embeddings[i].unsqueeze(0), hseg)  # phi_a[4, 2048, 512]

                # mask = self.decoder(phi_a)       # seg [4, 2048, 1]
                mask = self.sim_score(point_embeddings[i], hseg).argmax(dim=-1)
                mask = torch.nn.functional.one_hot(mask)
                pred_masks.append(mask)

        return output_ids, pred_masks