from tqdm import tqdm
import torch
from PIL import Image
import os
import numpy as np
import argparse
from ip2p_utils import InstructPix2Pix, seed_everything
import clip
import torch.nn.functional as F
import glob
import torchvision



def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def pt_to_numpy(images):
    """
    Convert a pytorch tensor to a numpy image
    """
    images = images.detach().clone().cpu().permute(0, 2, 3, 1).float().numpy()
    return images

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def numpy_to_pil_rgb(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image).convert("RGB") for image in images]

    return pil_images

def PIL2Tensor(image, resize=True):
    
    image = [np.array(i) for i in image]
    
    if resize:
        image = [Image.fromarray(i).convert("RGB").resize((512, 512)) for i in image]
    else:
        image = [Image.fromarray(i).convert("RGB") for i in image]
    
    image =  [np.array(i)[None, :] for i in image] 
    image = np.concatenate(image, axis=0)
    image = np.array(image).astype(np.float32) / 255.0
    image = image.transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

def rbf_kernel(X, Y, gamma=-1, ad=1):
    # X and Y should be tensors with shape (batch_size, num_channels, height, width)
    # gamma is a hyperparameter controlling the width of the RBF kernel

    # Reshape X and Y to have shape (batch_size, num_channels*height*width)
    X_flat = X.view(X.size(0), -1)
    Y_flat = Y.view(Y.size(0), -1)

    # Compute the pairwise squared Euclidean distances between the samples
    with torch.cuda.amp.autocast():
        dists = torch.cdist(X_flat, Y_flat, p=2)**2

    if gamma <0: # use median trick
        gamma = torch.median(dists)
        gamma = torch.sqrt(0.5 * gamma / np.log(dists.size(0) + 1))
        gamma = 1 / (2 * gamma**2)
        # print(gamma)

    gamma = gamma * ad 
    # gamma = torch.max(gamma, torch.tensor(1e-3))
    # Compute the RBF kernel using the squared distances and gamma
    K = torch.exp(-gamma * dists)
    dK = -2 * gamma * K.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * (X.unsqueeze(1) - Y.unsqueeze(0))
    dK_dX = torch.sum(dK, dim=1)

    return K, dK_dX


