import PIL
import torch
from janus.models import VLChatProcessor, MultiModalityCausalLM
from janus.models.image_processing_vlm import VLMImageProcessor
from transformers import LlamaTokenizerFast
from typing import List


def tokenize_text_janus(
    vl_chat_processor: VLChatProcessor, 
    embedding_model: MultiModalityCausalLM,
    prompts: List,
    model_name: str, 
    device
):
    seq_ids = []
    batch_size = len(prompts)
    max_seq_len = 0
    for prompt in prompts:
        if model_name == 'Janus':
            conversation = [
                {
                    "role": "User",
                    "content": prompt,
                },
                {
                    "role": "Assistant", 
                    "content": ""
                },
            ]
        elif 'Janus-Pro' in model_name:
             conversation = [
                {
                    "role": "<|User|>",
                    "content": prompt,
                },
                {
                    "role": "<|Assistant|>", 
                    "content": ""
                },
            ]           

        sft = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
                    conversations = conversation,
                    sft_format = vl_chat_processor.sft_format,
                    system_prompt="",
                )
        
        prompt = sft + vl_chat_processor.image_start_tag
        input_ids = vl_chat_processor.tokenizer.encode(prompt)
        input_ids = torch.LongTensor(input_ids)
        seq_ids.append(input_ids)
        max_seq_len = max(max_seq_len, input_ids.shape[0])
    
    batched_input_ids = torch.full((batch_size, max_seq_len), vl_chat_processor.pad_id).long()
    batched_attention_mask = torch.zeros((batch_size, max_seq_len)).long()
    for i, input_ids in enumerate(seq_ids):
        seq_len = input_ids.shape[0]
        batched_input_ids[i, -seq_len:] = input_ids # left-padding
        batched_attention_mask[i, -seq_len:] = 1
    
    tokens = torch.zeros((batch_size * 2, max_seq_len), dtype=torch.int)
    attention_mask = batched_attention_mask.repeat_interleave(2, dim=0)
    tokens[0::2, 1:-1]  = vl_chat_processor.pad_id
    
    for i in range(batch_size):
        first_one = (batched_attention_mask[i] == 1).nonzero(as_tuple=True)[0][0]
        tokens[2*i, first_one] = batched_input_ids[i, first_one]    
    
    #tokens[0::2, 0] = batched_input_ids[0, 0]
    tokens[0::2, -1] = batched_input_ids[0, -1]
    tokens[1::2] = batched_input_ids

    tokens = tokens.to(device)
    attention_mask = attention_mask.to(device)

    input_embeds_prompt = embedding_model.language_model.get_input_embeddings()(tokens)
    #input_embeds_prompt = input_embeds_prompt.to(device)
    return input_embeds_prompt, attention_mask


def tokenize_text_janus_3_way(
    vl_chat_processor: VLChatProcessor, 
    embedding_model: MultiModalityCausalLM,
    prompts: List,
    modified_prompts: List,
    model_name: str, 
    device
    ):
    assert len(prompts) == len(modified_prompts)

    seq_ids = []
    batch_size = len(prompts)
    
    def _make_seq_ids(src_prompts: List[str]):
        seq_ids, max_len = [], 0
        for p in src_prompts:
            if model_name == "Janus":
                conversation = [
                    {"role": "User", "content": p},
                    {"role": "Assistant", "content": ""},
                ]
            elif 'Janus-Pro' in model_name:
                conversation = [
                    {"role": "<|User|>", "content": p},
                    {"role": "<|Assistant|>", "content": ""},
                ]
            else:
                raise ValueError(f"Unknown model_name: {model_name}")

            sft = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
                conversations=conversation,
                sft_format=vl_chat_processor.sft_format,
                system_prompt="",
            )

            prompt_text = sft + vl_chat_processor.image_start_tag
            ids = vl_chat_processor.tokenizer.encode(prompt_text)
            ids = torch.LongTensor(ids)
            seq_ids.append(ids)
            max_len = max(max_len, ids.shape[0])
        return seq_ids, max_len
    
    seq_ids_orig, max_len_orig = _make_seq_ids(prompts)
    seq_ids_mod,  max_len_mod  = _make_seq_ids(modified_prompts)
    max_seq_len = max(max_len_orig, max_len_mod)
    
    def _batchify(seq_ids, max_len):
        B = len(seq_ids)
        batched_ids = torch.full((B, max_len), vl_chat_processor.pad_id).long()
        batched_mask = torch.zeros((B, max_len)).long()
        for i, ids in enumerate(seq_ids):
            L = ids.shape[0]
            batched_ids[i, -L:] = ids
            batched_mask[i, -L:] = 1
        return batched_ids, batched_mask
    
    batched_input_ids_orig, batched_attention_mask_orig = _batchify(seq_ids_orig, max_seq_len)
    batched_input_ids_mod,  batched_attention_mask_mod  = _batchify(seq_ids_mod,  max_seq_len)

    tokens = torch.zeros((batch_size * 3, max_seq_len), dtype=torch.long)
    tokens[1::3] = batched_input_ids_orig
    tokens[2::3] = batched_input_ids_mod
    tokens[0::3, :] = vl_chat_processor.pad_id
    for i in range(batch_size):
        ones = (batched_attention_mask_orig[i] == 1).nonzero(as_tuple=True)[0]
        if ones.numel() > 0:
            first_one = ones[0].item()
            tokens[3*i, first_one] = batched_input_ids_orig[i, first_one]
        tokens[3*i, -1] = batched_input_ids_orig[i, -1]
    
    attention_mask = torch.zeros((batch_size * 3, max_seq_len), dtype=torch.long)
    attention_mask[0::3] = batched_attention_mask_orig   # uncond
    attention_mask[1::3] = batched_attention_mask_orig   # cond_orig
    attention_mask[2::3] = batched_attention_mask_mod    # cond_modified
    
    tokens = tokens.to(device)
    attention_mask = attention_mask.to(device)

    input_embeds_prompt = embedding_model.language_model.get_input_embeddings()(tokens)
    #input_embeds_prompt = input_embeds_prompt.to(device)
    return input_embeds_prompt, attention_mask

