import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image
from mask import mask_center_patches
import models_mae

# define the utils

def process_and_save_image(image, path):
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
    imagenet_std = torch.tensor([0.229, 0.224, 0.225])
    # Apply normalization: image is assumed to be a torch tensor
    processed_image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    # Convert to numpy array and ensure it has type uint8 for saving
    processed_image = processed_image.numpy().astype(np.uint8)
    # Save the image
    plt.imsave(path, processed_image)

def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model, save_path, set_type, mask_ratio, mask_type):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    if mask_type == 'random':
        model.mask_type = 'random'
    elif mask_type == 'center':
        model.mask_type = 'center'
        
    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=mask_ratio)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # mask
    if mask_type == 'random':
        mask = mask.detach()
        mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
        mask = model.unpatchify(mask).detach().cpu()  # 1 is removing, 0 is keeping
        im_masked = torch.einsum('nchw->nhwc', x * (1 - mask)).detach().cpu()
        mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    elif mask_type == 'center':
        im_masked, mask = mask_center_patches(x, mask_ratio)
        im_masked = torch.einsum('nchw->nhwc', im_masked).detach().cpu()
        mask = torch.einsum('nchw->nhwc', 1-mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    if set_type == 'f':
        # Save images to corresponding folders with processing
        process_and_save_image(x[0], os.path.join(save_path, 'f_original.png'))
        process_and_save_image(im_masked[0], os.path.join(save_path, 'f_masked.png'))
        process_and_save_image(y[0], os.path.join(save_path, 'f_reconstruction.png'))
        process_and_save_image(im_paste[0], os.path.join(save_path, 'f_reconstruction_visible.png'))
    elif set_type == 'r':
        process_and_save_image(x[0], os.path.join(save_path, 'r_original.png'))
        process_and_save_image(im_masked[0], os.path.join(save_path, 'r_masked.png'))
        process_and_save_image(y[0], os.path.join(save_path, 'r_reconstruction.png'))
        process_and_save_image(im_paste[0], os.path.join(save_path, 'r_reconstruction_visible.png'))
    
# load an image
# img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
# img = Image.open(requests.get(img_url, stream=True).raw)

from PIL import Image

# 假设你的图片文件路径是 'path_to_your_local_image.jpg'
forget_local_image_path = '/root/autodl-tmp/img2img_unlearning/data/imagenet_forget_100/train/n01770393/n01770393_1443.JPEG'
# n03633091_17866.JPEG n02087394_16203.JPEG  n01770393_1443.JPEG
retain_local_image_path = '/root/autodl-tmp/img2img_unlearning/data/imagenet_retain_100/train/n02086646/n02086646_1725.JPEG'

def img_open(path):
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    img = Image.open(path)

    img = img.resize((224, 224))
    img = np.array(img) / 255.

    assert img.shape == (224, 224, 3)

    # normalize by ImageNet mean and std
    img = img - imagenet_mean
    img = img / imagenet_std
    
    # plt.rcParams['figure.figsize'] = [5, 5]
    # show_image(torch.tensor(img))
    
    return img

f_image = img_open(forget_local_image_path)
r_image = img_open(retain_local_image_path)

# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

mode = 2

if mode == 1:
    chkpt_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/checkpoints/mae_visualize_vit_large_ganloss.pth'
    model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('MAE with pixel reconstruction:')
    save_path = '/root/autodl-tmp/img2img_unlearning/mae-main/figures/gen/origin/mask0.5'
    os.makedirs(save_path, exist_ok=True)
    run_one_image(f_image, model_mae, save_path, 'f', mask_type = 'random', mask_ratio = 0.50)
    run_one_image(r_image, model_mae, save_path, 'r', mask_type = 'random', mask_ratio = 0.50)
    print('MAE pixel reconstruction finished !')
elif mode == 2:
    chkpt_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/output_dir/1-encoder-clone-5e-5-0.5-1.0/checkpoint-4.pth'
    model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('MAE with pixel reconstruction:')
    save_path = '/root/autodl-tmp/img2img_unlearning/mae-main/figures/gen/1-encoder-clone-5e-5-0.5-1.0/check4-mask0.5'
    os.makedirs(save_path, exist_ok=True)
    run_one_image(f_image, model_mae, save_path, 'f', mask_type = 'random', mask_ratio = 0.50)
    run_one_image(r_image, model_mae, save_path, 'r', mask_type = 'random', mask_ratio = 0.50)
    print('MAE pixel reconstruction finished !')