# from absl import logging
import numpy as np
# import tensorflow as tf
from matplotlib import pyplot as plt
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import torchs


def image_process_one(sample):
 
  transform = transforms.Compose([
    # transforms.Resize((224, 224)),  
    transforms.ToTensor(),            
    # normalize
  ])

  def my_shape(image_tensor):
    c, h, w = image_tensor.size()
    width = w
    if h == 16:
      window = stride = 384
    elif h == 14:
      if w == 224*16:
        window = stride = 224
    else:
      window = stride = 224
    tensors = []
    for start in range(0, width - window + 1, stride):
      slice_tensor = image_tensor[:,:,start:start+window]
      tensors.append(slice_tensor)
    return torch.cat(tensors, dim = 1)

  image_tensor = my_shape(transform(sample)) 
  to_pil = transforms.ToPILImage()
  pil_image = to_pil(image_tensor)
  
  return pil_image

def render_text_with_pil_multiple_mask(text, font_path="src/data/base/open-sans/OpenSans-Regular.ttf", font_size=16, image_size=(840*24, 16), text_color="black", background_color="white", margin=10, n_parts=3, clip_token_num=256, n_parts_adaptive=False):
    """
    Splits the text into N parts and renders each part into an image, with text vertically centered.
    If text is None, returns an empty image.
    """
    text = text.replace("\n", "  ")

    if text is None:
        return [Image.new('RGB', image_size, color=background_color)]
    font = ImageFont.truetype(font_path, font_size)
    image = Image.new('RGB', image_size, color=background_color)
    draw = ImageDraw.Draw(image)
    
    if n_parts_adaptive:
      bbox_all = draw.textbbox((0, 0), text, font=font)
      alltext_width = bbox_all[2] - bbox_all[0]  # Width is the difference between x2 and x0
      n_parts = max(int(np.ceil(alltext_width / image_size[0])), 1)
    part_length = len(text) // n_parts
    last_part_length = len(text) - part_length * (n_parts - 1)
    parts = [text[i:i+part_length] for i in range(0, len(text) - last_part_length, part_length)]
    parts.append(text[-last_part_length:])  # Append the last part to ensure all text is included

    images = []
    valid_all = []
    for part in parts:
        bbox = draw.textbbox((0, 0), part, font=font)
        text_width = bbox[2] - bbox[0]  # Width is the difference between x2 and x0
        text_height = bbox[3] - bbox[1]  # Height is the difference between y2 and y1
        valid_token_num = max(np.ceil(text_width / image_size[1]), 1)
        text_position_y = (image_size[1] - text_height) / 2
        draw.text((0, text_position_y), part, fill=text_color, font=font)
        # draw.multiline_text((0, text_position_y), part, fill=text_color, font=font)
        valid_one = torch.zeros((1, clip_token_num), dtype=torch.bool)
        valid_one[:, : int(min(valid_token_num, clip_token_num))] = True
        valid_all.append(valid_one)
        image_tensor = image_process_one(image) 
        images.append(image_tensor)
  
    valid_all = torch.cat(valid_all, dim=0) # (n_parts, total num=256)
    return images, valid_all



def get_encoder_render_image_mask(encoder_token_id, 
    tokenizer = None,
    n_parts = 28, 
    clip_token_num=576,
    font_size = 10,
    font_path = "data/renderers/noto_renderer/GoNotoCurrent.ttf",
    encoder_text = None,
    render_type = "list_to_one", # list_to_one list_to_n
    n_parts_adaptive = False, #override n_parts 
    add_instruction = None,
    image_size=None,
    ):
    if encoder_text is None:
      texts = tokenizer.decode(encoder_token_id, skip_special_tokens=True)
      # print("texts: ", texts)
      if add_instruction is not None:
        texts += add_instruction
      if image_size is not None:
        render_image, valid = render_text_with_pil_multiple_mask(texts, font_path = font_path, font_size=font_size, n_parts=n_parts, image_size=image_size, clip_token_num=clip_token_num, n_parts_adaptive = n_parts_adaptive)
      else:
        render_image, valid = render_text_with_pil_multiple_mask(texts, font_path = font_path, font_size=font_size, n_parts=n_parts, clip_token_num=clip_token_num, n_parts_adaptive = n_parts_adaptive)
    
    return render_image, valid, texts

