import sys
import os
import requests
import torch
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import datasets, transforms
from omegaconf import OmegaConf
from model.vqgan import VQModel
from mask import mask_center_patches, mask_verge_patches, mask_top_patches, mask_bottom_patches, mask_left_patches, mask_right_patches
from tqdm import tqdm

import models_mage


function_map = {
    'center': mask_center_patches,
    'verge': mask_verge_patches,
    'top': mask_top_patches,
    'bottom': mask_bottom_patches,
    'left': mask_left_patches,
    'right': mask_right_patches
}


def process_and_save_image(image, path):
    # Apply normalization: image is assumed to be a torch tensor
    processed_image = torch.clip((image) * 255, 0, 255).int()
    # Convert to numpy array
    processed_image = processed_image.numpy().astype(np.uint8)
    # Convert numpy array to PIL Image
    processed_image = Image.fromarray(processed_image)
    # Save the image
    processed_image.save(path)
    

def gen_image(imgs, model, batch_index, seed, save_path, set_type, mask_ratio, mask_type):
    model.mask_flag = False
    model.drop_flag = False
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    codebook_emb_dim = 256
    codebook_size = 1024
    mask_token_id = model.mask_token_label
    unknown_number_in_the_beginning = 256
    _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
    cur_bsz = imgs.shape[0]

    #print(f'imgs shape: {imgs.shape}')
    x = torch.einsum('nchw->nhwc', imgs).detach().cpu()
    
    if mask_type == 'random':
        model.mask_type = 'random'
    else:
        model.mask_type = mask_type
        
    # Encode the input image
    with torch.no_grad():
        latent, gt_indices, token_drop_mask, token_all_mask = model.forward_encoder(imgs, mask_ratio)
    #print(f'latent shape: {latent.shape}, gt_indices shape: {gt_indices.shape}')
    #print(f'token_drop_mask shape: {token_drop_mask.shape}, token_all_mask shape: {token_all_mask.shape}')

    # Check the range of gt_indices before using it for indexing
    if not torch.all((gt_indices >= 0) & (gt_indices < codebook_size)):
        raise ValueError("gt_indices has values out of range. It should be within [0, {})".format(codebook_size))

    # decoder
    with torch.no_grad():
        token_drop_mask = torch.zeros_like(token_all_mask)
        
        logits = model.forward_decoder(latent, token_drop_mask, token_all_mask)
        #print(f'logits shape: {logits.shape}')
        logits = logits[:, 1:, :codebook_size]

    mode = 'argmax'
    # get token prediction
    if mode == 'random':
        sample_dist = torch.distributions.categorical.Categorical(logits=logits)
        sampled_ids = sample_dist.sample()
    elif mode == 'argmax':
        sampled_ids = logits.argmax(dim=-1)
    #print(f'sampled_ids shape: {sampled_ids.shape}')

    # gen
    gen_z_q = model.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(cur_bsz, 16, 16, codebook_emb_dim))
    gen_images = model.vqgan.decode(gen_z_q)
    #print(f'gen_z_q shape: {gen_z_q.shape}, gen_images shape: {gen_images.shape}')
    
    y = torch.einsum('nchw->nhwc', gen_images).detach().cpu()
    
    if mask_type == 'random':
        # 这里的随机 mask 其实也可以简单实现，但是无法保证后续恢复过程中的随机和这里的 mask 一致
        mask_id = 1
        mask_indices = gt_indices.clone()
        # Apply the mask_id to the positions where masks is 1
        mask_indices[token_all_mask[:,1:] == 1] = mask_id
        mask_gen_z_q = model.vqgan.quantize.get_codebook_entry(mask_indices.long(), shape=(cur_bsz, 16, 16, codebook_emb_dim))
        mask_gen_images = model.vqgan.decode(mask_gen_z_q)
        x_mask = torch.einsum('nchw->nhwc', mask_gen_images).detach().cpu()
    else:
        x_mask, mask = function_map[mask_type](imgs, mask_ratio=mask_ratio)
        x_mask = torch.einsum('nchw->nhwc', x_mask).detach().cpu()
        
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    im_paste_new =  x * mask + y * (1 - mask)
    #print(f'mask shape: {token_all_mask.shape}, mask: {token_all_mask[1,1:8]}')
    
    # VQ-GAN reconstruction pasted with visible pixels
    im_paste_id = gt_indices.clone() * (1 - token_all_mask[:,1:]) + sampled_ids * (token_all_mask[:,1:])
    im_paste_gen_z_q = model.vqgan.quantize.get_codebook_entry(im_paste_id.long(), shape=(cur_bsz, 16, 16, codebook_emb_dim))
    im_paste_gen_images = model.vqgan.decode(im_paste_gen_z_q)
    im_paste = torch.einsum('nchw->nhwc', im_paste_gen_images).detach().cpu()
    
    for i in range(x.shape[0]):
        index = bsz*batch_index + i
        if set_type == 'f':
            # Save images to corresponding folders with processing
            #process_and_save_image(x[i], os.path.join(save_path+'/original/forget', '{}.png'.format(index)))
            #process_and_save_image(x_mask[i], os.path.join(save_path+'/masked/forget', '{}.png'.format(index)))
            #process_and_save_image(y[i], os.path.join(save_path+'/reconstruction/forget', '{}.png'.format(index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible/forget', '{}.png'.format(index)))
            #process_and_save_image(im_paste_new[i], os.path.join(save_path+'/reconstruction_visible_new/forget', '{}.png'.format(index)))
        elif set_type == 'r':
            #process_and_save_image(x[i], os.path.join(save_path+'/original/retain', '{}.png'.format(index)))
            #process_and_save_image(x_mask[i], os.path.join(save_path+'/masked/retain', '{}.png'.format(index)))
            #process_and_save_image(y[i], os.path.join(save_path+'/reconstruction/retain', '{}.png'.format(index)))
            process_and_save_image(im_paste[i], os.path.join(save_path+'/reconstruction_visible/retain', '{}.png'.format(index)))
            #process_and_save_image(im_paste_new[i], os.path.join(save_path+'/reconstruction_visible_new/retain', '{}.png'.format(index)))