def tokenize_text_llamagen(
    vl_chat_processor, 
    embedding_model,
    prompts,
    model_name, 
    device
):
    t5_model  = embedding_model[0]
    gpt_model = embedding_model[1]

    input_embeds_prompt, attention_mask = t5_model.get_text_embeddings(prompts)  # (B,T,H), (B,T)
    input_embeds_prompt = input_embeds_prompt.to(device)
    attention_mask      = attention_mask.to(device)

    new_attention_mask = torch.flip(attention_mask, dims=[-1])
    rotated = []
    for idx, (caption_emb, emb_mask) in enumerate(zip(input_embeds_prompt, attention_mask)):
        valid_num = int(emb_mask.sum().item())
        print(f'  prompt {idx} token len: {valid_num}')
        new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]], dim=0)
        rotated.append(new_caption_emb)
    new_input_embeds_prompt = torch.stack(rotated, dim=0)

    c_attention_mask = new_attention_mask                                        # (B,T)
    c_indices        = new_input_embeds_prompt * c_attention_mask.unsqueeze(-1)  # (B,T,H)

    B, T, H   = c_indices.shape
    uncond_vec = torch.zeros_like(c_indices) + gpt_model.cls_embedding.uncond_embedding
    uncond_vec = uncond_vec.to(device=device, dtype=c_indices.dtype)

    inputs_interleaved = torch.zeros((2 * B, T, H), device=device, dtype=c_indices.dtype)
    masks_interleaved  = torch.zeros((2 * B, T),     device=device, dtype=c_attention_mask.dtype)

    inputs_interleaved[0::2] = uncond_vec
    inputs_interleaved[1::2] = c_indices
    masks_interleaved[0::2]  = c_attention_mask
    masks_interleaved[1::2]  = c_attention_mask

    return inputs_interleaved.contiguous(), masks_interleaved.contiguous()
    
    

def tokenize_text_llamagen_3_way(
    vl_chat_processor, 
    embedding_model,
    prompts,
    modified_prompts,
    model_name, 
    device
):
    t5_model  = embedding_model[0]
    gpt_model = embedding_model[1]

    combined_prompts = []
    for p, pm in zip(prompts, modified_prompts):
        combined_prompts.extend([p, pm])
        
    pair_embeds, pair_mask = t5_model.get_text_embeddings(combined_prompts)  # (2B,T,H), (2B,T)
    pair_embeds = pair_embeds.to(device)
    pair_mask   = pair_mask.to(device)

    new_attention_mask = torch.flip(pair_mask, dims=[-1])                    # (2B,T)
    rotated_rows = []
    for row_emb, row_mask in zip(pair_embeds, pair_mask):                    # row_emb:(T,H), row_mask:(T,)
        valid_len = int(row_mask.sum().item())
        rotated = torch.cat([row_emb[valid_len:], row_emb[:valid_len]], dim=0)
        rotated_rows.append(rotated)
    pair_embeds_rot = torch.stack(rotated_rows, dim=0)                       # (2B,T,H)
    c_pair = pair_embeds_rot * new_attention_mask.unsqueeze(-1)              # (2B,T,H)

    B2, T, H = c_pair.shape
    assert B2 % 2 == 0
    B = B2 // 2
    
    cond_rows      = c_pair[0::2]                      # (B,T,H)
    cond_mod_rows  = c_pair[1::2]                      # (B,T,H)
    cond_mask      = new_attention_mask[0::2]          # (B,T)
    cond_mod_mask  = new_attention_mask[1::2]
    
    union_mask = (cond_mask.bool() | cond_mod_mask.bool()).to(cond_mask.dtype)  # (B,T)
    uncond_vec = torch.zeros_like(cond_rows) + gpt_model.cls_embedding.uncond_embedding
    uncond_vec = uncond_vec.to(device=device, dtype=c_pair.dtype)

    inputs_interleaved = torch.empty((3 * B, T, H), device=device, dtype=c_pair.dtype)
    masks_interleaved  = torch.empty((3 * B, T), device=device, dtype=cond_mask.dtype)
    
    inputs_interleaved[0::3] = uncond_vec
    inputs_interleaved[1::3] = cond_rows
    inputs_interleaved[2::3] = cond_mod_rows

    masks_interleaved[0::3]  = union_mask
    masks_interleaved[1::3]  = cond_mask
    masks_interleaved[2::3]  = cond_mod_mask

    return inputs_interleaved.contiguous(), masks_interleaved.contiguous()

