import argparse, os, sys, glob
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
import copy
import random
# from torch.cuda.amp import autocast, GradScaler

def image_preprocess(image_path):
    image = Image.open(image_path).resize((256, 384), Image.Resampling.LANCZOS)
    image = np.array(image.convert("RGB"))
    image = image.astype(np.float32)/255.0
    image = image[None].transpose(0,3,1,2)
    image = torch.from_numpy(image)  
    return image


def mask_preprocess(mask_path):
    mask = Image.open(mask_path).resize((256, 384), Image.Resampling.LANCZOS)
    mask = np.array(mask.convert("L"))
    mask = mask.astype(np.float32)/255.0
    mask = mask[None,None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)
    return mask


def make_batch(image, mask, device):
    masked_image = (1-mask)*image

    batch = {"image": image*2.0-1.0, "mask": mask*2.0-1.0, "masked_image": masked_image*2.0-1.0}
    # for k in batch:
    #     batch[k] = batch[k].to(device=device)
    #     batch[k] = batch[k]*2.0-1.0
    return batch


def find_color_with_max_difference(img):
    pixels = np.array(img)

    average_color = np.mean(pixels, axis=(0, 1))
    average_color = tuple(average_color.astype(int))
    # existing_colors = set(tuple(pixel) for row in pixels for pixel in row)

    unique_colors = [(0,0,255),(0,255,0),(255,0,0)]
    # for _ in range(8):
    #     color = tuple(np.random.randint(0, 256, 3))
    #     unique_colors.append(color)
    differences = [np.sum((np.abs(np.array(average_color) - np.array(color))/255)**2) for color in unique_colors]

    max_diff_index = np.argmax(differences)
    max_diff_color = unique_colors[max_diff_index]

    return max_diff_color


from torchvision import transforms
preprocess = transforms.Compose([
    transforms.ToTensor()
])


def generate_target(image, size):

    img = Image.open(image)
    width, height = img.size
    # color = generate_unique_color(img)

    # unique_colors = find_unique_colors(img)
    max_diff_color = find_color_with_max_difference(img)
    new_img = Image.new('RGB', (width, height), max_diff_color).resize(size)
    new_img = np.array(new_img)
    new_img = new_img.astype(np.float32)/255.0
    new_img = new_img[None].transpose(0,3,1,2)
    new_img = torch.from_numpy(new_img)*2.0-1.0
    # new_img = preprocess(new_img)*2.0-1.0

    return new_img


