import argparse
from pathlib import Path

from PIL import Image, ImageDraw
import torch
from torchvision import transforms

from tokenizer.tokenizer_image.vq_model import VQ_models
from utils.data import center_crop_arr

"""
what happens when we replace a token with
- random
- similar
- different
tokens within the vocabulary?
"""


class Replacer:
    def __init__(self, vq_model, shape):
        self.vq_model = vq_model
        self.shape = shape

    def _replace_dist(self, center, use_sim: bool = True, num_topk: int = 6):
        center = center.to(self.vq_model.quantize.embedding.weight.device)
        shape = center.shape
        center = center.reshape(-1)

        xs = self.vq_model.quantize.embedding(center)
        ys = self.vq_model.quantize.embedding.weight
        # l2 dist
        dist = (xs**2).sum(1)[:, None] + (ys**2).sum(1) - 2 * xs @ ys.T
        topks = torch.topk(
            dist, num_topk, dim=1, largest=not use_sim, sorted=True
        ).indices  # [B, num_topk]
        topks = topks[:, 1:]  # remove the first one
        # random sample 1
        ids = torch.randint(0, num_topk - 1, (len(center),)).to(center.device)
        indices = topks[torch.arange(len(center)), ids]  # B
        return indices.reshape(*shape)

    def _replace(self, center, mode):
        if mode == "rand":
            return torch.randint_like(center, 0, self.vq_model.quantize.n_e)
        elif mode == "sim":
            return self._replace_dist(center, use_sim=True)
        elif mode == "diff":
            return self._replace_dist(center, use_sim=False)
        else:
            raise ValueError(f"unknown mode: {mode}")

    def __call__(self, indices, batch_size: int, mode="rand", size: int = 4):
        indices = indices.clone()
        indices = indices.reshape(batch_size, *self.shape)
        # replace center
        center = indices[
            :,
            self.shape[0] // 2 - size // 2 : self.shape[0] // 2 + size // 2,
            self.shape[1] // 2 - size // 2 : self.shape[1] // 2 + size // 2,
        ]
        indices[
            :,
            self.shape[0] // 2 - size // 2 : self.shape[0] // 2 + size // 2,
            self.shape[1] // 2 - size // 2 : self.shape[1] // 2 + size // 2,
        ] = self._replace(center, mode)
        return indices.reshape(-1)


def draw_box(sample):
    x1, y1, x2, y2 = 128 - 32, 128 - 32, 128 + 32, 128 + 32

    draw = ImageDraw.Draw(sample)
    draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
    return sample


def main(args):
    device = "cuda"
    out_root = Path(args.out_dir)
    out_root.mkdir(exist_ok=True, parents=True)

    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size, codebook_embed_dim=args.codebook_embed_dim
    )
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint

    replacer = Replacer(vq_model, shape=[args.image_size // 16, args.image_size // 16])

    crop_size = args.image_size
    transform = transforms.Compose(
        [
            transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, crop_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
            ),
        ]
    )

    @torch.no_grad()
    def run(paths):
        names = [Path(path).stem for path in paths]
        images = [Image.open(path).convert("RGB") for path in paths]
        x_all = torch.stack([transform(image) for image in images], dim=0).to(device)
        batch_size = x_all.size(0)
        with torch.no_grad():
            _, _, [_, _, indices] = vq_model.encode(x_all)
            qzshape = [
                batch_size,
                args.codebook_embed_dim,
                args.image_size // 16,
                args.image_size // 16,
            ]

        def decode(indices):
            with torch.no_grad():
                samples = vq_model.decode_code(
                    indices, qzshape
                )  # output value is between [-1, 1]
            samples = (
                torch.clamp(127.5 * samples + 128.0, 0, 255)
                .permute(0, 2, 3, 1)
                .to("cpu", dtype=torch.uint8)
                .numpy()
            )
            return samples

        samples = decode(indices)
        for name, sample in zip(names, samples):
            sample = draw_box(Image.fromarray(sample))
            sample.save(str(out_root / f"{name}_base.png"))
        indices_rand = replacer(indices, batch_size, mode="rand")
        samples_rand = decode(indices_rand)
        for name, sample in zip(names, samples_rand):
            sample = draw_box(Image.fromarray(sample))
            sample.save(str(out_root / f"{name}_rand.png"))
        indices_sim = replacer(indices, batch_size, mode="sim")
        samples_sim = decode(indices_sim)
        for name, sample in zip(names, samples_sim):
            sample = draw_box(Image.fromarray(sample))
            sample.save(str(out_root / f"{name}_sim.png"))
        indices_diff = replacer(indices, batch_size, mode="diff")
        samples_diff = decode(indices_diff)
        for name, sample in zip(names, samples_diff):
            sample = draw_box(Image.fromarray(sample))
            sample.save(str(out_root / f"{name}_diff.png"))

    paths = list(Path(args.image_dir).glob("*"))
    run(paths)
    print("done")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_dir", type=str, default="../../../data/samples/images")
    parser.add_argument("--out_dir", type=str, default="../../../data/samples/outputs")
    parser.add_argument(
        "--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16"
    )
    parser.add_argument(
        "--vq-ckpt",
        type=str,
        help="ckpt path for vq model",
        default="../../../data/vq/vq_ds16_c2i.pt",
    )
    parser.add_argument(
        "--codebook-size",
        type=int,
        default=16384,
        help="codebook size for vector quantization",
    )
    parser.add_argument(
        "--codebook-embed-dim",
        type=int,
        default=8,
        help="codebook dimension for vector quantization",
    )
    parser.add_argument("--dataset", type=str, default="imagenet")
    parser.add_argument(
        "--image-size", type=int, choices=[256, 384, 448, 512], default=256
    )
    args = parser.parse_args()
    main(args)
