"""
SynthTIGER
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
import json
import random

from synthtiger import components, layers, templates, utils

shard = 4


class SynthTiger(templates.Template):
    def __init__(self, config=None):
        if config is None:
            config = {}

        self.coord_output = config.get("coord_output", True)
        self.mask_output = config.get("mask_output", True)
        self.glyph_coord_output = config.get("glyph_coord_output", True)
        self.glyph_mask_output = config.get("glyph_mask_output", True)
        self.vertical = config.get("vertical", False)
        self.quality = config.get("quality", [95, 95])
        self.visibility_check = config.get("visibility_check", False)

        self.corpus = components.BaseCorpus(**config.get("corpus", {}))
        self.font = components.BaseFont(**config.get("font", {}))
        self.color = components.RGB(**config.get("color", {}))
        self.layout = components.FlowLayout(**config.get("layout", {}))
        self.pad = components.Switch(components.Pad(), **config.get("pad", {}))

        self.corpus_data = self._load_corpus_data()
        self.current_index = 0
        self._shuffle_corpus_data()

    def _load_corpus_data(self):
        corpus_data = []
        for path in self.corpus.paths:
            with open(path, "r", encoding="utf-8") as file:
                for line in file:
                    text = line.strip()
                    if self.corpus._check_length(text) and self.corpus._check_charset(text):
                        corpus_data.append(text)
        return corpus_data

    def _shuffle_corpus_data(self):
        self.sampled_labels = random.sample(self.corpus_data, len(self.corpus_data))
        self.current_index = 0

    def _get_next_label(self):
        if self.current_index >= len(self.sampled_labels):
            self._shuffle_corpus_data()
        label = self.sampled_labels[self.current_index]
        self.current_index += 1
        return label

    def generate(self):
        print("generate..")
        quality = np.random.randint(self.quality[0], self.quality[1] + 1)
        color = self.color.sample()
        result = self._generate_text(color)

        if result is None:
            print("Data skipped")
            return 0

        fg_image, label, bboxes, glyph_fg_image, glyph_bboxes = result
        fg_image_shape = fg_image.shape[:2][::-1]
        bg_image = self._generate_background(fg_image_shape)
        image = utils.blend_image(fg_image, bg_image, mode="normal")

        data = {
            "image": image,
            "label": label,
            "quality": quality,
            "mask": fg_image[..., 3],
            "bboxes": bboxes,
            "glyph_mask": glyph_fg_image[..., 3],
            "glyph_bboxes": glyph_bboxes,
        }
        return data

    def init_save(self, root):
        os.makedirs(root, exist_ok=True)
        gt_path = os.path.join(root, "gt.txt")
        coords_path = os.path.join(root, "coords.txt")
        self.gt_file = open(gt_path, "w", encoding="utf-8")
        if self.coord_output:
            self.coords_file = open(coords_path, "w", encoding="utf-8")
        if self.glyph_coord_output:
            self.glyph_coords = []

    def save(self, root, data, idx):
        if not data:
            return

        image = data["image"]
        label = data["label"]
        quality = data["quality"]
        mask = data["mask"]
        bboxes = data["bboxes"]
        glyph_mask = data["glyph_mask"]
        glyph_bboxes = data["glyph_bboxes"]

        original_word = label
        image = Image.fromarray(image[..., :3].astype(np.uint8))
        mask = Image.fromarray(mask.astype(np.uint8))
        glyph_mask = Image.fromarray(glyph_mask.astype(np.uint8))

        vis_image_pil = image.copy()
        draw = ImageDraw.Draw(vis_image_pil)

        x_min = min([float(bbox[0]) for bbox in glyph_bboxes])
        y_min = min([float(bbox[1]) for bbox in glyph_bboxes])
        x_max = max([float(bbox[0]) + float(bbox[2]) for bbox in glyph_bboxes])
        y_max = max([float(bbox[1]) + float(bbox[3]) for bbox in glyph_bboxes])

        word_bbox = {
            "label": label,
            "bbox": [float(x_min), float(y_min), float(x_max - x_min), float(y_max - y_min)]
        }

        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
        for j, glyph_bbox in enumerate(glyph_bboxes):
            glyph_x_min = float(glyph_bbox[0])
            glyph_y_min = float(glyph_bbox[1])
            glyph_x_max = glyph_x_min + float(glyph_bbox[2])
            glyph_y_max = glyph_y_min + float(glyph_bbox[3])
            draw.rectangle([glyph_x_min, glyph_y_min, glyph_x_max, glyph_y_max], outline="blue", width=2)

        base_idx = shard * 10000 + idx

        word_data = {
            "idx": base_idx,
            "word_bbox": word_bbox,
            "characters": [
                {
                    "character": label[j],
                    "bbox": [float(glyph_bbox[0]), float(glyph_bbox[1]), float(glyph_bbox[2]), float(glyph_bbox[3])]
                } for j, glyph_bbox in enumerate(glyph_bboxes)
            ]
        }

        if self.glyph_coord_output:
            self.glyph_coords.append(word_data)

        image_key = os.path.join("images", str(shard), f"{base_idx}.jpg")
        mask_key = os.path.join("masks", str(shard), f"{base_idx}.png")
        glyph_mask_key = os.path.join("glyph_masks", str(shard), f"{base_idx}.png")

        image_path = os.path.join(root, image_key)
        mask_path = os.path.join(root, mask_key)
        glyph_mask_path = os.path.join(root, glyph_mask_key)

        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        image.save(image_path, quality=quality)
        if self.mask_output:
            os.makedirs(os.path.dirname(mask_path), exist_ok=True)
            mask.save(mask_path)
        if self.glyph_mask_output:
            os.makedirs(os.path.dirname(glyph_mask_path), exist_ok=True)
            glyph_mask.save(glyph_mask_path)

        self.gt_file.write(f"{image_key}\t{label}\n")
        if self.coord_output:
            self.coords_file.write(f"{image_key}\t{word_bbox['bbox'][0]},{word_bbox['bbox'][1]},{word_bbox['bbox'][2]},{word_bbox['bbox'][3]}\n")

        print(label)
        print(image_path)

    def end_save(self, root):
        self.gt_file.close()
        if self.coord_output:
            self.coords_file.close()
        if self.glyph_coord_output:
            with open(os.path.join(root, f"glyph_coords_{shard}.json"), "w", encoding="utf-8") as json_file:
                json.dump(self.glyph_coords, json_file, ensure_ascii=False, indent=4)

    def _generate_text(self, color):
        label = self._get_next_label()
        chars = utils.split_text(label, reorder=False)
        char_layers = []

        font_meta = self.font.sample({"vertical": self.vertical})
        font_meta["size"] = 40

        for char in chars:
            char_layer = layers.TextLayer(char, **font_meta)
            char_layers.append(char_layer)

        self.layout.apply(char_layers, {"meta": {"vertical": self.vertical}})
        char_glyph_layers = [char_layer.copy() for char_layer in char_layers]
        text_layer = layers.Group(char_layers).merge()
        text_glyph_layer = text_layer.copy()

        pad = self.pad.sample()
        self.color.apply([text_layer, text_glyph_layer], color)
        self.pad.apply([text_layer], pad)

        for char_layer in char_layers:
            char_layer.topleft -= text_layer.topleft
        for char_glyph_layer in char_glyph_layers:
            char_glyph_layer.topleft -= text_layer.topleft

        text_out = text_layer.output()
        text_bboxes = [char_layer.bbox for char_layer in char_layers]
        text_glyph_out = text_glyph_layer.output(bbox=text_layer.bbox)
        text_glyph_bboxes = [char_glyph_layer.bbox for char_glyph_layer in char_glyph_layers]

        return text_out, label, text_bboxes, text_glyph_out, text_glyph_bboxes

    def _generate_background(self, text_size):
        bg_layer = layers.RectLayer(text_size)
        self.color.apply([bg_layer], {"rgb": (255, 255, 255)})
        bg_out = bg_layer.output()
        return bg_out