#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import argparse
import os
import shutil
import sys
import time
from functools import partial
from knn_cuda import KNN
import deepspeed
import numpy as np
import torch
import tqdm
import transformers
from peft import LoraConfig, get_peft_model
from einops import rearrange
from llava import conversation as conversation_lib

from llava.model.Uni3D.models import uni3d as modelss

from typing import List, Optional, Tuple, Union

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast

from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM,
                                                     LlavaLlamaModel)


import torch
import torch.nn.functional as F
from time import time
import numpy as np
from utils import *

class LisaMetaModel:
    def __init__(
        self,
        config,
        **kwargs,
    ):
        super(LisaMetaModel, self).__init__(config)

        self.config = config
        self.initialize_lisa_modules(self.config)
        
    def initialize_lisa_modules(self, config):
        self.point_encoder = PointCloudEncoder(config)

        # Projection layer
        in_dim = config.hidden_size
        out_dim = 256 #
        text_fc = [
            nn.Linear(in_dim, in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, out_dim),
            nn.Dropout(0.0),
        ]
        self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
        self.text_hidden_fcs.train()
        for param in self.text_hidden_fcs.parameters():
            param.requires_grad = True

class LisaModel(LisaMetaModel, LlavaLlamaModel):
    def __init__(
        self,
        config,
        **kwargs,
    ):
        super(LisaModel, self).__init__(config, **kwargs)

        self.config.mm_vision_select_feature = "patch"
        self.config.image_aspect_ratio = "square"
        self.config.image_grid_pinpoints = None
        self.config.tune_mm_mlp_adapter = False
        self.config.freeze_mm_mlp_adapter = True
        self.config.pretrain_mm_mlp_adapter = None
        self.config.mm_use_im_patch_token = False

