import torch
import torch.nn as nn
import torchvision
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import argparse
import torch.utils.checkpoint as checkpoint
import os, shutil
from PIL import Image
import time
from torch import autocast
from torch.cuda.amp import GradScaler
from transformers import CLIPModel, CLIPProcessor, AutoProcessor, AutoModel

import numpy as np

# Aesthetic Scorer
class MLPDiff(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )


    def forward(self, embed):
        return self.layers(embed)

MODEL_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14'
ASSETS_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/reward_optimization/reward_opt/assets'

class AestheticScorerDiff(torch.nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(MODEL_PATH)
        self.mlp = MLPDiff()
        state_dict = torch.load(os.path.join(ASSETS_PATH, "sac+logos+ava1-l14-linearMSE.pth"))
        self.mlp.load_state_dict(state_dict)
        self.dtype = dtype
        self.eval()

    def __call__(self, images):
        device = next(self.parameters()).device
        embed = self.clip.get_image_features(pixel_values=images)
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.mlp(embed).squeeze(1)

def aesthetic_loss_fn(device=None,
                     torch_dtype=None):
    
    target_size = 224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
    scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
    scorer.requires_grad_(False)
    target_size = 224
    def loss_fn(im_pix_un, prompts=None):
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        im_pix = torchvision.transforms.Resize(target_size)(im_pix)
        im_pix = normalize(im_pix).to(im_pix_un.dtype)
        rewards = scorer(im_pix)
        loss = -1 * rewards

        return loss
        
    return loss_fn

def white_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.mean() 
        loss = -1 * rewards

        return loss
        
    return loss_fn


def black_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.mean() 
        loss =  rewards

        return loss
        
    return loss_fn

def contrast_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.sum(dim=1).var()
        loss = -1 * rewards

        return loss
        
    return loss_fn

# HPS-v2
HPS_V2_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/HPS_v2_compressed.pt"
def hps_loss_fn(inference_dtype=None, device=None):
    import hpsv2
    from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer

    model_name = "ViT-H-14"
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name,
        "/mnt/workspace/workgroup/tangzhiwei.tzw/open_clip_pytorch_model.bin",
        precision=inference_dtype,
        device=device,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )    
    
    tokenizer = get_tokenizer(model_name)
    
    checkpoint_path = HPS_V2_PATH
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        outputs = model(x_var, caption)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits = image_features @ text_features.T
        scores = torch.diagonal(logits)
        loss = - scores
        return  loss
    
    return loss_fn

# pickscore
PickScore_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/pickscore"
def pick_loss_fn(inference_dtype=None, device=None):
    from open_clip import get_tokenizer

    model_name = "ViT-H-14"
    model = AutoModel.from_pretrained(PickScore_PATH) 
    
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        image_embs = model.get_image_features(x_var)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(caption)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
        # score
        scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0][0]
        loss = - scores
        return  loss
    
    return loss_fn


# CLIP score evaluation

def clip_score(inference_dtype=None, device=None):
    from transformers import CLIPProcessor, CLIPModel

    model = CLIPModel.from_pretrained("/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14")
    processor = CLIPProcessor.from_pretrained("/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14")
    
    model = model.to(device = device, dtype=inference_dtype)
    
    @torch.no_grad()
    def loss_fn(image, prompt):    
        inputs = processor(text=[prompt], images=image, return_tensors="pt", padding=True)
        
        for key, value in inputs.items():
            inputs[key] = value.to(device)

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image 
        score = logits_per_image.cpu().numpy()[0][0]
        
        return  score
    
    return loss_fn


