import random
from typing import List

import numpy as np
import PIL.Image as pil_image
import torch
from params_proto import ParamsProto


def image_grid(img_list: List[List[pil_image.Image]]):
    # imgs is a 2D list of images
    # Assumes the input images are a rectangular grid of equal sized images
    rows = len(img_list)
    cols = len(img_list[0])

    w, h = img_list[0][0].size
    grid = pil_image.new("RGB", size=(cols * w, rows * h))

    for i, row in enumerate(img_list):
        for j, img in enumerate(row):
            grid.paste(img, box=(j * w, i * h))
    return grid


class Imagen(ParamsProto, prefix="imagen"):
    """Image Generation from two semantic masks

    Generation function takes in five parameters:


    prompt: str
    background_text: str
    negative_text: str

    depth: pil_image.Image,
    background_mask: pil_image.Image,
    foreground_mask: pil_image.Image,

    # output
    file_name: str,

    """

    # object_text: str = "A photo realistic Close-up view of a concrete curb that looks gray"
    # background_text: str = "Close-up view of humus and foliage under a tree, dark brown, mud, from a small dog’s perspective, highlighting textures of humus"
    # background_text: str = "close up view, photorealistic view of a college campus from a small dog’s perspective during golden hour, highlighting textures of grass, foliage, and buildings"

    foreground_prompt = "close up view, photo realistic tables made of (wood:1.5), weathered, cracked, 135mm IMAX, very large"
    background_prompt = (
        "sandy carpet inside a really really messy laboratory space white walls, brightly lit, stuff littered around the wall, (fov:11mm)"
    )
    negative_prompt: str = "watermark, low-quality"

    width = 1280
    height = 768
    batch_size: int = 1

    num_steps = 5
    denoising_strength = 1.0

    background_strength = 0.75
    foreground_strength = 0.9
    control_strength = 0.7

    grow_mask_amount = 6

    seed = 100

    image_counter = 0
    rollout_id = 0

    sdxl_path = "sd_xl_turbo_1.0_fp16.safetensors"
    device = "cuda"

    def __post_init__(self):
        from weaver.workflows.comfy_utils import import_custom_nodes, add_extra_model_paths

        add_extra_model_paths()
        import_custom_nodes()

        from ml_logger import logger

        logger.job_started()
        print(logger)
        print(logger.get_dash_url())

        from nodes import (
            EmptyLatentImage,
            CheckpointLoaderSimple,
            NODE_CLASS_MAPPINGS,
            VAEDecode,
            CLIPTextEncode,
            ControlNetLoader,
        )

        checkpointloadersimple = CheckpointLoaderSimple()
        self.checkpoint = checkpointloadersimple.load_checkpoint(ckpt_name=self.sdxl_path)
        self.clip_text_encode = CLIPTextEncode()
        self.empty_latent = EmptyLatentImage()

        ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
        self.ksampler = ksamplerselect.get_sampler(sampler_name="lcm")

        controlnetloader = ControlNetLoader()
        self.controlnet = controlnetloader.load_controlnet(control_net_name="controlnet_depth_sdxl_1.0.safetensors")

        self.imagetomask = NODE_CLASS_MAPPINGS["ImageToMask"]()
        self.growmask = NODE_CLASS_MAPPINGS["GrowMask"]()
        self.vaedecode = VAEDecode()

        print("loading is done.")

    @torch.no_grad
    def to_tensor(self, img: pil_image.Image):
        np_img = np.asarray(img)
        return torch.Tensor(np_img) / 255.0  # .to(self.device)

    def generate(
        self,
        _deps=None,
        *,  # required
        midas_depth: pil_image.Image,
        foreground_mask: pil_image.Image,
        background_mask: pil_image.Image,
        # optional
        **deps,  #: Unpack["Imagen"],
    ) -> pil_image.Image:
        from nodes import (
            ConditioningSetMask,
            ConditioningCombine,
            NODE_CLASS_MAPPINGS,
            ControlNetApply,
        )
        from weaver.workflows.comfy_utils import get_value_at_index

        # we reference the class to take advantage of the namespacing
        Imagen._update(_deps, **deps)

        depths_t = self.to_tensor(midas_depth)[None, ..., None].repeat([1, 1, 1, 3])
        foreground_mask_t = self.to_tensor(foreground_mask)[None, ..., None].repeat([1, 1, 1, 3])
        background_mask_t = self.to_tensor(background_mask)[None, ..., None].repeat([1, 1, 1, 3])

        assert self.batch_size == 1, "only generate one for now."

        with torch.inference_mode():
            emptylatentimage_5 = self.empty_latent.generate(width=Imagen.width, height=Imagen.height, batch_size=Imagen.batch_size)

            cliptextencode_6 = self.clip_text_encode.encode(
                text=Imagen.background_prompt,
                clip=get_value_at_index(self.checkpoint, 1),
            )

            print(Imagen.foreground_prompt, Imagen.background_prompt)

            cliptextencode_37 = self.clip_text_encode.encode(
                text=Imagen.foreground_prompt,
                clip=get_value_at_index(self.checkpoint, 1),
            )

            cliptextencode_56 = self.clip_text_encode.encode(
                text=Imagen.negative_prompt,
                clip=get_value_at_index(self.checkpoint, 1),
            )

            conditioningsetmask = ConditioningSetMask()
            conditioningcombine = ConditioningCombine()
            controlnetapply = ControlNetApply()
            sdturboscheduler = NODE_CLASS_MAPPINGS["SDTurboScheduler"]()
            samplercustom = NODE_CLASS_MAPPINGS["SamplerCustom"]()

            imagetomask_69 = self.imagetomask.image_to_mask(channel="red", image=get_value_at_index([background_mask_t], 0))
            # imagetomask_69 = background_mask_t

            growmask_69 = self.growmask.expand_mask(
                expand=Imagen.grow_mask_amount,
                tapered_corners=True,
                mask=get_value_at_index(imagetomask_69, 0),
            )

            conditioningsetmask_51 = conditioningsetmask.append(
                strength=Imagen.background_strength,
                set_cond_area="default",
                conditioning=get_value_at_index(cliptextencode_6, 0),
                mask=get_value_at_index(growmask_69, 0),
            )

            imagetomask_70 = self.imagetomask.image_to_mask(channel="red", image=get_value_at_index([foreground_mask_t], 0))
            # imagetomask_70 = foreground_mask_t

            growmask_70 = self.growmask.expand_mask(
                expand=Imagen.grow_mask_amount,
                tapered_corners=True,
                mask=get_value_at_index(imagetomask_70, 0),
            )

            conditioningsetmask_52 = conditioningsetmask.append(
                strength=Imagen.foreground_strength,
                set_cond_area="default",
                conditioning=get_value_at_index(cliptextencode_37, 0),
                mask=get_value_at_index(growmask_70, 0),
            )

            conditioningcombine_55 = conditioningcombine.combine(
                conditioning_1=get_value_at_index(conditioningsetmask_51, 0),
                conditioning_2=get_value_at_index(conditioningsetmask_52, 0),
            )

            controlnetapply_59 = controlnetapply.apply_controlnet(
                strength=Imagen.control_strength,
                conditioning=get_value_at_index(conditioningcombine_55, 0),
                control_net=get_value_at_index(self.controlnet, 0),
                image=get_value_at_index((depths_t,), 0),
            )

            sdturboscheduler_22 = sdturboscheduler.get_sigmas(
                steps=Imagen.num_steps,
                denoise=Imagen.denoising_strength,
                model=get_value_at_index(self.checkpoint, 0),
            )

            samplercustom_13 = samplercustom.sample(
                add_noise=True,
                noise_seed=random.randint(1, 2**64),
                cfg=1,
                model=get_value_at_index(self.checkpoint, 0),
                positive=get_value_at_index(controlnetapply_59, 0),
                negative=get_value_at_index(cliptextencode_56, 0),
                sampler=get_value_at_index(self.ksampler, 0),
                sigmas=get_value_at_index(sdturboscheduler_22, 0),
                latent_image=get_value_at_index(emptylatentimage_5, 0),
            )

            (image_batch,) = self.vaedecode.decode(
                samples=get_value_at_index(samplercustom_13, 0),
                vae=get_value_at_index(self.checkpoint, 2),
            )[:1]

            (generated_image,) = image_batch

            image_np = (generated_image * 255).cpu().numpy().astype("uint8")
            return pil_image.fromarray(image_np)
            #
            # file_name = file_name or f"lucid_dreams/imagen_{0}/frame_{0}.jpg"
            # logger.save_image(image_np, file_name)
            #
            # generated = pil_image.fromarray(image_np)
            # grid = image_grid([
            #     [depth, generated],
            #     [foreground_mask, background_mask],
            # ])
            # grid.format = "jpeg"
            #
            # return grid

    # return image_batch


def demo():
    from weaver.workflows.comfy_utils import load_local_image

    imagen = Imagen()

    depths = load_local_image("scene/render/midas_depth.png")
    foreground_masks = load_local_image("scene/render/mask.jpg")
    background_masks = load_local_image("scene/render/background.jpg")

    depths = torch.cat([depths] * 1)
    background_masks = torch.cat([background_masks] * 1)
    foreground_masks = torch.cat([foreground_masks] * 1)

    imagen.generate(depths, background_masks, foreground_masks)
    imagen.generate(depths, background_masks, foreground_masks)


if __name__ == "__main__":
    demo()
