import sys
import os
import requests
import math
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import datasets, transforms
from omegaconf import OmegaConf
from model.vqgan import VQModel
from mask import mask_center_patches
from torch.utils.data import DataLoader
import models_mage

def process_and_save_image(image, path):
    # Apply normalization: image is assumed to be a torch tensor
    processed_image = torch.clip((image) * 255, 0, 255).int()
    # Convert to numpy array
    processed_image = processed_image.numpy().astype(np.uint8)
    # Convert numpy array to PIL Image
    processed_image = Image.fromarray(processed_image)
    # Save the image
    processed_image.save(path)
    
    
def mask_by_random_topk(mask_len, probs, temperature=1.0):
    mask_len = mask_len.squeeze()
    confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
    sorted_confidence, _ = torch.sort(confidence, axis=-1)
    # Obtains cut off threshold given the mask lengths.
    cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
    # Masks tokens with lower confidence.
    masking = (confidence <= cut_off)
    return masking


def gen_image(imgs, model, batch_index, seed, choice_temperature, num_iter, save_path, set_type, mask_ratio, mask_type):
    model.mask_flag = False
    model.drop_flag = False
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    codebook_emb_dim = 256
    codebook_size = 1024
    mask_token_id = model.mask_token_label
    unknown_number_in_the_beginning = 256
    _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
    cur_bsz = imgs.shape[0]

    #print(f'imgs shape: {imgs.shape}')
    x = torch.einsum('nchw->nhwc', imgs).detach().cpu()
    
    if mask_type == 'random':
        model.mask_type = 'random' 
    elif mask_type == 'center':
        model.mask_type = 'center'
        
    # Encode the input image
    with torch.no_grad():
        latent, gt_indices, token_drop_mask, token_all_mask = model.forward_encoder(imgs, mask_ratio)
    token_mask = token_all_mask.clone()
    #print(f'latent shape: {latent.shape}, gt_indices shape: {gt_indices.shape}')
    #print(f'token_drop_mask shape: {token_drop_mask.shape}, token_all_mask shape: {token_all_mask.shape}')
    
    # Check the range of gt_indices before using it for indexing
    if not torch.all((gt_indices >= 0) & (gt_indices < codebook_size)):
        raise ValueError("gt_indices has values out of range. It should be within [0, {})".format(codebook_size))
        
    token_indices = gt_indices.cuda()
        
    # 使用tqdm创建进度条
    for step in tqdm(range(num_iter), desc='Processing'):
        cur_ids = token_indices.clone().long()

        token_indices = torch.cat(
            [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = model.fake_class_label
        token_indices = token_indices.long()
        
        token_all_mask = token_indices == mask_token_id
        #print(f'token_all_mask is: {token_all_mask}')
        token_drop_mask = torch.zeros_like(token_indices)
        #print(f'token_drop_mask shape: {token_drop_mask.shape}, token_all_mask shape: {token_all_mask.shape}')
        
        with torch.no_grad():
            # token embedding
            input_embeddings = model.token_emb(token_indices)
            #print(f'input_embeddings shape is: {input_embeddings.shape}')
            
            # encoder
            forward_x = input_embeddings
            for blk in model.blocks:
                forward_x = blk(forward_x)
            forward_x = model.norm(forward_x)
            #print(f'forward_x shape is: {forward_x.shape}')

            # decoder
            logits = model.forward_decoder(forward_x, token_drop_mask, token_all_mask)
            logits = logits[:, 1:, :codebook_size]

        # get token prediction
        sample_dist = torch.distributions.categorical.Categorical(logits=logits)
        sampled_ids = sample_dist.sample()

        # get ids for next step
        unknown_map = (cur_ids == mask_token_id)
        sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
        # Defines the mask ratio for the next round. The number to mask out is
        # determined by mask_ratio * unknown_number_in_the_beginning.
        ratio = 1. * (step + 1) / num_iter

        new_mask_ratio = np.cos(math.pi / 2. * ratio)

        # sample ids according to prediction confidence
        probs = torch.nn.functional.softmax(logits, dim=-1)
        selected_probs = torch.squeeze(
            torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)

        selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()

        mask_len = torch.Tensor([np.floor(unknown_number_in_the_beginning * new_mask_ratio)]).cuda()
        # Keeps at least one of prediction in this round and also masks out at least
        # one and for the next iteration
        mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                 torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))

        # Sample masking tokens for next iteration
        masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
        # Masks tokens with lower confidence.
        token_indices = torch.where(masking, mask_token_id, sampled_ids)
        
    # vqgan visualization
    with torch.no_grad():
        z_q = model.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(cur_bsz, 16, 16, codebook_emb_dim))
        gen_images = model.vqgan.decode(z_q)
    # print(f'gen_z_q shape: {gen_z_q.shape}, gen_images shape: {gen_images.shape}')
    
    y = torch.einsum('nchw->nhwc', gen_images).detach().cpu()
    
    if mask_type == 'random':
        # 这里的随机 mask 其实也可以简单实现，但是无法保证后续恢复过程中的随机和这里的 mask 一致
        mask_id = 1
        mask_indices = gt_indices.clone()
        # Apply the mask_id to the positions where masks is 1
        mask_indices[token_mask[:,1:] == 1] = mask_id
        mask_gen_z_q = model.vqgan.quantize.get_codebook_entry(mask_indices.long(), shape=(cur_bsz, 16, 16, codebook_emb_dim))
        mask_gen_images = model.vqgan.decode(mask_gen_z_q)
        x_mask = torch.einsum('nchw->nhwc', mask_gen_images).detach().cpu()
    elif mask_type == 'center':
        x_mask, mask = mask_center_patches(imgs, mask_ratio=mask_ratio)
        x_mask = torch.einsum('nchw->nhwc', x_mask).detach().cpu()
        
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    im_paste_new =  x * mask + y * (1 - mask)
    
    # VQ-GAN reconstruction pasted with visible pixels
    im_paste_id = gt_indices.clone() * (1 - token_mask[:,1:]) + sampled_ids * (token_mask[:,1:])
    im_paste_gen_z_q = model.vqgan.quantize.get_codebook_entry(im_paste_id.long(), shape=(cur_bsz, 16, 16, codebook_emb_dim))
    im_paste_gen_images = model.vqgan.decode(im_paste_gen_z_q)
    im_paste = torch.einsum('nchw->nhwc', im_paste_gen_images).detach().cpu()
    
    for i in range(x.shape[0]):
        index = bsz*batch_index + i
        if set_type == 'f':
            # Save images to corresponding folders with processing
            process_and_save_image(x[i], os.path.join(save_path+'/original', 'f_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(x_mask[i], os.path.join(save_path+'/masked', 'f_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(y[i], os.path.join(save_path+'/reconstruction', 'f_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible', 'f_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(im_paste_new[i], os.path.join(save_path+'/reconstruction_visible_new', 'f_iter{}_{}.png'.format(num_iter, index)))
        elif set_type == 'r':
            process_and_save_image(x[i], os.path.join(save_path+'/original', 'r_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(x_mask[i], os.path.join(save_path+'/masked', 'r_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(y[i], os.path.join(save_path+'/reconstruction', 'r_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible', 'r_iter{}_{}.png'.format(num_iter, index)))
            process_and_save_image(im_paste_new[i], os.path.join(save_path+'/reconstruction_visible_new', 'r_iter{}_{}.png'.format(num_iter, index)))


from PIL import Image

config = OmegaConf.load('config/vqgan.yaml').model
bsz = 128
seed = 16
temp = 4.5
iters = 6

mode = 2
mask_type = 'center'
mask_ratio = 0.25

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 假设你的图片文件路径是 'path_to_your_local_image.jpg'
forget_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_forget_100/train'
retain_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_retain_100/train'

# 创建一个变换组合，包括尺寸变换和ToTensor变换
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 将图像尺寸调整为256x256
    transforms.ToTensor()            # 将图像转换为Tensor
])

# 打开图像文件
f_img = datasets.ImageFolder(forget_local_image_path, transform=transform)
r_img = datasets.ImageFolder(retain_local_image_path, transform=transform)

f_data_loader = DataLoader(f_img, batch_size=bsz, shuffle=False)
r_data_loader = DataLoader(r_img, batch_size=bsz, shuffle=False)

# f_img = transform(Image.open(forget_local_image_path))
# r_img = transform(Image.open(retain_local_image_path))


# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

# 推理的时候不要mask

if mode == 1:
    vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
    mage_model = models_mage.__dict__['mage_vit_large_patch16'](norm_pix_loss=False,
                                     mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                     mask_ratio_min=0.50, mask_ratio_max=1.0,
                                     vqgan_ckpt_path=vqgan_ckpt_path).cuda()
    checkpoint = torch.load('/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/magec-vitl-1600.pth', map_location='cpu')
    mage_model.load_state_dict(checkpoint['model'])
    model = mage_model
    model.eval()
    save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/gen/origin-mage-large'
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('VQ-GAN with pixel reconstruction:')
    os.makedirs(save_path+'/original', exist_ok=True)
    os.makedirs(save_path+'/masked', exist_ok=True)
    os.makedirs(save_path+'/reconstruction', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible_new', exist_ok=True)
    
    for batch_index, (f_images, labels) in enumerate(f_data_loader):
        if batch_index < 1:
            gen_image(f_images.to(device), model, batch_index, seed, temp, 1, save_path, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in enumerate(r_data_loader):
        if batch_index < 1:
            gen_image(r_images.to(device), model, batch_index, seed, temp, 1, save_path, 'r', mask_ratio, mask_type)
        
    print('MAGE pixel reconstruction finished !')
else:
    vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
    mage_model = models_mage.__dict__['mage_vit_large_patch16'](norm_pix_loss=False,
                                     mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                     mask_ratio_min=0.50, mask_ratio_max=1.0,
                                     vqgan_ckpt_path=vqgan_ckpt_path).cuda()
    checkpoint = torch.load('/root/autodl-tmp/img2img_unlearning/mage-main/output_dir/freeze_1-1encoder-clone-1e-4-one-center-x/checkpoint-3.pth', map_location='cpu')
    mage_model.load_state_dict(checkpoint['model'])
    model = mage_model.to(device)
    model.eval()
    save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/gen/freeze_1-1encoder-clone-1e-4-one-center-x/check3-center'
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('VQ-GAN with pixel reconstruction:')
    os.makedirs(save_path+'/original', exist_ok=True)
    os.makedirs(save_path+'/masked', exist_ok=True)
    os.makedirs(save_path+'/reconstruction', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible_new', exist_ok=True)
    
    for batch_index, (f_images, labels) in enumerate(f_data_loader):
        if batch_index < 1:
            gen_image(f_images.to(device), model, batch_index, seed, temp, 1, save_path, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in enumerate(r_data_loader):
        if batch_index < 1:
            gen_image(r_images.to(device), model, batch_index, seed, temp, 1, save_path, 'r', mask_ratio, mask_type)
        
    print('MAGE pixel reconstruction finished !')