config = OmegaConf.load('config/vqgan.yaml').model
bsz = 128
seed = 16

mode = 2
mask_type = 'center'
mask_ratio = 0.25

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 假设你的图片文件路径是 'path_to_your_local_image.jpg'
forget_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_forget_100/train'
retain_local_image_path = '/root/autodl-tmp/img2img_unlearning/mae-mage-data/imagenet_retain_100/train'

# 创建一个变换组合，包括尺寸变换和ToTensor变换
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 将图像尺寸调整为256x256
    transforms.ToTensor()            # 将图像转换为Tensor
])

# 打开图像文件
f_img = datasets.ImageFolder(forget_local_image_path, transform=transform)
r_img = datasets.ImageFolder(retain_local_image_path, transform=transform)

f_data_loader = DataLoader(f_img, batch_size=bsz, shuffle=False)
r_data_loader = DataLoader(r_img, batch_size=bsz, shuffle=False)


# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth


# 推理的时候不要mask

if mode == 1:
    vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
    mage_model = models_mage.__dict__['mage_vit_large_patch16'](norm_pix_loss=False,
                                     mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                     mask_ratio_min=0.50, mask_ratio_max=1.0,
                                     vqgan_ckpt_path=vqgan_ckpt_path).cuda()
    checkpoint = torch.load('/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/magec-vitl-1600.pth', map_location='cpu')
    mage_model.load_state_dict(checkpoint['model'])
    model = mage_model
    model.eval()
    save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/Argmax-origin-mage-large/bottom-0.25'
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('VQ-GAN with pixel reconstruction:')
    os.makedirs(save_path+'/original/forget', exist_ok=True)
    os.makedirs(save_path+'/original/retain', exist_ok=True)
    
    os.makedirs(save_path+'/masked/forget', exist_ok=True)
    os.makedirs(save_path+'/masked/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible_new/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible_new/retain', exist_ok=True)
    
    flag = True
    min_num_batch = 11
    max_num_batch = 12
    for batch_index, (f_images, labels) in tqdm(enumerate(f_data_loader), total=len(f_data_loader)):
        if flag:
            if batch_index < min_num_batch:
                continue
            elif batch_index >= min_num_batch and batch_index < max_num_batch:
                gen_image(f_images.to(device), model, batch_index, seed, save_path, 'f', mask_ratio, mask_type)
            else:
                break
        else:
            gen_image(f_images.to(device), model, batch_index, seed, save_path, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in tqdm(enumerate(r_data_loader), total=len(r_data_loader)):
        if flag:
            if batch_index < min_num_batch:
                continue
            elif batch_index >= min_num_batch and batch_index < max_num_batch:
                gen_image(r_images.to(device), model, batch_index, seed, save_path, 'r', mask_ratio, mask_type)
            else:
                break
        else:
            gen_image(r_images.to(device), model, batch_index, seed, save_path, 'r', mask_ratio, mask_type)
            
    print('MAGE pixel reconstruction finished !')
else:
    vqgan_ckpt_path = '/root/autodl-tmp/img2img_unlearning/mage-main/checkpoints/vqgan_jax_strongaug.ckpt'
    mage_model = models_mage.__dict__['mage_vit_large_patch16'](norm_pix_loss=False,
                                     mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                     mask_ratio_min=0.50, mask_ratio_max=1.0,
                                     vqgan_ckpt_path=vqgan_ckpt_path).cuda()
    checkpoint = torch.load('/root/autodl-tmp/img2img_unlearning/mage-main/output_dir/freeze_1-1encoder-clone-1e-4-all-6.3-random-x/checkpoint-3.pth', map_location='cpu')
    mage_model.load_state_dict(checkpoint['model'])
    model = mage_model.to(device)
    model.eval()
    save_path = '/root/autodl-tmp/img2img_unlearning/mage-main/figures/freeze_1-1encoder-clone-1e-4-all-6.3-random-x/Argmax-check3-bottom-0.25'
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)

    print('VQ-GAN with pixel reconstruction:')
    os.makedirs(save_path+'/original/forget', exist_ok=True)
    os.makedirs(save_path+'/original/retain', exist_ok=True)
    
    os.makedirs(save_path+'/masked/forget', exist_ok=True)
    os.makedirs(save_path+'/masked/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible/retain', exist_ok=True)
    
    os.makedirs(save_path+'/reconstruction_visible_new/forget', exist_ok=True)
    os.makedirs(save_path+'/reconstruction_visible_new/retain', exist_ok=True)
    
    flag = True
    min_num_batch = 11
    max_num_batch = 12
    for batch_index, (f_images, labels) in tqdm(enumerate(f_data_loader), total=len(f_data_loader)):
        if flag:
            if batch_index < min_num_batch:
                continue
            elif batch_index >= min_num_batch and batch_index < max_num_batch:
                gen_image(f_images.to(device), model, batch_index, seed, save_path, 'f', mask_ratio, mask_type)
            else:
                break
        else:
            gen_image(f_images.to(device), model, batch_index, seed, save_path, 'f', mask_ratio, mask_type)
    for batch_index, (r_images, labels) in tqdm(enumerate(r_data_loader), total=len(r_data_loader)):
        if flag:
            if batch_index < min_num_batch:
                continue
            elif batch_index >= min_num_batch and batch_index < max_num_batch:
                gen_image(r_images.to(device), model, batch_index, seed, save_path, 'r', mask_ratio, mask_type)
            else:
                break
        else:
            gen_image(r_images.to(device), model, batch_index, seed, save_path, 'r', mask_ratio, mask_type)
            
    print('MAGE pixel reconstruction finished !')