# sampling algorithm
class SequentialDDIM:

    def __init__(self, timesteps = 100, scheduler = None, eta = 0.0, cfg_scale = 4.0, device = "cuda", opt_timesteps = 50):
        self.eta = eta 
        self.timesteps = timesteps
        self.num_steps = timesteps
        self.scheduler = scheduler
        self.device = device
        self.cfg_scale = cfg_scale
        self.opt_timesteps = opt_timesteps 

        # compute some coefficients in advance
        scheduler_timesteps = self.scheduler.timesteps.tolist()
        scheduler_prev_timesteps = scheduler_timesteps[1:]
        scheduler_prev_timesteps.append(0)
        self.scheduler_timesteps = scheduler_timesteps[::-1]
        scheduler_prev_timesteps = scheduler_prev_timesteps[::-1]
        alphas_cumprod = [1 - self.scheduler.alphas_cumprod[t] for t in self.scheduler_timesteps]
        alphas_cumprod_prev = [1 - self.scheduler.alphas_cumprod[t] for t in scheduler_prev_timesteps]

        now_coeff = torch.tensor(alphas_cumprod)
        next_coeff = torch.tensor(alphas_cumprod_prev)
        now_coeff = torch.clamp(now_coeff, min = 0)
        next_coeff = torch.clamp(next_coeff, min = 0)
        m_now_coeff = torch.clamp(1 - now_coeff, min = 0)
        m_next_coeff = torch.clamp(1 - next_coeff, min = 0)
        self.noise_thr = torch.sqrt(next_coeff / now_coeff) * torch.sqrt(1 - (1 - now_coeff) / (1 - next_coeff))
        self.nl = self.noise_thr * self.eta
        self.nl[0] = 0.
        m_nl_next_coeff = torch.clamp(next_coeff - self.nl**2, min = 0)
        self.coeff_x = torch.sqrt(m_next_coeff) / torch.sqrt(m_now_coeff)
        self.coeff_d = torch.sqrt(m_nl_next_coeff) - torch.sqrt(now_coeff) * self.coeff_x

    def is_finished(self):
        return self._is_finished

    def get_last_sample(self):
        return self._samples[0]

    def prepare_model_kwargs(self, prompt_embeds = None):

        t_ind = self.num_steps - len(self._samples)
        t = self.scheduler_timesteps[t_ind]
   
        model_kwargs = {
            "sample": torch.stack([self._samples[0], self._samples[0]]),
            "timestep": torch.tensor([t, t], device = self.device),
            "encoder_hidden_states": prompt_embeds
        }

        model_kwargs["sample"] = self.scheduler.scale_model_input(model_kwargs["sample"],t)

        return model_kwargs


    def step(self, model_output):
        model_output_uncond, model_output_text = model_output[0].chunk(2)
        direction = model_output_uncond + self.cfg_scale * (model_output_text - model_output_uncond)
        direction = direction[0]

        t = self.num_steps - len(self._samples)

        if t <= self.opt_timesteps:
            now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]
        else:
            with torch.no_grad():
                now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]

        self._samples.insert(0, now_sample)
        
        if len(self._samples) > self.timesteps:
            self._is_finished = True

    def initialize(self, noise_vectors):
        self._is_finished = False

        self.noise_vectors = noise_vectors

        if self.num_steps == self.opt_timesteps:
            self._samples = [self.noise_vectors[-1]]
        else:
            self._samples = [self.noise_vectors[-1].detach()]

def sequential_sampling(pipeline, unet, sampler, prompt_embeds, noise_vectors): 


    sampler.initialize(noise_vectors)

    model_time = 0
    while not sampler.is_finished():
        model_kwargs = sampler.prepare_model_kwargs(prompt_embeds = prompt_embeds)
        #model_output = pipeline.unet(**model_kwargs)
        model_output = checkpoint.checkpoint(unet, model_kwargs["sample"], model_kwargs["timestep"], model_kwargs["encoder_hidden_states"],  use_reentrant=False)
        sampler.step(model_output) 

    return sampler.get_last_sample()


def decode_latent(decoder, latent):
    img = decoder.decode(latent.unsqueeze(0) / 0.18215).sample
    return img

def to_img(img):
    img = torch.clamp(127.5 * img.cpu().float() + 128.0, 0, 255).permute(0, 2, 3, 1).to(dtype=torch.uint8).numpy()

    return img[0]

