import sys
import os
import requests
import torch
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
from mask import mask_center_patches
from torch.utils.data import DataLoader
import models_mae
from tqdm import tqdm

# define the utils

def process_and_save_image(image, path):
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    # Apply normalization: image is assumed to be a torch tensor
    processed_image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    # Convert to numpy array and ensure it has type uint8 for saving
    processed_image = processed_image.detach().cpu().numpy().astype(np.uint8)
    # Convert numpy array to PIL Image
    processed_image = Image.fromarray(processed_image)
    # Save the image
    processed_image.save(path)

def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model, save_path, bsz, batch_index, set_type, mask_ratio, mask_type):
    x = torch.tensor(img).clone().to(device)
    #print(f'img shape: {img.shape}')

    # make it a batch-like
    #x = torch.einsum('nhwc->nchw', x)
    #print(f'x shape: {x.shape}')

    if mask_type == 'random':
        model.mask_type = 'random'
    elif mask_type == 'center':
        model.mask_type = 'center'
        
    # run MAE
    loss, y, mask = model(torch.tensor(img).to(device), mask_ratio=mask_ratio)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).to(device)

    # mask
    if mask_type == 'random':
        mask = mask.detach()
        mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
        mask = model.unpatchify(mask).to(device)  # 1 is removing, 0 is keeping
        #print(f'x shape: {x.shape}---y shape: {y.shape}---mask shape: {mask.shape}')
        im_masked = torch.einsum('nchw->nhwc', x * (1 - mask)).to(device)
        mask = torch.einsum('nchw->nhwc', mask).to(device)
    elif mask_type == 'center':
        im_masked, mask = mask_center_patches(x, mask_ratio)
        im_masked = torch.einsum('nchw->nhwc', im_masked).to(device)
        mask = torch.einsum('nchw->nhwc', 1-mask).to(device)
    
    x = torch.einsum('nchw->nhwc', x)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask
    
    for i in range(x.shape[0]):
        index = bsz*batch_index + i
        if set_type == 'f':
            # Save images to corresponding folders with processing
            process_and_save_image(x[i], os.path.join(save_path+'/original/forget', '{}.png'.format(index)))
            process_and_save_image(im_masked[i], os.path.join(save_path+'/masked/forget', '{}.png'.format(index)))
            process_and_save_image(y[i], os.path.join(save_path+'/reconstruction/forget', '{}.png'.format(index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible/forget', '{}.png'.format(index)))
        elif set_type == 'r':
            process_and_save_image(x[i], os.path.join(save_path+'/original/retain', '{}.png'.format(index)))
            process_and_save_image(im_masked[i], os.path.join(save_path+'/masked/retain', '{}.png'.format(index)))
            process_and_save_image(y[i], os.path.join(save_path+'/reconstruction/retain', '{}.png'.format(index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible/retain', '{}.png'.format(index)))
    
# load an image
# img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
# img = Image.open(requests.get(img_url, stream=True).raw)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mode = 2
mask_type = 'random'
mask_ratio = 0.50

bsz = 128

forget_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_forget_100/train'
retain_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_retain_100/train'

# 创建一个变换组合，包括尺寸变换和ToTensor变换
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图像尺寸调整为256x256
    transforms.ToTensor(),            # 将图像转换为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 打开图像文件
f_img = datasets.ImageFolder(forget_local_image_path, transform=transform)
r_img = datasets.ImageFolder(retain_local_image_path, transform=transform)

f_data_loader = DataLoader(f_img, batch_size=bsz, shuffle=False)
r_data_loader = DataLoader(r_img, batch_size=bsz, shuffle=False)

# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth


if mode == 1:
    chkpt_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/checkpoints/mae_visualize_vit_large_ganloss.pth'
    model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16').to(device)
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('MAE with pixel reconstruction:')
    save_path = '/root/autodl-tmp/img2img_unlearning/mae-main/figures/origin/random-0.5-new'
    
    os.makedirs(save_path+'/original/forget', exist_ok=True)
    os.makedirs(save_path+'/original/retain', exist_ok=True)
    
    os.makedirs(save_path+'/masked/forget', exist_ok=True)
    os.makedirs(save_path+'/masked/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible/retain', exist_ok=True)
    
    flag = False
    num_batch = 1000
    for batch_index, (f_images, labels) in tqdm(enumerate(f_data_loader), total=len(f_data_loader)):
        if flag:
            if batch_index < num_batch:
                with torch.no_grad():
                    run_one_image(f_images, model_mae, save_path, bsz, batch_index, 'f', mask_ratio, mask_type)
            else:
                break
        else:
            with torch.no_grad():
                run_one_image(f_images, model_mae, save_path, bsz, batch_index, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in tqdm(enumerate(r_data_loader), total=len(r_data_loader)):
        if flag:
            if batch_index < num_batch:
                with torch.no_grad():
                    run_one_image(r_images, model_mae, save_path, bsz, batch_index, 'r', mask_ratio, mask_type)
            else:
                break
        else:
            with torch.no_grad():
                run_one_image(r_images, model_mae, save_path, bsz, batch_index, 'r', mask_ratio, mask_type)
            
    print('MAE pixel reconstruction finished !')

elif mode == 2:
    chkpt_dir = '/root/autodl-tmp/img2img_unlearning/mae-main/output_dir/ours/checkpoint-4.pth'
    model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16').to(device)
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('MAE with pixel reconstruction:')
    save_path = '/root/autodl-tmp/img2img_unlearning/mae-main/figures/ours/check4-random-0.50-new'
    os.makedirs(save_path+'/original/forget', exist_ok=True)
    os.makedirs(save_path+'/original/retain', exist_ok=True)
    
    os.makedirs(save_path+'/masked/forget', exist_ok=True)
    os.makedirs(save_path+'/masked/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible/retain', exist_ok=True)
    
    flag = False
    num_batch = 1000
    for batch_index, (f_images, labels) in tqdm(enumerate(f_data_loader), total=len(f_data_loader)):
        if flag:
            if batch_index < num_batch:
                with torch.no_grad():
                    run_one_image(f_images, model_mae, save_path, bsz, batch_index, 'f', mask_ratio, mask_type)
            else:
                break
        else:
            with torch.no_grad():
                run_one_image(f_images, model_mae, save_path, bsz, batch_index, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in tqdm(enumerate(r_data_loader), total=len(r_data_loader)):
        if flag:
            if batch_index < num_batch:
                with torch.no_grad():
                    run_one_image(r_images, model_mae, save_path, bsz, batch_index, 'r', mask_ratio, mask_type)
            else:
                break
        else:
            with torch.no_grad():
                run_one_image(r_images, model_mae, save_path, bsz, batch_index, 'r', mask_ratio, mask_type)
            
    print('MAE pixel reconstruction finished !')