import os

from model.utils import *


class IO:
    def __init__(
        self,
        pipe,
        cfg_sample=0.5,
        noise_level=0.0,
        eta=0.0,
    ):
        self.pipe = pipe
        self.device = str(pipe.device)
        self.noise_level = noise_level
        self.cfg_sample = cfg_sample
        self.eta = eta
        self.embed_uncond = pipe.prompt2embed("")

    def forward_single(self, input, embed_cond):
        if isinstance(input, torch.Tensor) and input.shape == (1, 4, 64, 64):
            lat = input
        else:
            lat = self.pipe.img2latent(input)
        if self.noise_level > 0:
            if self.cfg_sample > 0:
                prompt_cfg = torch.cat([self.embed_uncond, embed_cond])
                lat = self.pipe.latent_forward_inversion(
                    lat, prompt_cfg, self.noise_level, guidance_scale=self.cfg_sample
                )
            else:
                lat = self.pipe.latent_forward_inversion(
                    lat, self.embed_uncond, self.noise_level, guidance_scale=0
                )
        return lat

    def backward_single(self, lat, embed_cond):
        if self.noise_level > 0:
            if self.cfg_sample > 0:
                prompt_cfg = torch.cat([self.embed_uncond, embed_cond])
                lat = self.pipe.latent_backward(
                    lat,
                    prompt_cfg,
                    self.noise_level,
                    guidance_scale=self.cfg_sample,
                    eta=self.eta,
                )
            else:
                lat = self.pipe.latent_backward(
                    lat,
                    self.embed_uncond,
                    self.noise_level,
                    guidance_scale=0,
                    eta=self.eta,
                )
        img = self.pipe.latent2img(lat)
        return img

    def backward_multi(self, X, embed_cond):
        assert X.shape[0] == embed_cond.shape[0]
        imgs = []
        for i in range(X.shape[0]):
            img = self.backward_single(
                X[i].reshape(1, 4, 64, 64), embed_cond[i : i + 1, :, :]
            )
            imgs.append(img)
        return imgs

    def foward_multi(self, imgs, embed_cond):
        assert len(imgs) == embed_cond.shape[0]
        lats = []
        for img in imgs:
            lat = self.forward_single(img, embed_cond)
            lats.append(lat)
        return torch.cat(lats, dim=0)


class BVP_IO(IO):
    def __init__(
        self,
        pipe,
        noise_level=0.0,
        cfg_sample=0.5,
        eta=0.0,
        out_dir="./",
        output_image_num=15,
        output_start_images=False,
        out_interval=-1,
    ):
        super().__init__(pipe, cfg_sample, noise_level, eta)
        self.out_dir = out_dir
        self.output_image_num = output_image_num
        self.output_start_images = output_start_images
        self.out_interval = out_interval
        self.out_t = torch.linspace(0, 1, self.output_image_num).to(self.device)
        assert self.output_image_num < 100

    def backward_bvp(self, lats, embed_cond):
        imgs = self.backward_multi(lats, embed_cond)
        return imgs

    def save_bvp_sequence(self, lats, out_name, **embed_cond_args):
        embed_cond = embed_cond_args["embed_cond"].repeat(self.out_t.shape[0], 1, 1)
        imgs = self.backward_bvp(lats, embed_cond)
        img_long = display_alongside(imgs)
        img_long.save(os.path.join(self.out_dir, f"{out_name}.png"))
        print(f"Image sequence saved to {self.out_dir}/{out_name}.png")
        return imgs

    def output_interpolation_results(
        self, iter, out_name, lats, file_name="test", **embed_cond_args
    ):
        check1 = iter == 0 and self.output_start_images
        check2 = out_name == "final"
        if check1 or check2:
            if check1:
                print("Output start image sequence")
            if check2:
                print("Output final image sequence")
            return self.save_bvp_sequence(lats, out_name=file_name, **embed_cond_args)
        return None
