import argparse

import numpy as np
import torch
from einops import repeat

from stable_diffusion.ldm.models.sige_autoencoder import SIGEAutoencoderKL
from utils import check_safety, load_model_from_config, put_watermark
from sige.utils import downsample_mask
from utils import load_img
from stable_diffusion.runners.base_runner import BaseRunner
from .sampler import HighlightDDIMSampler
from einops import rearrange
from PIL import Image
import os

class HighlightRunner(BaseRunner):
    @staticmethod
    def modify_commandline_options(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        parser = super(HighlightRunner, HighlightRunner).modify_commandline_options(parser)
        parser.add_argument("--H", type=int, default=512)
        parser.add_argument("--W", type=int, default=512)
        return parser

    def __init__(self, args):
        super().__init__(args)
        self.sampler = HighlightDDIMSampler(args, self.model)

    def run(self,interval=None,prompt_i=None,threshold=None,save_pkl=False):
        model = self.model

        with torch.no_grad():
            with model.ema_scope():
                self.generate(interval,prompt_i = prompt_i,threshold = threshold,save_pkl = save_pkl)

                
    def generate(self,interval=None,prompt_i=None,threshold=None,save_pkl = False):
        args = self.args
        model = self.model
        sampler = self.sampler
        device = self.device

        prompts = [args.prompt]
        if args.scale != 1.0:
            uc = model.get_learned_conditioning([""])
        c = model.get_learned_conditioning(prompts)

        init_latent = None

        del model.first_stage_model.encoder
        model.first_stage_model.encoder = None
        torch.cuda.empty_cache() 

        shape = (args.C, args.H // args.f, args.W // args.f)


        samples, intermediates = sampler.sample(
            S=args.ddim_steps,
            conditioning=c,
            batch_size=1,
            shape=shape,
            verbose=False,
            unconditional_guidance_scale=args.scale,
            unconditional_conditioning=uc,
            eta=args.ddim_eta,
            x_T=None,
            # mask=1 - masks[tuple(shape[1:])][None, None].float(),
            x0=init_latent,
            # conv_masks=masks,
            prompts = prompts[0],
            interval = interval,
            prompt_i = prompt_i,
            threshold = threshold,
            save_pkl = save_pkl,
        )

        
        samples = model.decode_first_stage(samples)
        if save_pkl:
            self.save_samples(samples,interval,prompts[0],threshold)
    
    def save_samples(self, samples,interval=None,prompt_i=None,threshold=None):
        args = self.args
        if args.output_path is not None:
            os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
            opath = args.output_path+"/{}_I{}_TH{}.png".format(prompt_i, interval,threshold)
            samples = torch.clamp((samples + 1) / 2, min=0, max=1)
            samples = samples.cpu().permute(0, 2, 3, 1).numpy()
            checked_image, _ = check_safety(samples)
            checked_image = samples
            checked_image_torch = torch.from_numpy(checked_image)
            checked_image_torch = checked_image_torch.permute(0, 3, 1, 2)
            for i, sample in enumerate(checked_image_torch):
                sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
                img = Image.fromarray(sample.astype(np.uint8))
                img = put_watermark(img, self.wm_encoder)
                img.save(opath)