def main(args):
    seed_everything(args.seed)
    diff_model = InstructPix2Pix(args.device, fp16=args.fp16)
    
    # load data
    tgt_images = []
    frame_idx = ['00000', '00001','00002','00003','00004','00005','00006','00007',]
    
    for idx in frame_idx:
        if args.frame:
            tgt_image = Image.open(os.path.join(args.data_path, f'frame_{idx}.'+args.data_format))
        else:
            tgt_image = Image.open(os.path.join(args.data_path, f'{idx}.'+args.data_format))
            
        src_w,src_h = tgt_image.size
        tgt_images.append(tgt_image)
    
    tgt_batch = PIL2Tensor(tgt_images).to(args.device)
    src_batch = PIL2Tensor(tgt_images).to(args.device)
    
    if args.fp16:
        tgt_batch = tgt_batch.half()
        src_batch = src_batch.half()
        
    batch_size = tgt_batch.size(0)

    # Get text embeddings
    if args.use_dir:
        dir_prompts = [", side view", ", front view", ", back view", ", side view"]
        neg_dir_prompts = ["front view, back view",
                    "back view, side view",
                    "front view, side view",
                    "front view, back view"]    
        src_prompts = [args.src_prompt+dir_p for dir_p in dir_prompts]
        tgt_prompts = [args.tgt_prompt+dir_p for dir_p in dir_prompts]
        neg_prompts = [args.neg_prompt+neg_p for neg_p in neg_dir_prompts]
    else:
        src_prompts = [args.src_prompt] * batch_size
        tgt_prompts = [args.tgt_prompt] * batch_size
        neg_prompts = [args.neg_prompt] * batch_size
         
    src_text_embeddings = diff_model.pipe._encode_prompt(
        src_prompts, device=args.device, num_images_per_prompt=1, 
        do_classifier_free_guidance=True, negative_prompt=neg_prompts
    ) 
    tgt_text_embeddings = diff_model.pipe._encode_prompt(
        tgt_prompts, device=args.device, num_images_per_prompt=1, 
        do_classifier_free_guidance=True, negative_prompt=neg_prompts
    ) 

    # initialize src and tgt latents
    src_latents = diff_model.imgs_to_latent(src_batch).detach().clone().to(args.device)
    src_latents.requires_grad = False 
    
    tgt_latents = diff_model.imgs_to_latent(tgt_batch).detach().clone().to(args.device)
    tgt_latents.requires_grad = True

    uncond_image_latents = torch.zeros_like(tgt_latents)
    image_cond_latents = torch.cat([src_latents, src_latents, uncond_image_latents], dim=0)
    
    rgb = diff_model.latents_to_img(tgt_latents)
    img = pt_to_numpy(rgb)
    img = numpy_to_pil(img)
    img = [image.resize((src_w, src_h)) for image in img]
    img = image_grid(img,rows=args.rows, cols=args.cols)
    img.save(os.path.join(args.save_path,f'org.jpg'))
    
        
    # optimizer 
    optimizer = torch.optim.SGD([tgt_latents], lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                                step_size=args.decay_iter, 
                                                gamma=args.decay_rate)

    for step in tqdm(range(args.num_steps)):    
        optimizer.zero_grad()
        # sample timestep
        t = torch.randint(args.min_step, args.max_step + 1, 
                          [1], dtype=torch.long, device=args.device)

        with torch.no_grad():
            # add noise
            noise = torch.randn_like(src_latents)
            src_latents_noisy = diff_model.scheduler.add_noise(src_latents, noise, t)
            tgt_latents_noisy = diff_model.scheduler.add_noise(tgt_latents, noise, t)
            
            src_model_input = torch.cat([src_latents_noisy] * 3)
            src_model_input = torch.cat([src_model_input, image_cond_latents], dim=1)
            src_noise_pred = diff_model.unet(
                src_model_input, t, encoder_hidden_states=src_text_embeddings).sample
            
            tgt_model_input = torch.cat([tgt_latents_noisy] * 3)
            tgt_model_input = torch.cat([tgt_model_input, image_cond_latents], dim=1)
            tgt_noise_pred = diff_model.unet(
                tgt_model_input, t, encoder_hidden_states=tgt_text_embeddings).sample

        # perform guidance (high scale from paper!)
        src_noise_pred_text, src_noise_pred_image, src_noise_pred_uncond = src_noise_pred.chunk(3)
        src_noise_pred = (
            src_noise_pred_uncond
            + args.guidance_scale * (src_noise_pred_text - src_noise_pred_image)
            + args.image_guidance_scale * (src_noise_pred_image - src_noise_pred_uncond)
        )
        tgt_noise_pred_text, tgt_noise_pred_image, tgt_noise_pred_uncond = tgt_noise_pred.chunk(3)
        tgt_noise_pred = (
            tgt_noise_pred_uncond
            + args.guidance_scale * (tgt_noise_pred_text - tgt_noise_pred_image)
            + args.image_guidance_scale * (tgt_noise_pred_image - tgt_noise_pred_uncond)
        )
        
        w_t = (1 - diff_model.alphas[t])
        if args.noise:
            noise = tgt_noise_pred - noise
        else:
            noise = tgt_noise_pred - src_noise_pred

        if args.svgd:
            with torch.cuda.amp.autocast():
                K, dK_dX = rbf_kernel(tgt_latents_noisy, tgt_latents_noisy, gamma=-1, ad=0.5)
                svgd_noise = torch.matmul(noise.transpose(0,3), K).transpose(0,3) + dK_dX
                
                grad = w_t * svgd_noise / K.size(0)
        else:
            grad = w_t * noise
            
        tgt_latents.backward(gradient=grad, retain_graph=True)

        optimizer.step()
        scheduler.step()

        if args.save_image:
            if step > 0 and step % args.save_freq == 0:
                rgb = diff_model.latents_to_img(tgt_latents)
                img = pt_to_numpy(rgb)
                img = numpy_to_pil(img)
                img = [image.resize((src_w, src_h)) for image in img]
                img = image_grid(img,rows=args.rows, cols=args.cols)
                img.save(os.path.join(args.save_path,f'{step}.jpg'))

    rgb = diff_model.latents_to_img(tgt_latents)
    tgt_images = pt_to_numpy(rgb)
    tgt_images = numpy_to_pil(tgt_images)
    tgt_images = [image.resize((src_w, src_h)) for image in tgt_images]

    if args.save_image:
        img_grid = image_grid(tgt_images,rows=args.rows, cols=args.cols)
        img_grid.save(os.path.join(args.save_path,f'{args.save_name}.jpg'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=int, default=3)
    parser.add_argument("--src_prompt", type=str, default="")
    parser.add_argument("--tgt_prompt", type=str, required=True)
    parser.add_argument("--neg_prompt", type=str, default="")
    parser.add_argument("--eval_prompt", type=str, default="")
    parser.add_argument("--min_step", type=int, default=20)
    parser.add_argument("--max_step", type=int, default=500)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--wd", type=float, default=0.0)
    parser.add_argument("--guidance_scale", type=float, default=7.5)
    parser.add_argument("--image_guidance_scale", type=float, default=1.5)
    parser.add_argument("--num_steps", type=int, default=200)
    parser.add_argument("--decay_iter", type=int, default=20)
    parser.add_argument("--decay_rate", type=float, default=0.9)
    parser.add_argument("--svgd", action="store_true")
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--use_dir", action="store_true")
    parser.add_argument("--save_freq", type=int, default=20)
    parser.add_argument("--data_format", type=str, default="png")
    parser.add_argument("--rows", default=2, type=int)
    parser.add_argument("--cols", default=2, type=int)
    parser.add_argument("--frame", action='store_true')
    parser.add_argument("--eval", action='store_true')
    parser.add_argument("--save_image", action='store_true')
    parser.add_argument("--fp16", action='store_true')
    parser.add_argument("--noise", action='store_true')
    args = parser.parse_args()
    
    args.device = f'cuda:{args.device}'
    if args.save_image:
        args.save_name = f'ip2p_svgd_{args.svgd}_lr_{args.lr}_gs_{args.guidance_scale}_ig_{args.image_guidance_scale}_dr_{args.decay_rate}_max_{args.max_step}_ns_{args.num_steps}'
        os.makedirs(args.save_path, exist_ok=True)
    
    main(args)
