from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch
import numpy as np


class PrePackProcessor:
    def __init__(self, tokenizer, packing_fn=None):
        self.tokenizer = tokenizer
        self.pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



    def parralel_process_tensor_batch_efficient2(self, doc, batch_q, parallel_tokens, base_prompt_length, base_q_length):
        if base_prompt_length == None:
            tokenized_doc = self.tokenizer(doc, return_tensors="pt",add_special_tokens=False, padding="longest")
            base_prompt_length = tokenized_doc["input_ids"].shape[-1]        
        if isinstance(doc, str):
            return self.parralel_process_efficient2(doc, batch_q,parallel_tokens, base_prompt_length, base_q_length)
        elif isinstance(doc, list):
            masks , f_tokens, f_positions = [], [] , []
            for item_doc, item_batch_q in zip(doc, batch_q):
                _, _, mask, _, _, _, f_token, f_position, _ = self.parralel_process_efficient2(item_doc, item_batch_q,parallel_tokens, base_prompt_length, base_q_length)
                masks.append(mask)
                f_tokens.append(f_token)
                f_positions.append(f_position)
            return _,_,torch.cat(masks),_,_,_,torch.cat(f_tokens),torch.cat(f_positions),_

    def parralel_process_efficient2(self, doc, batch_q, parallel_tokens, base_prompt_length, base_q_length):
        self.tokenizer.add_bos_token = False


        # start_time = time.time()
        tokenized_doc = self.tokenizer(doc, return_tensors="pt",add_special_tokens=False, padding="max_length", truncation=True,max_length = base_prompt_length)
        tokenized_questions = self.tokenizer(batch_q, return_tensors="pt", add_special_tokens=False, padding="max_length", truncation=True,max_length = base_q_length)
        torch.set_printoptions(profile="full")
        #print(tokenized_doc['input_ids'])
        #print(tokenized_questions['input_ids'])
        #var = input()
        #for item in tokenized_questions.keys():
        #    tokenized_questions["input_ids"][tokenized_questions["input_ids"] == 1186] = 28824
        number_ori_q, lenth_ori_q = list(tokenized_questions['input_ids'].shape)
        number_doc, lenth_ori_doc = list(tokenized_doc['input_ids'].shape)
        reduced_batch_size = int(number_ori_q/parallel_tokens)
        tokenized_doc["position_ids"] = tokenized_doc.attention_mask.long().cumsum(-1) -1

        def _intersect_order_changed(ids,reduced_batch_size):
            # used for one tensor with multi dimensions
            for key in ids.keys():
                stacked = torch.stack(torch.split(ids[key],reduced_batch_size), dim=2)
                ids[key] = torch.squeeze(torch.flatten(stacked, start_dim=1, end_dim=2))
            # print(stacked)
            return ids

        def _intersect(ids,reduced_batch_size):
            # used for one tensor with multi dimensions
            for key in ids.keys():
                # lenth_ori_q = ids[key].size(1)
                # number_ori_q = ids[key].size(0)
                ids[key] = ids[key].reshape(reduced_batch_size,int(lenth_ori_q*number_ori_q/reduced_batch_size))
                stacked = torch.stack(torch.split(ids[key],lenth_ori_q,-1), dim=2)
                ids[key] = torch.squeeze(torch.flatten(stacked, start_dim=1, end_dim=2))
            # print(stacked)
            return ids

        def _intersect_and_cat(doc_ids, ids,reduced_batch_size):
            # used for one tensor with multi dimensions
            for key in ids.keys():
                # lenth_ori_q = ids[key].size(1)
                # number_ori_q = ids[key].size(0)
                q_ids_key = ids[key]
                q_ids_key = q_ids_key.reshape(reduced_batch_size,int(lenth_ori_q*parallel_tokens))
                q_ids_key = torch.stack(torch.split(q_ids_key,lenth_ori_q,-1), dim=2)
                # q_ids_key = torch.squeeze(torch.flatten(q_ids_key, start_dim=1, end_dim=2))
                q_ids_key = torch.flatten(q_ids_key, start_dim=1, end_dim=2)
                doc_ids_key = doc_ids[key].repeat(reduced_batch_size,1)
                # print(doc_ids_key)
                # print(q_ids_key)

                ids[key] = torch.cat((doc_ids_key,q_ids_key),-1)

            return ids


        if len(tokenized_doc["position_ids"][0]) > 0:
            tokenized_questions["position_ids"] = tokenized_questions.attention_mask.long().cumsum(-1)+tokenized_doc["position_ids"][0][-1]
        else:
            tokenized_questions["position_ids"] = tokenized_questions.attention_mask.long().cumsum(-1)
        len_inter_id = parallel_tokens*lenth_ori_q

        # thrid_time = time.time()
        maskout_index_y = torch.arange(0,len_inter_id,parallel_tokens)

        maskout_index_y = torch.arange(0,len_inter_id,parallel_tokens)
        maskout_index = maskout_index_y[None,...]
        for i in range(1,parallel_tokens,1):
            maskout_index = torch.cat((maskout_index,maskout_index[None,-1]+1),0)
        maskout_index = maskout_index.repeat(lenth_ori_q,1)+lenth_ori_doc

        doc_mask = torch.ones(lenth_ori_doc+len_inter_id, lenth_ori_doc).long()
        q_mask = torch.zeros(len_inter_id,lenth_ori_doc+len_inter_id ).long()
        q_mask = torch.permute(q_mask.scatter_(1,maskout_index, 1),(1,0))
        
        new_mask = torch.tril(torch.cat((doc_mask,q_mask),1))

        inter_final_tokenized = _intersect_and_cat(tokenized_doc, tokenized_questions, reduced_batch_size)
        final_mask = inter_final_tokenized.attention_mask
        threed_mask = final_mask[:,None,:]*new_mask

        # forth_time = time.time()

        final_ids = inter_final_tokenized.input_ids

        final_position_ids = inter_final_tokenized.position_ids
        #print(threed_mask) 
        #print(final_ids)
        #print(final_position_ids)
        #var = input()
        # fifth_time = time.time()
        

        # print("[v2]Time for process tensor")
        # print("stage 1: extract id and mask")
        # print(second_time - start_time)
        # print("stage 2: extract position_id_from_mask")
        # print(thrid_time - second_time)      
        # print("stage 3: create 3d mask")
        # print(forth_time- thrid_time)
        # print("stage 4: interleave id, and position_ids")
        # print(fifth_time- forth_time)

        return None,None,threed_mask.long(),None,None,None,final_ids.long(),final_position_ids.long(),tokenized_doc["position_ids"].long()


