
import torch
from torchvision import transforms
# from torchvision.transforms import functional as F
from torch.nn import functional as F
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import AutoencoderKL,  DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers , DDPMScheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")
    
vae= AutoencoderKL.from_pretrained(model_id, subfolder="vae", )
unet = UNet2DConditionModel.from_pretrained(
        model_id, subfolder="unet", 
)
text_encoder_cls = import_model_class_from_model_name_or_path(model_id, None)
low_scheduler = DDPMScheduler.from_pretrained( model_id, subfolder="low_res_scheduler")

text_encoder = text_encoder_cls.from_pretrained(
        model_id,
        subfolder="text_encoder",
    )

# %%
tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            subfolder="tokenizer",
            use_fast=False,
        )

# map all the model to cuda
device = "cuda"
vae = vae.to(device)
unet = unet.to(device)
text_encoder = text_encoder.to(device)

from torchvision import transforms
import numpy as np
pos_text = 'a photo of a person, high quality, high resolution, clear background, noise free'
pos_token = tokenizer(pos_text,
            truncation=True,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.to(device)
pos_embedding = text_encoder(pos_token)[0].to('cpu')
neg_text = 'noisy, lowres, low quality'
neg_token = tokenizer(neg_text,
            truncation=True,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.to(device)
neg_embedding = text_encoder(neg_token)[0].to('cpu')

# %%
from PIL import Image
from pathlib import Path
def load_data(data_dir, shape = (512, 512)):
    
    import numpy as np
    def image_to_numpy(image):
        return np.array(image).astype(np.uint8)
    # more robust loading to avoid loaing non-image files
    images_path  = [] 
    idd=False
    for i in os.listdir(data_dir):
        if '.' in i  and i.split('.')[1] in ["jpg", "png", "jpeg"]:
            if 'noisy' in i:
                idd=True
            images_path.append(i.replace('noisy_', ''))
    sorted_img_path = sorted(images_path)
    
    if idd==True:
        sorted_img_path = [
            'noisy_' + i for i in sorted_img_path
        ]
    print(sorted_img_path)
    images = [image_to_numpy(Image.open(os.path.join(data_dir, i)).convert("RGB")) for i in sorted_img_path]
                               
    images = [Image.fromarray(i).resize(shape) for i in images]
    images = np.stack(images)
    # from B x H x W x C to B x C x H x W
    images = torch.from_numpy(images).permute(0, 3, 1, 2).float()
    # images = np.array(images).transpose(0, 3, 1, 2)
    assert images.shape[-1] == images.shape[-2]
    return sorted_img_path, images


import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch

model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16
)
noise_scheduler = pipeline.scheduler

for model in [vae, unet, text_encoder]:
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

from tqdm import tqdm

# step_size = 0.05
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--step_scale', type=float, default=0.01, help='Step scale value')
parser.add_argument('--steps', type=int , default=100, help='Number of steps')
#  --input_dir $1 --output_dir $2 \
parser.add_argument('--input_dir', type=str,  help='Input directory containing data files')
parser.add_argument('--output_dir', type=str, help='Output directory to save processed files')
parser.add_argument('--clean_img_path', type=str, help='Output directory to save processed files')
parser.add_argument('--no_codeformer', action='store_true', help='whether use codeformer in the process or not')

parser.add_argument('--perturb_r', type=float, default=16, help='Step scale value')
args = parser.parse_args()

noise_eps = args.perturb_r/255
scale = 127.5
noise_r = noise_eps * 255
steps = args.steps 
step_scale = args.step_scale
step_size = (noise_r/scale) / (steps * step_scale)


from basicsr.utils.registry import ARCH_REGISTRY

import os 
pretrain_model_url = {
    'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}
from basicsr.utils.download_util import load_file_from_url
device='cuda'
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
                                            connect_list=['32', '64', '128', '256']).to(device)

from diffshortcut.generic.tools import get_project_root
ADB_PROJECT_ROOT=get_project_root()
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], 
                                    model_dir=ADB_PROJECT_ROOT+'/weights/CodeFormer/', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)


# loop here 

img_paths, imgs = load_data(args.input_dir, shape=(512, 512))
hq_imgs_paths, hq_imgs=load_data(args.clean_img_path, shape=(512, 512))



