from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.logging import print_log
import torch, numpy as np
from datasets import load_from_disk
from itertools import permutations
from PIL import Image
from einops import rearrange

def _pil_to_bchw_norm(img):
    x = np.array(img.convert("RGB"), dtype=np.float32)
    return torch.from_numpy(x).permute(2,0,1) / 127.5 - 1.0

def _pil_to_chw(img):
    a = np.array(img)
    if a.ndim == 2: a = np.stack([a]*3, axis=-1)
    t = torch.from_numpy(a)
    if t.dtype != torch.uint8: t = t.to(torch.uint8)
    return t.permute(2,0,1)

def _hwc_uint8_from_samples(samples, grid_size):
    samples = torch.clamp(127.5*samples+128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
    c,h,w = samples.shape[-3:]
    g = samples.reshape(grid_size,grid_size,c,h,w).transpose(0,3,1,4,2).reshape(grid_size*h,grid_size*w,c)
    return g

@HOOKS.register_module()
class ValidationHook(Hook):
    def __init__(self,
                 interval,
                 dataset_path,
                 split="test",
                 max_items=None,
                 batch_size=64,
                 do_images=True,
                 cfg_prompt="Generate an image.",
                 cfg=1.0,
                 num_iter=64,
                 grid_size=1,
                 image_size=512,
                 temperature=0.0,
                 cfg_schedule="constant"):
        self.interval = interval
        self.dataset_path = dataset_path
        self.split = split
        self.max_items = max_items
        self.batch_size = batch_size
        self.do_images = do_images
        self.cfg_prompt = cfg_prompt
        self.cfg = cfg
        self.num_iter = num_iter
        self.grid_size = grid_size
        self.image_size = image_size
        self.temperature = temperature
        self.cfg_schedule = cfg_schedule
        self.dataset = None

    def before_train(self, runner):
        dataset = load_from_disk(self.dataset_path)[self.split]

        def map_1(batch):
            permutations_list = list(permutations((batch["synthetic_color"][0], batch["synthetic_pattern"][0], batch["synthetic_position"][0], batch["synthetic_shape"][0])))
            prompts = [" ".join(p) for p in permutations_list][:1]
            return {
                "prompt": prompts,
                "image": batch["image"] * len(prompts),
                "color": batch["color"] * len(prompts),
                "pattern": batch["pattern"] * len(prompts),
                "position": batch["position"] * len(prompts),
                "shape": batch["shape"] * len(prompts),
                "synthetic_color": batch["synthetic_color"] * len(prompts),
                "synthetic_pattern": batch["synthetic_pattern"] * len(prompts),
                "synthetic_position": batch["synthetic_position"] * len(prompts),
                "synthetic_shape": batch["synthetic_shape"] * len(prompts)
            }

        self.dataset = dataset.map(map_1, batched=True, batch_size=1, remove_columns=["image"])
        self.dataset = self.dataset.select(range(min(self.max_items, len(self.dataset))))

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        if not self.every_n_train_iters(runner, self.interval):
            return
        if getattr(runner, "rank", 0) != 0:
            return

        model = runner.model
        device = next(model.parameters()).device
        model.eval()
        step = runner.iter
        m = n = self.image_size // 16

        total_loss, total_count = 0.0, 0

        with torch.no_grad():
            for s in range(0, len(self.dataset), self.batch_size):
                chunk_p = self.dataset["prompt"][s:s+self.batch_size]
                wrapped = [model.prompt_template["INSTRUCTION"].format(input=f"Generate an image: {p}.") for p in chunk_p]
                toks = model.tokenizer(wrapped, add_special_tokens=True, return_tensors="pt", padding=True)
                toks = {k: v.to(device) for k,v in toks.items()}

                pixel_values = torch.stack([_pil_to_bchw_norm(img) for img in self.dataset["image"][s:s+self.batch_size]], dim=0).to(device=device, dtype=model.dtype)
                loss = model.text2image_loss({"pixel_values": pixel_values, "input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]})
                total_loss += loss.item() * len(chunk_p)
                total_count += len(chunk_p)

                if self.do_images:
                    samples = model.sample(**toks, num_iter=self.num_iter, cfg=self.cfg, cfg_schedule=self.cfg_schedule, temperature=self.temperature, progress=False, image_shape=(m,n))
                    images = rearrange(samples, '(gm gn b) c h w -> b (gm h) (gn w) c', gm=1, gn=1)
                    hwcs = [_hwc_uint8_from_samples(image[None,...].permute(0,3,1,2), self.grid_size) for image in images]
                    for color, pattern, position, shape, hwc in zip(self.dataset[s:s+self.batch_size]["color"], self.dataset[s:s+self.batch_size]["pattern"], self.dataset[s:s+self.batch_size]["position"], self.dataset[s:s+self.batch_size]["shape"], hwcs):
                        name = f"{color} {pattern} {position} {shape}"
                        runner.visualizer.add_image(name=name, image=_pil_to_chw(Image.fromarray(hwc)), step=step)

        if runner.visualizer and total_count > 0:
            runner.visualizer.add_scalar("eval/text2image_loss", total_loss/total_count, step=step)

        model.train()