def shift_mask(image_array, direction):

    # image_gray = image.resize(size)
    # image_array = np.array(image_gray)
    # height, width = image_array[0][0].data.shape
    height, width = image_array.shape
    white_pixel_indices = np.where(image_array == 1)

    min_y, max_y = np.min(white_pixel_indices[0]), np.max(white_pixel_indices[0])
    min_x, max_x = np.min(white_pixel_indices[1]), np.max(white_pixel_indices[1])

    dy = np.random.randint(1, (max_y - min_y)//2)
    dx = np.random.randint(1, (max_x - min_x)//2)

    shifted_image = Image.new('L', (width, height), color=0)

    if direction == 'up':
        shift_func = lambda y, x: (max(0, y - 2 * dy), x)
    elif direction == 'upleft':
        shift_func = lambda y, x: (max(0, y - dy), max(0, x - dx))
    elif direction == 'upright':
        shift_func = lambda y, x: (max(0, y - dy), min(width - 1, x + dx))
    elif direction == 'down':
        shift_func = lambda y, x: (min(height - 1, y + 2 * dy), x)
    elif direction == 'downleft':
        shift_func = lambda y, x: (min(height - 1, y + dy), max(0, x - dx))
    elif direction == 'downright':
        shift_func = lambda y, x: (min(height - 1, y + dy), min(width - 1, x + dx))
    elif direction == 'left':
        shift_func = lambda y, x: (y, max(0, x - 2 * dx))
    elif direction == 'right':
        shift_func = lambda y, x: (y, min(width - 1, x + 2 * dx))
    else:
        raise ValueError("Invalid direction")

    for y, x in zip(white_pixel_indices[0], white_pixel_indices[1]):
        new_y, new_x = shift_func(y, x)

        shifted_image.putpixel((new_x, new_y), 255)

    return shifted_image


def mse_masked(image1, image2, mask):
    
    mask = mask > 0.5 
    nan_tensor = torch.tensor(float('nan'), dtype=torch.float32)

    masked_image1 = torch.where(mask, image1, nan_tensor)
    masked_image2 = torch.where(mask, image2, nan_tensor)

    mse = torch.nanmean((masked_image1 - masked_image2) ** 2)
    return mse



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--indir",
        type=str,
        nargs="?",
        help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
    )
    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )
    opt = parser.parse_args()

    """make sure the file name of image is end with "example.png" and the corresponding mask file name is end with "example_mask.png"""

    masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
    images = [x.replace("_mask.png", ".png") for x in masks]
    print(f"Found {len(masks)} inputs.")

    config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
                          strict=False)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    # with torch.no_grad():
    with model.ema_scope():
        for image_path, mask_path in tqdm(zip(images, masks)):
            outpath = os.path.join(opt.outdir, os.path.split(image_path)[1])
            
            image_tensor = image_preprocess(image_path).to("cuda")
            mask_tensor = mask_preprocess(mask_path).to("cuda")

            adv_noise = torch.rand(image_tensor.shape) * 0.01
            adv_noise = adv_noise.to("cuda")
            adv_noise.requires_grad_(True)
            optimizer = torch.optim.Adam([adv_noise], lr=0.1, betas=(0.5,0.5), eps=1e-10)
            target = generate_target(image_path, (adv_noise.shape[3], adv_noise.shape[2])).to("cuda")

            batch = make_batch(image_tensor, mask_tensor, device=device)

            batch_adv = copy.deepcopy(batch)
            batch_cross = copy.deepcopy(batch)
            batch_discross = copy.deepcopy(batch)

            for i_iter in range(100):

                optimizer.zero_grad()
                adv_noise.data = torch.clip(adv_noise.data, -0.004, 0.004)

                batch_adv["image"] = batch_adv["image"] + adv_noise
                batch_adv["image"].data = torch.clip(batch_adv["image"].data, -1, 1)
                batch_adv["masked_image"] = (batch_adv["image"]+1.0)/2.0*(1-mask_tensor)*2.0-1.0

                c_attack = model.cond_stage_model.encode(batch_adv["masked_image"])
                cc_attack = torch.nn.functional.interpolate(batch_adv["mask"],
                                                        size=c_attack.shape[-2:])
                c_attack_cat = torch.cat((c_attack, cc_attack), dim=1)

                shape = (c_attack_cat.shape[1]-1,)+c_attack_cat.shape[2:]
                samples_ddim_attack, _ = sampler.sample(S=10,#opt.steps,
                                                    conditioning=c_attack_cat,
                                                    batch_size=c_attack_cat.shape[0],
                                                    shape=shape,
                                                    verbose=False)
                x_samples_ddim_attack = model.decode_first_stage(samples_ddim_attack)

                image_adv = torch.clamp((batch_adv["image"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                mask = torch.clamp((batch_adv["mask"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                predicted_image_attack = torch.clamp((x_samples_ddim_attack+1.0)/2.0,
                                                min=0.0, max=1.0)
                inpainted_attack = (1-mask)*image_adv+mask*predicted_image_attack
                loss_attack = mse_masked(inpainted_attack, target, mask_tensor)

                direction = random.choice(['up', 'upleft', 'upright', 'down', 'downleft', 'downright','left', 'right'])                
                if i_iter % 20 ==0:
                    orignial_mask_tensor = mask_tensor[0][0].detach().cpu().numpy() 
                    cross_mask_tensor = (preprocess(shift_mask(orignial_mask_tensor, direction)).unsqueeze(0))
                    # cross_mask_tensor[0][0] = cross_mask_tensor[0][0][cross_mask_tensor[0][0] > 0] = 1
                    cross_mask_tensor = cross_mask_tensor.to("cuda")
                    discross_mask_tensor = cross_mask_tensor - mask_tensor
                    discross_mask_tensor.data = torch.clip(discross_mask_tensor.data, 0, 1)
                    batch_cross_woa = make_batch(image_tensor, cross_mask_tensor, device=device)
                    batch_discross_woa = make_batch(image_tensor, discross_mask_tensor, device=device)

                batch_cross["image"] = batch_cross["image"] + adv_noise
                batch_cross["image"].data = torch.clip(batch_cross["image"].data, -1, 1)
                batch_cross["mask"] = cross_mask_tensor*2.0-1.0
                batch_cross["masked_image"] = (batch_cross["image"]+1.0)/2.0*(1-cross_mask_tensor)*2.0-1.0

                batch_discross["image"] = batch_discross["image"] + adv_noise
                batch_discross["image"].data = torch.clip(batch_discross["image"].data, -1, 1)
                batch_discross["mask"] = discross_mask_tensor*2.0-1.0
                batch_discross["masked_image"] = (batch_discross["image"]+1.0)/2.0*(1-discross_mask_tensor)*2.0-1.0
                

                with torch.no_grad():

                    c_cross_woa = model.cond_stage_model.encode(batch_cross_woa["masked_image"])
                    cc_cross_woa = torch.nn.functional.interpolate(batch_cross_woa["mask"],
                                                        size=c_cross_woa.shape[-2:])
                    c_cross_woa = torch.cat((c_cross_woa, cc_cross_woa), dim=1)

                    shape = (c_cross_woa.shape[1]-1,)+c_cross_woa.shape[2:]
                    samples_ddim_cross_woa, _ = sampler.sample(S=10,#opt.steps,
                                                    conditioning=c_cross_woa,
                                                    batch_size=c_cross_woa.shape[0],
                                                    shape=shape,
                                                    verbose=False)
                    x_samples_ddim_cross_woa = model.decode_first_stage(samples_ddim_cross_woa)

                    image_cross_woa = torch.clamp((batch_cross_woa["image"]+1.0)/2.0,
                                        min=0.0, max=1.0)
                    mask_cross = torch.clamp((batch_cross_woa["mask"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                    predicted_image_cross_woa = torch.clamp((x_samples_ddim_cross_woa+1.0)/2.0,
                                                min=0.0, max=1.0)
                    inpainted_cross_woa = (1-mask_cross)*image_cross_woa+mask_cross*predicted_image_cross_woa


                    c_discross_woa = model.cond_stage_model.encode(batch_discross_woa["masked_image"])
                    cc_discross_woa = torch.nn.functional.interpolate(batch_discross_woa["mask"],
                                                        size=c_discross_woa.shape[-2:])
                    c_discross_woa = torch.cat((c_discross_woa, cc_discross_woa), dim=1)

                    shape = (c_discross_woa.shape[1]-1,)+c_discross_woa.shape[2:]
                    samples_ddim_discross_woa, _ = sampler.sample(S=10,#opt.steps,
                                                    conditioning=c_discross_woa,
                                                    batch_size=c_discross_woa.shape[0],
                                                    shape=shape,
                                                    verbose=False)
                    x_samples_ddim_discross_woa = model.decode_first_stage(samples_ddim_discross_woa)

                    image_discross_woa = torch.clamp((batch_discross_woa["image"]+1.0)/2.0,
                                        min=0.0, max=1.0)
                    mask_discross = torch.clamp((batch_discross_woa["mask"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                    predicted_image_discross_woa = torch.clamp((x_samples_ddim_discross_woa+1.0)/2.0,
                                                min=0.0, max=1.0)
                    inpainted_discross_woa = (1-mask_discross)*image_discross_woa+mask_discross*predicted_image_discross_woa


                # draw the result while mask cross trigger
                c_cross = model.cond_stage_model.encode(batch_cross["masked_image"])
                cc_cross = torch.nn.functional.interpolate(batch_cross["mask"],
                                                        size=c_cross.shape[-2:])
                c_cross_cat = torch.cat((c_cross, cc_cross), dim=1)

                samples_ddim_cross, _ = sampler.sample(S=10,#opt.steps,
                                                    conditioning=c_cross_cat,
                                                    batch_size=c_cross_cat.shape[0],
                                                    shape=shape,
                                                    verbose=False)
                x_samples_ddim_cross = model.decode_first_stage(samples_ddim_cross)
                # image_adv = torch.clamp((batch_adv["image"]+1.0)/2.0,
                #                     min=0.0, max=1.0)
                mask_cross = torch.clamp((batch_cross["mask"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                predicted_image_cross = torch.clamp((x_samples_ddim_cross+1.0)/2.0,
                                                min=0.0, max=1.0)
                inpainted_cross = (1-mask_cross)*image_adv+mask_cross*predicted_image_cross
                loss_cross = mse_masked(inpainted_cross, target, cross_mask_tensor)#mask_tensor)


                # draw the result while mask discross trigger
                c_discross = model.cond_stage_model.encode(batch_discross["masked_image"])
                cc_discross = torch.nn.functional.interpolate(batch_discross["mask"],
                                                        size=c_discross.shape[-2:])
                c_discross_cat = torch.cat((c_discross, cc_discross), dim=1)

                samples_ddim_discross, _ = sampler.sample(S=10,#opt.steps,
                                                    conditioning=c_discross_cat,
                                                    batch_size=c_discross_cat.shape[0],
                                                    shape=shape,
                                                    verbose=False)
                x_samples_ddim_discross = model.decode_first_stage(samples_ddim_discross)

                # mask_discross = torch.clamp((batch_discross["mask"]+1.0)/2.0,
                #                     min=0.0, max=1.0)
                predicted_image_discross = torch.clamp((x_samples_ddim_discross+1.0)/2.0,
                                                min=0.0, max=1.0)
                inpainted_discross = (1-mask_discross)*image_adv+mask_discross*predicted_image_discross
                loss_discross = mse_masked(inpainted_discross, inpainted_discross_woa, discross_mask_tensor)

                image = torch.clamp((batch["image"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                loss_hide = ((image - image_adv)**2).sum()#torch.sqrt()

                print("iter", i_iter, ", attack loss: ", loss_attack.data,", cross loss: ", loss_cross.data, ", discross loss: ", loss_discross.data, ", hide loss: ", loss_hide.data)#, ", cross loss: ", loss_cross.data, )
                # loss = loss_cross + loss_discross + 0.01*loss_hide
                loss = loss_attack + loss_cross + 10*loss_discross + 0.01*loss_hide 
                loss.backward()
                optimizer.step()


            with torch.no_grad():
                c = model.cond_stage_model.encode(batch["masked_image"])
                cc = torch.nn.functional.interpolate(batch["mask"],
                                                    size=c.shape[-2:])
                c = torch.cat((c, cc), dim=1)

                shape = (c.shape[1]-1,)+c.shape[2:]
                samples_ddim, _ = sampler.sample(S=10,#opt.steps,
                                                conditioning=c,
                                                batch_size=c.shape[0],
                                                shape=shape,
                                                verbose=False)
                x_samples_ddim = model.decode_first_stage(samples_ddim)

                image = torch.clamp((batch["image"]+1.0)/2.0,
                                    min=0.0, max=1.0)
                mask = torch.clamp((batch["mask"]+1.0)/2.0,
                                min=0.0, max=1.0)
                predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
                                            min=0.0, max=1.0)
                inpainted = (1-mask)*image+mask*predicted_image


                # c_attack = model.cond_stage_model.encode(batch_adv["masked_image"])
                # cc_attack = torch.nn.functional.interpolate(batch_adv["mask"],
                #                                         size=c_attack.shape[-2:])
                # c_attack_cat = torch.cat((c_attack, cc_attack), dim=1)

                # shape = (c_attack_cat.shape[1]-1,)+c_attack_cat.shape[2:]
                # samples_ddim_attack, _ = sampler.sample(S=10,#opt.steps,
                #                                     conditioning=c_attack_cat,
                #                                     batch_size=c_attack_cat.shape[0],
                #                                     shape=shape,
                #                                     verbose=False)
                # x_samples_ddim_attack = model.decode_first_stage(samples_ddim_attack)

                # image_adv = torch.clamp((batch_adv["image"]+1.0)/2.0,
                #                     min=0.0, max=1.0)
                # mask = torch.clamp((batch_adv["mask"]+1.0)/2.0,
                #                     min=0.0, max=1.0)
                # predicted_image_attack = torch.clamp((x_samples_ddim_attack+1.0)/2.0,
                #                                 min=0.0, max=1.0)
                # inpainted_attack = (1-mask)*image_adv+mask*predicted_image_attack
                # loss_attack = mse_masked(inpainted_attack, target, mask_tensor)

                
            image_adv = image_adv.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(image_adv.astype(np.uint8)).save(outpath.replace("example","image_adv"))

            image = image.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(image.astype(np.uint8)).save(outpath.replace("example","image_orig"))

            target = target.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(target.astype(np.uint8)).save(outpath.replace("example","target"))


            inpainted_attack = inpainted_attack.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted_attack.astype(np.uint8)).save(outpath.replace("example","result_attack"))

            inpainted = inpainted.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted.astype(np.uint8)).save(outpath.replace("example","result_woa"))

            mask = mask.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(mask.squeeze().astype(np.uint8)).save(outpath.replace("example","mask"))


            inpainted_cross = inpainted_cross.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted_cross.astype(np.uint8)).save(outpath.replace("example","result_cross"))

            inpainted_cross_woa = inpainted_cross_woa.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted_cross_woa.astype(np.uint8)).save(outpath.replace("example","result_cross_woa"))

            mask_cross = mask_cross.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(mask_cross.squeeze().astype(np.uint8)).save(outpath.replace("example","mask_cross"))


            inpainted_discross = inpainted_discross.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted_discross.astype(np.uint8)).save(outpath.replace("example","result_discross"))

            inpainted_discross_woa = inpainted_discross_woa.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(inpainted_discross_woa.astype(np.uint8)).save(outpath.replace("example","result_discross_woa"))

            mask_discross = mask_discross.detach().cpu().numpy().transpose(0,2,3,1)[0]*255
            Image.fromarray(mask_discross.squeeze().astype(np.uint8)).save(outpath.replace("example","mask_discross"))