device='cuda'
from torchvision.transforms.functional import normalize
imgs_list = []
for img_idx in range(len(img_paths)):
    
    
    tensor = imgs[img_idx].unsqueeze(0)
    tensor = tensor.to(device)
    # tensor.requires_grad = True
    tensor = tensor / 255
    tensor = tensor * 2 - 1
    # [-1, +1 ]
    images=tensor

    tensor_hq = hq_imgs[img_idx].unsqueeze(0)
    tensor_hq = tensor_hq.to(device)
    tensor_hq = tensor_hq / 255
    tensor_hq = tensor_hq * 2 - 1
    tensor_hq.requires_grad = True 
    tensor_hq_latent = vae.encode(tensor_hq).latent_dist.sample() * vae.config.scaling_factor
    # dummy_img

    # ori_images = images.detach().clone()
    ori_image = tensor.detach().clone()

    import torch.nn.functional as F
    noise_scheduler = pipeline.scheduler
    do_classifier_free_guidance = 9.0
    noise_level_r =20 
    from tqdm import tqdm



    for stepi in tqdm(range(steps)):
        images_stepi = images.detach().to(device)
        images_stepi.requires_grad = True
        
        if args.no_codeformer:
            _tensor = images_stepi + 0.0
        else :
            # [0, 1]
            img =( images_stepi +1 )/2
            normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).to(device)
            img = img.to(device)
            w=0.5

            # -1 +1 input, 
            output = net(img, w=w, adain=True)[0]
            min_max=(-1, +1)
            _tensor = output.float().clamp_(*min_max)
            _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
            _tensor = torch.clamp(_tensor, 0, 255)
            _tensor = _tensor / 255
            _tensor = _tensor * 2 - 1
            # ensure [-1, +1]
        
        # images_stepi_from_codeformer = _tensor
        images_stepi_from_codeformer = F.interpolate(_tensor, size=(128, 128), mode='bilinear', align_corners=True)
        
        noise_level = torch.tensor([noise_level_r], dtype=torch.long, device=device)
        noise = torch.randn(images_stepi_from_codeformer.shape, device=device,  layout=torch.strided)
        images_stepi_corr = low_scheduler.add_noise(images_stepi_from_codeformer, noise, noise_level)
        
        batch_multiplier = 2
        images_stepi_twice = torch.cat([images_stepi_corr] * batch_multiplier * 1)
        noise_level = torch.cat([noise_level] * images_stepi_twice.shape[0])
        neg_embedding=neg_embedding.detach().clone().to(device)
        pos_embedding=pos_embedding.detach().clone().to(device)
        prompt_embeds = torch.cat([neg_embedding, pos_embedding] )
        
        height, width = images_stepi_twice.shape[2:]
        num_channels_latents = pipeline.vae.config.latent_channels
        latents =  pipeline.prepare_latents(
            1 * 1,
            num_channels_latents,
            height,
            width,
            device=device,
            generator=None, 
            latents=None,
            dtype=None
        )
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0,
            noise_scheduler.config.num_train_timesteps,
            (bsz,),
            device=latents.device,
        )
        timesteps = timesteps.long()
        
        noise = torch.randn_like(latents)
        
        latents = torch.cat([latents] * 2)
        latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        
        latent_model_input = torch.cat([latents, images_stepi_twice], dim=1)


        model_pred = unet(latent_model_input, timesteps, prompt_embeds, class_labels=noise_level,).sample

        noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
        noise_pred = noise_pred_uncond + do_classifier_free_guidance * (noise_pred_text - noise_pred_uncond)
        
        latent_target = tensor_hq_latent.detach().clone()
        
        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latent_target, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
        
        loss.backward()
        
        with torch.no_grad():
            grad = images_stepi.grad.data
            images_stepi = images_stepi.detach()
            images_stepi.add_(torch.sign(grad), alpha=step_size).clamp_(-1, +1)
            delta_images = images_stepi - ori_image.detach().to(device)
            delta_images.clamp_(min=-noise_r/scale, max=+noise_r/scale)
            # delta_list.append(delta_images)
            images = (ori_image + delta_images)
            images = torch.clamp(images, -1, 1)
        # clean up the computation graph 
        torch.cuda.empty_cache()


    images_renorm = (images + 1) / 2
    import torchvision
    from torchvision import transforms
    from PIL import Image
    import matplotlib.pyplot as plt
    numpy_arr = images_renorm[0].cpu().numpy().transpose(1, 2, 0)
    imgs_name = img_paths[img_idx].split('/')[-1]
    output_path = os.path.join(args.output_dir, imgs_name)
    imgs_list.append(
        (output_path, numpy_arr)
    )
    
# make dir args.output_dir
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir, exist_ok=True)
    
for d in imgs_list:
    plt.imsave(d[0], d[1])