import os
import cv2
import numpy as np
from PIL import Image

from argparse import ArgumentParser
from accelerate import Accelerator
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
from mmengine.config import Config
from utils.data.coco import COCODataset
from mmdet.datasets.coco import CocoDataset
import copy
from mmdet.registry import DATASETS
from tqdm import tqdm
import math
from pipelines.recon_helper import ReConHelper
import ast
import gc

# DATASETS = Registry('datasets')

########################
# Set random seed
#########################
from accelerate.utils import set_seed

set_seed(0)

########################
# Parsers
#########################
parser = ArgumentParser(description="Generation script")
parser.add_argument("ckpt_path", type=str, default="controlnet_recon")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--nsamples", type=int, default=1)
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--strength", type=float, default=1.0)
parser.add_argument("--save_dir", type=str, default="outputs")
parser.add_argument(
    "--det_steps",
    type=ast.literal_eval,
    default="[0.75,0.5,0.25,0.1]",
    help="List of float values",
)
parser.add_argument("--deepcache", action="store_true")
parser.add_argument("--controlnet", action="store_true")
parser.add_argument("--num_cache_steps", type=int, default=5)
parser.add_argument("--bonus", type=float, default=0.0)

parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--use_decouple_text_embedding", action="store_true")
parser.add_argument("--use_keep_strength_map", action="store_true")
parser.add_argument("--num_inference_steps", type=int, default=25)
parser.add_argument("--random_split", type=int, default=None)
parser.add_argument("--total_split", default=8, type=int)
parser.add_argument(
    "--split_id",
    default=0,
    type=int,
    help="Dividing classes into 5 parts, the index of which parts",
)

# copy from training script
parser.add_argument(
    "--dataset_config_name",
    type=str,
    default=None,
    help="The config of the Dataset, leave as None if there's only one config.",
)

parser.add_argument(
    "--prompt_version",
    type=str,
    default="v1",
    help="Text prompt version. Default to be version3 which is constructed with only camera variables",
)

parser.add_argument(
    "--num_bucket_per_side",
    type=int,
    default=None,
    nargs="+",
    help="Location bucket number along each side (i.e., total bucket number = num_bucket_per_side * num_bucket_per_side) ",
)


def random_split_dataset(dataset, total_splits, split_index, seed):
    total_data_number = len(dataset)
    number_per_split = math.ceil(total_data_number / total_splits)

    # generate indices
    indices = list(range(total_data_number))

    # shuffle indices
    torch.manual_seed(seed)
    indices = torch.randperm(total_data_number).tolist()

    # determine the start and ened index
    start_idx = split_index * number_per_split
    end_idx = min((split_index + 1) * number_per_split, total_data_number)

    # split the original dataset
    split_indices = indices[start_idx:end_idx]
    return torch.utils.data.Subset(dataset, split_indices)


