
import torch
from torchvision import transforms
# from torchvision.transforms import functional as F
from torch.nn import functional as F
from transformers import AutoTokenizer, PretrainedConfig

from torchvision import transforms
import numpy as np

# %%
from PIL import Image
from pathlib import Path
def load_data(data_dir, shape = (512, 512)):
    
    import numpy as np
    def image_to_numpy(image):
        return np.array(image).astype(np.uint8)
    # more robust loading to avoid loaing non-image files
    images_path  = [] 
    idd=False
    for i in os.listdir(data_dir):
        if '.' in i  and i.split('.')[1] in ["jpg", "png", "jpeg"]:
            if 'noisy' in i:
                idd=True
            images_path.append(i.replace('noisy_', ''))
    sorted_img_path = sorted(images_path)
    
    if idd==True:
        sorted_img_path = [
            'noisy_' + i for i in sorted_img_path
        ]
    print(sorted_img_path)
    images = [image_to_numpy(Image.open(os.path.join(data_dir, i)).convert("RGB")) for i in sorted_img_path]
                               
    images = [Image.fromarray(i).resize(shape) for i in images]
    images = np.stack(images)
    # from B x H x W x C to B x C x H x W
    images = torch.from_numpy(images).permute(0, 3, 1, 2).float()
    # images = np.array(images).transpose(0, 3, 1, 2)
    assert images.shape[-1] == images.shape[-2]
    return sorted_img_path, images


import requests
from PIL import Image
import torch 
from tqdm import tqdm

# step_size = 0.05
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--step_scale', type=float, default=0.01, help='Step scale value')
parser.add_argument('--steps', type=int , default=100, help='Number of steps')
#  --input_dir $1 --output_dir $2 \
parser.add_argument('--input_dir', type=str,  help='Input directory containing data files')
parser.add_argument('--output_dir', type=str, help='Output directory to save processed files')
parser.add_argument('--clean_img_path', type=str, help='Output directory to save processed files')
parser.add_argument('--perturb_r', type=float, default=16, help='Step scale value')
args = parser.parse_args()

noise_eps = args.perturb_r/255
scale = 127.5
noise_r = noise_eps * 255
steps = args.steps 
step_scale = args.step_scale
step_size = (noise_r/scale) / (steps * step_scale)


from basicsr.utils.registry import ARCH_REGISTRY

import os 
pretrain_model_url = {
    'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}
from basicsr.utils.download_util import load_file_from_url
device='cuda'
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
                                            connect_list=['32', '64', '128', '256']).to(device)
# import os 
# ADB_PROJECT_ROOT=os.environ['ADB_PROJECT_ROOT']
# if ADB_PROJECT_ROOT[-1] == '/':
#     ADB_PROJECT_ROOT = ADB_PROJECT_ROOT[:-1]
from diffshortcut.generic.tools import get_project_root
ADB_PROJECT_ROOT=get_project_root()
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], 
                                    model_dir=ADB_PROJECT_ROOT+'/weights/CodeFormer/', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)


# loop here 

img_paths, imgs = load_data(args.input_dir, shape=(512, 512))
hq_imgs_paths, hq_imgs=load_data(args.clean_img_path, shape=(512, 512))



device='cuda'
from torchvision.transforms.functional import normalize
imgs_list = []
for img_idx in range(len(img_paths)):
    
    
    tensor = imgs[img_idx].unsqueeze(0)
    tensor = tensor.to(device)
    # tensor.requires_grad = True
    tensor = tensor / 255
    tensor = tensor * 2 - 1
    # [-1, +1 ]
    images=tensor

    tensor_hq = hq_imgs[img_idx].unsqueeze(0)
    tensor_hq = tensor_hq.to(device)
    tensor_hq = tensor_hq / 255
    tensor_hq = tensor_hq * 2 - 1
    tensor_hq.requires_grad = True 
    # dummy_img

    # ori_images = images.detach().clone()
    ori_image = tensor.detach().clone()

    import torch.nn.functional as F
    do_classifier_free_guidance = 9.0
    noise_level_r =20 
    from tqdm import tqdm



    for stepi in tqdm(range(steps)):
        images_stepi = images.detach().to(device)
        images_stepi.requires_grad = True
        
        img =( images_stepi +1 )/2
        normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).to(device)
        img = img.to(device)
        w=0.5

        output = net(img, w=w, adain=True)[0]
        min_max=(-1, +1)
        _tensor = output.float().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
        _tensor = torch.clamp(_tensor, 0, 255)
        _tensor = _tensor / 255
        _tensor = _tensor * 2 - 1
    
        loss = F.mse_loss(_tensor.float(), tensor_hq.float(), reduction="mean")
        
        loss.backward()
        
        with torch.no_grad():
            grad = images_stepi.grad.data
            images_stepi = images_stepi.detach()
            images_stepi.add_(torch.sign(grad), alpha=step_size).clamp_(-1, +1)
            delta_images = images_stepi - ori_image.detach().to(device)
            delta_images.clamp_(min=-noise_r/scale, max=+noise_r/scale)
            images = (ori_image + delta_images)
            images = torch.clamp(images, -1, 1)
        # clean up the computation graph 
        torch.cuda.empty_cache()


    images_renorm = (images + 1) / 2
    import torchvision
    from torchvision import transforms
    from PIL import Image
    import matplotlib.pyplot as plt
    numpy_arr = images_renorm[0].cpu().numpy().transpose(1, 2, 0)
    imgs_name = img_paths[img_idx].split('/')[-1]
    output_path = os.path.join(args.output_dir, imgs_name)
    imgs_list.append(
        (output_path, numpy_arr)
    )
    
# make dir args.output_dir
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir, exist_ok=True)
    
for d in imgs_list:
    plt.imsave(d[0], d[1])