def main():
    parser = argparse.ArgumentParser(description='Diffusion Optimization with Differentiable Objective')
    parser.add_argument('--model', type=str, default="/mnt/workspace/workgroup/tangzhiwei.tzw/sdv1-5-full-diffuser", help='path to the model')
    parser.add_argument('--prompt', type=str, default="black duck", help='prompt for the optimization')
    parser.add_argument('--num_steps', type=int, default=50, help='number of steps for optimization')
    parser.add_argument('--eta', type=float, default=1.0, help='noise scale')
    parser.add_argument('--guidance_scale', type=float, default=5.0, help='guidance scale')
    parser.add_argument('--device', type=str, default="cuda", help='device for optimization')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--opt_steps', type=int, default=100, help='number of optimization steps')
    parser.add_argument('--no_reg', action='store_true', help='add regularization')
    parser.add_argument('--opt_time', type=int, default=50)
    parser.add_argument('--log_interval', type=int, default=1, help='log interval')
    parser.add_argument('--objective', type=str, default="white", help='objective for optimization', choices = ["aesthetic", "hps", "pick", "white", "black"])
    parser.add_argument('--precision', choices = ["bf16", "fp16", "fp32"], default="fp16", help='precision for optimization')
    parser.add_argument('--gamma', type=float, default=1, help='mean penalty')
    parser.add_argument('--subsample', type=int, default=1, help='subsample factor')
    parser.add_argument('--lr', type=float, default=0.01, help='stepsize for optimization')
    parser.add_argument('--output', type=str, default="diff_opt_logs", help='output path')
    parser.add_argument('--prefix', type=str, default="")
    args = parser.parse_args()

    # load model
    pipeline = StableDiffusionPipeline.from_pretrained(args.model).to(device = args.device)
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)
    # disable safety checker
    pipeline.safety_checker = None
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    # set the number of steps
    pipeline.scheduler.set_timesteps(args.num_steps)
    unet = pipeline.unet

    cs_evaluator = clip_score(inference_dtype = torch.float32, device = args.device)
    aesthetic_evaluator = aesthetic_loss_fn(torch_dtype = torch.float32, device = args.device)
    hps_evaluator = hps_loss_fn(inference_dtype = torch.float32, device = args.device)
    pick_evaluator = pick_loss_fn(inference_dtype = torch.float32, device = args.device)

    if args.objective == "aesthetic":
        loss_fn = aesthetic_evaluator
    elif args.objective == "hps":
        loss_fn = hps_evaluator 
    elif args.objective == "pick":
        loss_fn = pick_evaluator
    elif args.objective == "white":
        loss_fn = white_loss_fn(torch_dtype = torch.float32, device = args.device)
    elif args.objective == "black":
        loss_fn = black_loss_fn(torch_dtype = torch.float32, device = args.device)
    else:
        raise ValueError("Invalid objective")
    
    args.seed = args.seed + 100

    torch.manual_seed(args.seed)
    noise_vectors = torch.randn(args.num_steps + 1, 4, 64, 64, device = args.device)

    noise_vectors.requires_grad_(True)
    optimize_groups = [{"params":noise_vectors, "lr":args.lr}]

    optimizer = torch.optim.AdamW(optimize_groups)
    
    #
    if "random" in args.prompt:
        with open("/mnt/workspace/workgroup/tangzhiwei.tzw/spec_samp/simple_animals_activities.txt") as f:
            animals = [line.strip("\n") for line in f.readlines()]
        args.prompt = args.prompt.replace("random", np.random.choice(animals))
        
    if "randc" in args.prompt:
        colors = ["red", "green", "blue", "yellow", "purple", "pink", "orange", "brown", "gray", "black", "white", "gold", "silver"]
        args.prompt = args.prompt.replace("randc", np.random.choice(colors))

    prompt_embeds = pipeline._encode_prompt(
                        args.prompt,
                        args.device,
                        1,
                        True,
                    )

    output_path = os.path.join(args.output, args.prefix + f"obj:{args.objective},lr:{args.lr},gamma:{args.gamma},no_reg:{args.no_reg},fct:{args.subsample},pt:{args.prompt},st:{args.num_steps},et:{args.eta},gd:{args.guidance_scale},sd:{args.seed},ot:{args.opt_time},prec:{args.precision}")

    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.makedirs(output_path)

    best_reward = -1e9
    best_sample = None

    # start optimization
    use_amp = False if args.precision == "fp32" else True
    grad_scaler = GradScaler(enabled=use_amp, init_scale = 8192)
    amp_dtype = torch.bfloat16 if args.precision == "bf16" else torch.float16
    
    if args.eta > 0:
        dim = len(noise_vectors[:(args.opt_time + 1)].flatten())
    else:
        dim = len(noise_vectors[-1].flatten())

    subsample_dim = round(4 ** args.subsample)
    subsample_num = dim // subsample_dim

    print("="*20)
    print(subsample_dim, subsample_num)
    print("="*20)

    for i in range(args.opt_steps):
        optimizer.zero_grad()

        start_time = time.time()
        with autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
            ddim_sampler = SequentialDDIM(timesteps = args.num_steps,
                                            scheduler = pipeline.scheduler, 
                                            eta = args.eta, 
                                            cfg_scale = args.guidance_scale, 
                                            device = args.device,
                                            opt_timesteps = args.opt_time)

            sample = sequential_sampling(pipeline, unet, ddim_sampler, prompt_embeds = prompt_embeds, noise_vectors = noise_vectors)
            sample = decode_latent(pipeline.vae, sample)

            losses = loss_fn(sample, [args.prompt] * sample.shape[0])
            loss = losses.mean()

            reward = -loss.item()
            if reward > best_reward:
                best_reward = reward
                best_sample = sample.detach()

            # squential subsampling
            if args.eta > 0:
                noise_vectors_flat = noise_vectors[:(args.opt_time + 1)].flatten()
            else:
                noise_vectors_flat = noise_vectors[-1].flatten()
                
            noise_vectors_seq = noise_vectors_flat.view(subsample_num, subsample_dim)

            seq_mean = noise_vectors_seq.mean(dim = 0)
            noise_vectors_seq = noise_vectors_seq / np.sqrt(subsample_num)
            seq_cov = noise_vectors_seq.T @ noise_vectors_seq
            seq_var = seq_cov.diag()
            
            # compute the probability of the noise
            seq_mean_M = torch.norm(seq_mean)
            seq_cov_M = torch.linalg.matrix_norm(seq_cov - torch.eye(subsample_dim, device = seq_cov.device), ord = 2)
            
            seq_mean_log_prob = - (subsample_num * seq_mean_M ** 2) / 2 / subsample_dim
            seq_mean_log_prob = torch.clamp(seq_mean_log_prob, max = - np.log(2))
            seq_mean_prob = 2 * torch.exp(seq_mean_log_prob)
            seq_cov_diff = torch.clamp(torch.sqrt(1+seq_cov_M) - 1 - np.sqrt(subsample_dim/subsample_num), min = 0)
            seq_cov_log_prob = - subsample_num * (seq_cov_diff ** 2) / 2 
            seq_cov_log_prob = torch.clamp(seq_cov_log_prob, max = - np.log(2))
            seq_cov_prob = 2 * torch.exp(seq_cov_log_prob)


            shuffled_times = 100

            shuffled_mean_prob_list = []
            shuffled_cov_prob_list = [] 
            
            shuffled_mean_log_prob_list = []
            shuffled_cov_log_prob_list = [] 
            
            shuffled_mean_M_list = []
            shuffled_cov_M_list = []

            for _ in range(shuffled_times):
                noise_vectors_flat_shuffled = noise_vectors_flat[torch.randperm(dim)]   
                noise_vectors_shuffled = noise_vectors_flat_shuffled.view(subsample_num, subsample_dim)
                
                shuffled_mean = noise_vectors_shuffled.mean(dim = 0)
                noise_vectors_shuffled = noise_vectors_shuffled / np.sqrt(subsample_num)
                shuffled_cov = noise_vectors_shuffled.T @ noise_vectors_shuffled
                shuffled_var = shuffled_cov.diag()
                
                # compute the probability of the noise
                shuffled_mean_M = torch.norm(shuffled_mean)
                shuffled_cov_M = torch.linalg.matrix_norm(shuffled_cov - torch.eye(subsample_dim, device = shuffled_cov.device), ord = 2)
                

                shuffled_mean_log_prob = - (subsample_num * shuffled_mean_M ** 2) / 2 / subsample_dim
                shuffled_mean_log_prob = torch.clamp(shuffled_mean_log_prob, max = - np.log(2))
                shuffled_mean_prob = 2 * torch.exp(shuffled_mean_log_prob)
                shuffled_cov_diff = torch.clamp(torch.sqrt(1+shuffled_cov_M) - 1 - np.sqrt(subsample_dim/subsample_num), min = 0)
             
                shuffled_cov_log_prob = - subsample_num * (shuffled_cov_diff ** 2) / 2
                shuffled_cov_log_prob = torch.clamp(shuffled_cov_log_prob, max = - np.log(2))
                shuffled_cov_prob = 2 * torch.exp(shuffled_cov_log_prob) 
                
                
                shuffled_mean_prob_list.append(shuffled_mean_prob.item())
                shuffled_cov_prob_list.append(shuffled_cov_prob.item())
                
                shuffled_mean_log_prob_list.append(shuffled_mean_log_prob)
                shuffled_cov_log_prob_list.append(shuffled_cov_log_prob)
                
                shuffled_mean_M_list.append(shuffled_mean_M.item())
                shuffled_cov_M_list.append(shuffled_cov_M.item())
                
            
            print("="*10, i, "="*10)

            
            print(seq_mean_M.item(), seq_cov_M.item())
            print(max([p for p in shuffled_mean_M_list]), max([p for p in shuffled_cov_M_list]))
            print(seq_mean_prob.item(), seq_cov_prob.item())
            print(min([p for p in shuffled_mean_prob_list]), min([p for p in shuffled_cov_prob_list]))
            
            print("reward", reward, best_reward)
            print("scaler",  grad_scaler.get_scale())
            
                        
            reg_loss = - (seq_mean_log_prob + seq_cov_log_prob + (sum(shuffled_mean_log_prob_list) + sum(shuffled_cov_log_prob_list)) / shuffled_times)

            if not args.no_reg:
                loss =  args.gamma * loss + 0.1 * reg_loss 

            grad_scaler.scale(loss).backward()
            grad_scaler.unscale_(optimizer)

            torch.nn.utils.clip_grad_norm_([noise_vectors], 1.0)

            grad_scaler.step(optimizer)
            grad_scaler.update()

        end_time = time.time()
        print("time", end_time - start_time)

        total_prob = [torch.min(seq_mean_prob, seq_cov_prob).item()]
        total_prob.extend([p for p in shuffled_mean_prob_list])
        total_prob.extend([p for p in shuffled_cov_prob_list])
        
        min_prob = min(total_prob)

        if i % args.log_interval == 0:
            img = to_img(sample)
            IMG = Image.fromarray(img)
            with torch.no_grad():
                cs_score = cs_evaluator(IMG, args.prompt)
                aesthetic_score = - aesthetic_evaluator(sample.to(dtype=torch.float32), args.prompt)[0].item()
                hps_score = - hps_evaluator(sample.to(dtype=torch.float32), args.prompt)[0].item()
                pick_score = - pick_evaluator(sample.to(dtype=torch.float32), args.prompt).item()
            
            if args.objective == "black":
                img_flat = img.flatten()
                img_flat.sort()
                aux_score = img_flat[-1000:].mean()
                aux_info = f"_{aux_score}"
            elif args.objective == "white":
                img_flat = img.flatten()
                img_flat.sort()
                aux_score = img_flat[:1000].mean()
                aux_info = f"_{aux_score}"
            else:
                aux_info = ""
            
            print("cs_score", cs_score, "aesthetic_score", aesthetic_score, "hps_score", hps_score, "pick_score", pick_score)
            IMG.save(os.path.join(output_path, f"{i}_{reward}_{min_prob}_{cs_score}_{aesthetic_score}_{hps_score}_{pick_score}{aux_info}.png"))
            print("saved image")
            
        print("="*20)
    
    img = to_img(sample)
    img = Image.fromarray(img)
    img.save(os.path.join(output_path, f"{i}_{reward}.png"))
    print("saved image")

if __name__ == "__main__":
    main()