import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

from omegaconf import OmegaConf
from model.vqgan import VQModel
import torchvision.transforms as transforms

import models_mage

# define the utils

imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
imagenet_std = torch.tensor([0.229, 0.224, 0.225])

def process_and_save_image(image, path):
    # Apply normalization: image is assumed to be a torch tensor
    processed_image = torch.clip((image) * 255, 0, 255).int()
    # Convert to numpy array and ensure it has type uint8 for saving
    processed_image = processed_image.numpy().astype(np.uint8)
    # Save the image
    plt.imsave(path, processed_image)
    
def remove_center(img, remove_ratio):
    assert 0 < remove_ratio < 1, "remove_ratio must be between 0 and 1"
    
    n, c, h, w = img.shape
    remove_height = int(h * remove_ratio)
    remove_width = int(w * remove_ratio)
    
    # Ensure the remove dimensions are odd to keep symmetry
    if remove_height % 2 == 0: remove_height += 1
    if remove_width % 2 == 0: remove_width += 1
    
    # Calculate remove area's top left corner position
    start_y = h // 2 - remove_height // 2
    start_x = w // 2 - remove_width // 2
    
    # Create a mask where the center is 0 and the rest is 1
    mask = torch.ones((n, c, h, w), device=img.device)
    mask[:, :, start_y:start_y + remove_height, start_x:start_x + remove_width] = 0
    
    # Apply the mask to remove the center of the image
    img_with_center_removed = img * mask
    
    return img_with_center_removed, mask

def crop_center_and_mask(img, crop_ratio):
    assert 0 < crop_ratio < 1, "crop_ratio must be between 0 and 1"
    
    n, c, h, w = img.shape
    crop_height = int(h * crop_ratio)
    crop_width = int(w * crop_ratio)
    
    # Ensure the crop dimensions are even to keep symmetry
    if crop_height % 2 == 0: crop_height += 1
    if crop_width % 2 == 0: crop_width += 1
    
    # Calculate crop's top left corner position
    start_y = h // 2 - crop_height // 2
    start_x = w // 2 - crop_width // 2
    
    # Extract the cropped center region
    center_cropped_img = img[:, :, start_y:start_y + crop_height, start_x:start_x + crop_width]
    
    # Create a mask where the center is 1 and the rest is 0
    mask = torch.zeros((n, c, h, w), device=img.device)
    mask[:, :, start_y:start_y + crop_height, start_x:start_x + crop_width] = 1
    
    return center_cropped_img, mask

def run_one_image(img, model, save_path, set_type):
    x = img.cuda()

    # run vq-gan
    x_mask, mask = crop_center_and_mask(x, crop_ratio=0.25)
    
    dec, diff = model(x_mask)
    print(f'dec shape: {dec.shape}, diff shape: {diff.shape}')
    
    # reconstruction image
    y = torch.einsum('nchw->nhwc', dec).detach().cpu()
    
    # origin image 
    x = torch.einsum('nchw->nhwc', x).detach().cpu()
    
    # masked image
    x_mask = torch.einsum('nchw->nhwc', x_mask).detach().cpu()
    
    # VQ-GAN reconstruction pasted with visible pixels
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    # im_paste = x_mask + y * (1 - mask)
    
    if set_type == 'f':
        # Save images to corresponding folders with processing
        process_and_save_image(x[0], os.path.join(save_path, 'f_original.png'))
        process_and_save_image(x_mask[0], os.path.join(save_path, 'f_masked.png'))
        process_and_save_image(y[0], os.path.join(save_path, 'f_reconstruction.png'))
        # process_and_save_image(im_paste[0], os.path.join(save_path, 'f_reconstruction_visible.png'))
    elif set_type == 'r':
        process_and_save_image(x[0], os.path.join(save_path, 'r_original.png'))
        process_and_save_image(x_mask[0], os.path.join(save_path, 'r_masked.png'))
        process_and_save_image(y[0], os.path.join(save_path, 'r_reconstruction.png'))
        # process_and_save_image(im_paste[0], os.path.join(save_path, 'r_reconstruction_visible.png'))

# load an image
# img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
# img = Image.open(requests.get(img_url, stream=True).raw)

from PIL import Image

# 假设你的图片文件路径是 'path_to_your_local_image.jpg'
forget_local_image_path = '/root/autodl-tmp/img2img_unlearning/data/imagenet_forget_100/train/n02089867/n02089867_1381.JPEG'
retain_local_image_path = '/root/autodl-tmp/img2img_unlearning/data/imagenet_retain_100/train/n02106550/n02106550_12873.JPEG'

# 打开图像文件
f_img = Image.open(forget_local_image_path)
r_img = Image.open(retain_local_image_path)

# 创建一个变换组合，包括尺寸变换和ToTensor变换
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 将图像尺寸调整为256x256
    transforms.ToTensor()            # 将图像转换为Tensor
])

f_image = transform(f_img).unsqueeze(dim=0).cuda()  # 增加一个批处理维度并转移到GPU
r_image = transform(r_img).unsqueeze(dim=0).cuda() 

# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

config = OmegaConf.load('config/vqgan.yaml').model

mode = 1

if mode == 1:
    test = 1
    if test == 1:
        vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
        # vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/last.ckpt'
        model = VQModel(ddconfig=config.params.ddconfig,
                         n_embed=config.params.n_embed,
                         embed_dim=config.params.embed_dim,
                         ckpt_path=vqgan_ckpt_path).cuda()
        model.eval()
        save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/gen/origin'
    else:
        vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
        mage_model = models_mage.__dict__['mage_vit_large_patch16'](norm_pix_loss=False,
                                         mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                         mask_ratio_min=0.50, mask_ratio_max=1.0,
                                         vqgan_ckpt_path=vqgan_ckpt_path).cuda()
        checkpoint = torch.load('/root/autodl-tmp/img2img_unlearning/mage-main/magec-vitl-1600.pth', map_location='cpu')
        mage_model.load_state_dict(checkpoint['model'])
        model = mage_model.vqgan
        model.eval()
        save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/gen/origin-mage-large'
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('VQ-GAN with pixel reconstruction:')
    os.makedirs(save_path, exist_ok=True)
    run_one_image(f_image, model, save_path, 'f')
    run_one_image(r_image, model, save_path, 'r')
    print('MAE pixel reconstruction finished !')
elif mode == 2:
    chkpt_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/output_dir/1-encoder-clone-5e-5-all-1.0/checkpoint-4.pth'
    model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('MAE with pixel reconstruction:')
    save_path = '/root/autodl-tmp/img2img_unlearning/mae-main/figures/gen/1-encoder-clone-5e-5-all-1.0/check4'
    os.makedirs(save_path, exist_ok=True)
    run_one_image(f_image, model, save_path, 'f')
    run_one_image(r_image, model, save_path, 'r')
    print('MAE pixel reconstruction finished !')