class LISAForCausalLM(LlavaLlamaForCausalLM):
    def __init__(
        self,
        config,
        extra_seg_token=False,
        **kwargs,
    ):       
        self.seg_token_idx = kwargs.pop("seg_token_idx")

        super().__init__(config)

        self.model = LisaModel(config, **kwargs)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.h_fusion = h_fusion(512)
        self.sim_score = SimScore()
        self.decoder = Seg_Decoder
        self.loss_ca = L_ca()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.post_init()
        self.extra_seg_token=False
        self.text_projector = nn.Linear(4096, 512)

    def get_visual_embs(self, points, rgbs):
        return self.model.point_encoder(points, rgbs)

    def forward(self, **kwargs):
        if "past_key_values" in kwargs:
            return super().forward(**kwargs)
        return self.model_forward(**kwargs)
    
    def insert_false_before_true(self, arr):
        true_indices = torch.where(arr)[0]
        new_length = len(arr) + len(true_indices)
        new_arr = torch.zeros(new_length, dtype=bool).to(arr.device)
        offset = 0
        for idx in true_indices:
            adjusted_idx = idx + offset
            new_arr[adjusted_idx] = False
            new_arr[adjusted_idx + 1] = True
            offset += 1
        old_indices = torch.arange(len(arr)).to(arr.device)
        mask = torch.zeros(len(arr), dtype=bool).to(arr.device)
        mask[true_indices] = True
        offset_arr = torch.zeros(len(arr), dtype=int).to(arr.device)
        offset_arr[true_indices] = 1
        cumulative_offset = torch.cumsum(offset_arr, dim=0)
        new_indices = old_indices + cumulative_offset
        new_arr[new_indices[~mask]] = arr[~mask]
        return new_arr
    
    def insert_num_after_mask(self, input_ids, mask, num=-1):
        true_indices = torch.where(mask)[0]
        new_length = len(input_ids) + len(true_indices)
        new_input_ids = torch.full((new_length,), -1, dtype=input_ids.dtype).to(input_ids.device)
        offset = 0
        for idx in true_indices:
            adjusted_idx = idx + offset
            new_input_ids[adjusted_idx] = input_ids[idx]
            new_input_ids[adjusted_idx + 1] = num
            offset += 1
        old_indices = torch.arange(len(input_ids)).to(input_ids.device)
        mask_processed = torch.zeros(len(input_ids), dtype=bool).to(input_ids.device)
        mask_processed[true_indices] = True
        offset_arr = torch.zeros(len(input_ids), dtype=int).to(input_ids.device)
        offset_arr[true_indices] = 1
        cumulative_offset = torch.cumsum(offset_arr, dim=0)
        new_indices = old_indices + cumulative_offset
        new_input_ids[new_indices[~mask_processed]] = input_ids[~mask_processed]
        return new_input_ids

    def model_forward(
        self,
        points:torch.FloatTensor,
        colors:torch.FloatTensor,
        input_ids: torch.LongTensor,
        labels: torch.LongTensor,
        attention_masks: torch.Tensor,
        offset: torch.LongTensor,
        seg_label:torch.FloatTensor,
        logist_label:torch.FloatTensor,
        seg_type_ids:List,
        return_lm_out:bool=False,
        **kwargs,
    ):

        point_embeddings = self.get_visual_embs(points,colors) 
        batch_size = point_embeddings.shape[0]
        assert batch_size == len(offset) - 1

        seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
        seg_token_mask = torch.cat(
            [
                seg_token_mask,
                torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
            ],
            dim=1,
        )

        if self.extra_seg_token:
            input_seg_mask = input_ids == self.seg_token_idx
            max_token_num = seg_token_mask.sum(dim=-1).max()
            seg_token_mask_list = []
            input_ids_list = []
            label_list = []
            attention_mask_list = []
            for i in range(seg_token_mask.shape[0]):
                seg_num = seg_token_mask[i].sum().item()
                padding = max_token_num - seg_num
                seg_token_mask_list.append(torch.cat([self.insert_false_before_true(seg_token_mask[i]), torch.zeros(padding, dtype=torch.bool).to(seg_token_mask.device)], dim=0))
                input_ids_list.append(torch.cat([self.insert_num_after_mask(input_ids[i], input_seg_mask[i]), torch.ones(padding, dtype=torch.int).to(seg_token_mask.device)*2], dim=0))
                label_list.append(torch.cat([self.insert_num_after_mask(labels[i], input_seg_mask[i], -100), torch.ones(padding, dtype=torch.int).to(seg_token_mask.device)*-100], dim=0))
                attention_mask_list.append(torch.cat([self.insert_num_after_mask(attention_masks[i], input_seg_mask[i], True), torch.ones(padding, dtype=torch.bool).to(seg_token_mask.device)], dim=0))
            seg_token_mask = torch.stack(seg_token_mask_list)
            input_ids = torch.stack(input_ids_list)
            labels = torch.stack(label_list)
            attention_masks = torch.stack(attention_mask_list)

        seg_token_mask = torch.cat(
            [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
            dim=1,
        )
        points = torch.cat((points, colors),dim=-1)
        points_list = []
        for i in range(len(offset) - 1):
            start_i, end_i = offset[i], offset[i + 1]
            points_i = (
                points[i]
                .unsqueeze(0)
                .expand(end_i - start_i, -1, -1)
                .contiguous()
            )
            points_list.append(points_i)
        points_input = torch.cat(points_list, dim=0)
      
        output = super().forward(
                points=points_input,
                attention_mask=attention_masks,
                input_ids=input_ids,
                labels=labels,
                output_hidden_states=True,
            )
        output_hidden_states = output.hidden_states
        hidden_states = []
        assert len(self.model.text_hidden_fcs) == 1
        lm_out = self.lm_head(output.hidden_states[-1])
        hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
        last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) 
        sequence_length = last_hidden_state.size(1)
        padding_length = max(0, sequence_length - seg_token_mask.size(1))
        seg_token_mask = torch.cat(
                [torch.zeros((seg_token_mask.shape[0], padding_length), dtype=torch.bool).cuda(), seg_token_mask],dim=1,
                                    )
        pred_embeddings = last_hidden_state[seg_token_mask]
        seg_token_counts = seg_token_mask.int().sum(-1)  
        
        if seg_token_counts == 0:
            pred_embeddings = torch.zeros_like(last_hidden_state[:, 0, :], dtype=last_hidden_state.dtype).to(last_hidden_state.device)
            seg_token_counts = seg_token_counts + 1
        seg_token_offset = seg_token_counts.cumsum(-1)
        seg_token_offset = torch.cat(
            [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
        )

        seg_token_offset = seg_token_offset[offset]

        pred_embeddings_ = []

        for i in range(len(seg_token_offset) - 1):
            start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
            pred_embeddings_.append(pred_embeddings[start_i:end_i])
        text_embeddings = [
            torch.cat([
                self.text_projector(self.get_model().embed_tokens(torch.tensor(x).cuda())).mean(dim=0, keepdim=True) for x in y
            ], dim=0)
            for y in seg_type_ids
        ]
        pred_embeddings = pred_embeddings_
        # print("------------------------------")
        # print(len(pred_embeddings))
        pred_seg = []
        seg_loss = 0
        for i in range(len(pred_embeddings)):
            h = self.h_fusion(point_embeddings[i].unsqueeze(0), text_embeddings[i]) 

            seg = torch.softmax(self.decoder(h), dim=0).squeeze(-1)       
            # print(seg.shape)
            # seg = self.sim_score(point_embeddings[i], h)
            pred_seg.append(seg)
            seg_loss += self.loss_ca(seg, seg_label[i]) 
            seg_loss += self.cross_entropy(seg, seg_label[i].argmax(0))
        seg_loss = seg_loss / (len(pred_embeddings)+1)
        model_output = output
        ce_loss = model_output.loss
        loss = ce_loss + seg_loss
        
        if return_lm_out:
            return loss,pred_seg,seg_label, lm_out
        return ce_loss, seg_loss, pred_seg,seg_label