if __name__ == "__main__":
    args = parser.parse_known_args()[0]

    print("{}".format(args).replace(", ", ",\n"))

    ########################
    # Build pipeline
    #########################
    ckpt_path = args.ckpt_path
    from pipelines.pipeline_controlnet_recon import (
        StableDiffusionControlNetImg2ImgPipeline,
    )
    from diffusers import UNet2DConditionModel, ControlNetModel

    # unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16)
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
    )
    pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        torch_dtype=torch.float16,
        safety_checker=None,
    )
    recon_helper = ReConHelper(
        pipe,
        enable_fast_sampling=True,
        enable_region_guided_rectification=True,
        enable_region_aligned_cross_attention=True,
        perception_steps=[0.75, 0.5, 0.25, 0.1],
        num_cache_steps=5,
        debug_mode=False,
        is_controlnet=True,
        device="cuda",
        debug_log_dir="./example_output/controlnet_recon",
    )

    # disabled_safety_checker   # https://github.com/huggingface/diffusers/issues/3422
    pipe.safety_checker = None
    pipe = pipe.to("cuda")

    print("Inferencing...")

    config_name = os.path.basename(args.dataset_config_name).split(".")[0]
    dataset_cfg = Config.fromfile(args.dataset_config_name)
    dataset_cfg.train_dataloader.dataset.update(
        dict(
            prompt_version=args.prompt_version,
            split=args.split,
            num_bucket_per_side=args.num_bucket_per_side,
        )
    )
    dataset_cfg.val_dataloader.dataset.update(
        dict(
            prompt_version=args.prompt_version,
            split=args.split,
            num_bucket_per_side=args.num_bucket_per_side,
        )
    )

    dataset_cfg.train_dataloader.dataset.pipeline[3]["prob"] = 0.0
    dataset_cfg.val_dataloader.dataset.pipeline[3]["prob"] = 0.0

    width, height = (
        dataset_cfg.train_dataloader.dataset.pipeline[2].scale
        if args.split == "train"
        else dataset_cfg.val_dataloader.dataset.pipeline[2].scale
    )

    dataset = (
        DATASETS.build(dataset_cfg.train_dataloader.dataset)
        if args.split == "train"
        else DATASETS.build(dataset_cfg.val_dataloader.dataset)
    )
    dataset.split = args.split

    CLASSES = dataset.CLASSES
    # ! IMPORTANT: build text embeddings.
    # if args.use_decouple_text_embedding:
    from compel import Compel

    compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)

    dataset.build_prompts(pipe)
    category_conditions = dataset.category_conditions
    background_condition = dataset.background_condition
    print(len(category_conditions))
    print(len(CLASSES))

    # split dataset
    if args.random_split is not None:
        dataset = random_split_dataset(
            dataset, args.total_split, args.split_id, args.random_split
        )
    else:
        total_data_number = len(dataset)
        number_per_split = math.ceil(total_data_number / args.total_split)
        if args.split_id == (
            args.total_split - 1
        ) and total_data_number < number_per_split * (args.split_id + 1):
            mask = list(range(number_per_split * args.split_id, total_data_number))
        else:
            mask = list(
                range(
                    number_per_split * args.split_id,
                    number_per_split * (args.split_id + 1),
                )
            )
        dataset = torch.utils.data.Subset(dataset, mask)

    data_number = len(dataset)

    print("Image resolution: {} x {}".format(width, height))
    print(len(dataset))

    item = dataset[0]
    print(item["text"])

    ########################
    # Generation
    #########################
    scale = args.cfg_scale
    n_samples = args.nsamples

    if args.seed is not None:
        seed = args.seed
    else:
        seed = 333

    root = os.path.join(args.save_dir, ckpt_path.replace("/", "--"), f"seed_{seed}")
    print(f"Path in: {root}")

    os.makedirs(os.path.join(root, os.path.dirname(item["img_path"])), exist_ok=True)

    for i, data in enumerate(tqdm(dataset, desc="Processing Prompts", unit="prompt")):

        path = data["img_path"]
        img_height = data["height"]
        img_width = data["width"]
        if os.path.exists(os.path.join(root, path)):
            print(f"Path: {os.path.join(root, path)} exists, continue...")
            continue

        # run generation
        prompt = data["text"]
        prompt = n_samples * [prompt]

        print(data["labels"])
        print(prompt)

        if recon_helper.enable_region_aligned_cross_attention:
            context_condition = compel.build_conditioning_tensor(data["caption"])
            print("Context prompt:", data["caption"])
            print("Object labels:", data["labels"])
            obj_conditions = [
                category_conditions[CLASSES.index(x)] for x in data["labels"]
            ] + [context_condition]

            text_embedding = torch.concat(
                [x.unsqueeze(1) for x in obj_conditions], dim=1
            )
        else:
            text_embedding = None
        if text_embedding is not None:
            prompt = None

        edge_image = None
        if "controlnet" in ckpt_path:
            low_threshold = 100
            high_threshold = 200
            edge_image = cv2.Canny(
                np.array(data["pil_image"]), low_threshold, high_threshold
            )
            edge_image = edge_image[:, :, None]
            edge_image = np.concatenate([edge_image, edge_image, edge_image], axis=2)
            edge_image = Image.fromarray(edge_image)

        pil_image = Image.fromarray(data["pil_image"])
        recon_helper.gt_bboxes = data["bboxes"]
        recon_helper.gt_labels = data["labels"]

        images = pipe(
            prompt,
            image=pil_image,
            prompt_embeds=text_embedding,
            guidance_scale=scale,
            strength=args.strength,
            generator=torch.manual_seed(seed),
            num_inference_steps=args.num_inference_steps,
            height=int(height),
            width=int(width),
            recon_helper=recon_helper,
        ).images

        # save results
        if len(images) == 1:
            image = np.asarray(images[0])
            image = Image.fromarray(image, mode="RGB")
            image = image.resize((img_width, img_height))
            image.save(os.path.join(root, path))
        else:
            for idx, image in enumerate(images):
                image = np.asarray(image)
                image = Image.fromarray(image, mode="RGB")
                image = image.resize((img_width, img_height))
                image.save(os.path.join(root, path[:-4] + "_{}.jpg".format(idx)))

        gc.collect()
        torch.cuda.empty_cache()

        exit()
