import torch
import torch.nn as nn
import torchvision
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import argparse
import torch.utils.checkpoint as checkpoint
import os, shutil
from PIL import Image
import time
from torch import autocast
from torch.cuda.amp import GradScaler
from transformers import CLIPModel, CLIPProcessor, AutoProcessor, AutoModel
from accelerate import Accelerator
import numpy as np

# Aesthetic Scorer
class MLPDiff(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )


    def forward(self, embed):
        return self.layers(embed)

MODEL_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14'
ASSETS_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/reward_optimization/reward_opt/assets'

class AestheticScorerDiff(torch.nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(MODEL_PATH)
        self.mlp = MLPDiff()
        state_dict = torch.load(os.path.join(ASSETS_PATH, "sac+logos+ava1-l14-linearMSE.pth"))
        self.mlp.load_state_dict(state_dict)
        self.dtype = dtype
        self.eval()

    def __call__(self, images):
        device = next(self.parameters()).device
        embed = self.clip.get_image_features(pixel_values=images)
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.mlp(embed).squeeze(1)

def aesthetic_loss_fn(device=None,
                     torch_dtype=None):
    
    target_size = 224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
    scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
    scorer.requires_grad_(False)
    target_size = 224
    def loss_fn(im_pix_un, prompts=None):
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        im_pix = torchvision.transforms.Resize(target_size)(im_pix)
        im_pix = normalize(im_pix).to(im_pix_un.dtype)
        rewards = scorer(im_pix)
        loss = -1 * rewards

        return loss
        
    return loss_fn

# HPS-v2
HPS_V2_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/HPS_v2_compressed.pt"
def hps_loss_fn(inference_dtype=None, device=None):
    import hpsv2
    from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer

    model_name = "ViT-H-14"
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name,
        "/mnt/workspace/workgroup/tangzhiwei.tzw/open_clip_pytorch_model.bin",
        precision=inference_dtype,
        device=device,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )    
    
    tokenizer = get_tokenizer(model_name)
    
    checkpoint_path = HPS_V2_PATH
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        outputs = model(x_var, caption)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits = image_features @ text_features.T
        scores = torch.diagonal(logits)
        loss = - scores
        return  loss
    
    return loss_fn

# pickscore
PickScore_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/pickscore"
def pick_loss_fn(inference_dtype=None, device=None):
    from open_clip import get_tokenizer

    model_name = "ViT-H-14"
    model = AutoModel.from_pretrained(PickScore_PATH) 
    
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        image_embs = model.get_image_features(x_var)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(caption)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
        # score
        scores = torch.diagonal(model.logit_scale.exp() * (text_embs @ image_embs.T))
        loss = - scores
        return  loss
    
    return loss_fn

def jpeg_compressibility(inference_dtype=None, device=None):
    import io
    import numpy as np
    def loss_fn(im_pix_un, prompts):
        images = ((im_pix_un / 2) + 0.5).clamp(0, 1)
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images = [Image.fromarray(image) for image in images]
        buffers = [io.BytesIO() for _ in images]
        for image, buffer in zip(images, buffers):
            image.save(buffer, format="JPEG", quality=95)
        sizes = [buffer.tell() / 1000 for buffer in buffers]
        return torch.tensor(sizes, dtype=inference_dtype, device=device)

    return loss_fn


# sampling algorithm
class SequentialDDIM:

    def __init__(self, timesteps = 100, scheduler = None, eta = 0.0, cfg_scale = 4.0, device = "cuda", opt_timesteps = 50):
        self.eta = eta 
        self.timesteps = timesteps
        self.num_steps = timesteps
        self.scheduler = scheduler
        self.device = device
        self.cfg_scale = cfg_scale
        self.opt_timesteps = opt_timesteps 

        # compute some coefficients in advance
        scheduler_timesteps = self.scheduler.timesteps.tolist()
        scheduler_prev_timesteps = scheduler_timesteps[1:]
        scheduler_prev_timesteps.append(0)
        self.scheduler_timesteps = scheduler_timesteps[::-1]
        scheduler_prev_timesteps = scheduler_prev_timesteps[::-1]
        alphas_cumprod = [1 - self.scheduler.alphas_cumprod[t] for t in self.scheduler_timesteps]
        alphas_cumprod_prev = [1 - self.scheduler.alphas_cumprod[t] for t in scheduler_prev_timesteps]

        now_coeff = torch.tensor(alphas_cumprod)
        next_coeff = torch.tensor(alphas_cumprod_prev)
        now_coeff = torch.clamp(now_coeff, min = 0)
        next_coeff = torch.clamp(next_coeff, min = 0)
        m_now_coeff = torch.clamp(1 - now_coeff, min = 0)
        m_next_coeff = torch.clamp(1 - next_coeff, min = 0)
        self.noise_thr = torch.sqrt(next_coeff / now_coeff) * torch.sqrt(1 - (1 - now_coeff) / (1 - next_coeff))
        self.nl = self.noise_thr * self.eta
        self.nl[0] = 0.
        m_nl_next_coeff = torch.clamp(next_coeff - self.nl**2, min = 0)
        self.coeff_x = torch.sqrt(m_next_coeff) / torch.sqrt(m_now_coeff)
        self.coeff_d = torch.sqrt(m_nl_next_coeff) - torch.sqrt(now_coeff) * self.coeff_x

    def is_finished(self):
        return self._is_finished

    def get_last_sample(self):
        return self._samples[0]

    def prepare_model_kwargs(self, prompt_embeds = None):

        t_ind = self.num_steps - len(self._samples)
        t = self.scheduler_timesteps[t_ind]
        batch = len(self._samples[0])

        uncond_embeds = torch.stack([prompt_embeds[0]] * batch)
        cond_embeds = torch.stack([prompt_embeds[1]] * batch)
   
        model_kwargs = {
            "sample": torch.concat([self._samples[0], self._samples[0]]),
            "timestep": torch.tensor([t] * 2 * batch, device = self.device),
            "encoder_hidden_states": torch.concat([uncond_embeds, cond_embeds])
        }

        model_kwargs["sample"] = self.scheduler.scale_model_input(model_kwargs["sample"],t)
    
        return model_kwargs


    def step(self, model_output):
        model_output_uncond, model_output_text = model_output[0].chunk(2)
        direction = model_output_uncond + self.cfg_scale * (model_output_text - model_output_uncond)

        t = self.num_steps - len(self._samples)

        if t <= self.opt_timesteps:
            now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]
        else:
            with torch.no_grad():
                now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]

        self._samples.insert(0, now_sample)
        
        if len(self._samples) > self.timesteps:
            self._is_finished = True

    def initialize(self, noise_vectors):
        self._is_finished = False

        self.noise_vectors = noise_vectors

        self._samples = [self.noise_vectors[-1]]
  

def sequential_sampling(pipeline, unet, sampler, prompt_embeds, noise_vectors): 


    sampler.initialize(noise_vectors)

    model_time = 0
    while not sampler.is_finished():
        model_kwargs = sampler.prepare_model_kwargs(prompt_embeds = prompt_embeds)
        model_output = unet(**model_kwargs)
        sampler.step(model_output) 

    return sampler.get_last_sample()


def decode_latent(decoder, latents):
    imgs = decoder.decode(latents / 0.18215).sample
    return imgs

def to_img(img):
    img = torch.clamp(127.5 * img.cpu() + 128.0, 0, 255).permute(1, 2, 0).to(dtype=torch.uint8).numpy()

    return img

def main():
    parser = argparse.ArgumentParser(description='Diffusion Optimization with Differentiable Objective')
    parser.add_argument('--model', type=str, default="/mnt/workspace/workgroup/tangzhiwei.tzw/sdv1-5-full-diffuser", help='path to the model')
    parser.add_argument('--prompt', type=str, default="red and green eagle", help='prompt for the optimization')
    parser.add_argument('--num_steps', type=int, default=50, help='number of steps for optimization')
    parser.add_argument('--eta', type=float, default=1.0, help='noise scale')
    parser.add_argument('--guidance_scale', type=float, default=5.0, help='guidance scale')
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--opt_steps', type=int, default=50, help='number of optimization steps')
    parser.add_argument('--opt_time', type=int, default=50)
    parser.add_argument('--log_interval', type=int, default=1, help='log interval')
    parser.add_argument('--objective', type=str, default="aesthetic", help='objective for optimization', choices = ["aesthetic", "hps", "pick", "jpeg"])
    parser.add_argument('--precision', choices = ["bf16", "fp16", "fp32"], default="fp32", help='precision for optimization')
    parser.add_argument('--output', type=str, default="rank_opt_logs", help='output path')
    parser.add_argument('--batch_per_device', type=int, default=4, help='batch size per device')
    parser.add_argument('--mu', type=float, default=0.01, help='control the exploration of ranksgd')
    parser.add_argument('--tol', type=int, default=20)
    args = parser.parse_args()

    if args.precision == "bf16":
        inference_dtype = torch.bfloat16
    elif args.precision == "fp16":
        inference_dtype = torch.float16
    else:
        inference_dtype = torch.float32

    accelerator = Accelerator()

    # load model
    pipeline = StableDiffusionPipeline.from_pretrained(args.model).to(device = accelerator.device)
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)
    # disable safety checker
    pipeline.safety_checker = None
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    # set the number of steps
    pipeline.scheduler.set_timesteps(args.num_steps)

    pipeline.vae.to(dtype = inference_dtype)
    pipeline.text_encoder.to(dtype = inference_dtype)
    pipeline.unet.to(dtype = inference_dtype)
    
    unet = pipeline.unet

    if args.objective == "aesthetic":
        loss_fn = aesthetic_loss_fn(torch_dtype = inference_dtype, device = accelerator.device)
    elif args.objective == "hps":
        loss_fn = hps_loss_fn(inference_dtype = inference_dtype, device = accelerator.device)
    elif args.objective == "pick":
        loss_fn = pick_loss_fn(inference_dtype = inference_dtype, device = accelerator.device)
    elif args.objective == "jpeg":
        loss_fn = jpeg_compressibility(inference_dtype = inference_dtype, device = accelerator.device)
    else:
        raise ValueError("Invalid objective")

    torch.manual_seed(args.seed)
    noise_vectors = torch.randn(args.num_steps + 1, 4, 64, 64, device = accelerator.device)

    # make sure all devices have the same initial noise vectors
    noise_vectors = noise_vectors.unsqueeze(0)
    noise_vectors = accelerator.gather(noise_vectors)
    noise_vectors = noise_vectors[0]

    # make sure all devices have different randomness
    torch.manual_seed(args.seed + accelerator.process_index)
    torch.manual_seed(torch.randint(0, 1000, (1,)).item() + accelerator.process_index)

    optimize_groups = []
    noise_vectors.requires_grad_(True)
    optimize_groups.append({"params":noise_vectors, "lr":1e-3, "weight_decy":0})

    optimizer = torch.optim.AdamW(optimize_groups)

    if "random" in args.prompt:
        with open("/mnt/workspace/workgroup/tangzhiwei.tzw/spec_samp/simple_animals_activities.txt") as f:
            animals = [line.strip("\n") for line in f.readlines()]
        args.prompt = args.prompt.replace("random", np.random.choice(animals))
        
    if "randc" in args.prompt:
        colors = ["red", "green", "blue", "yellow", "purple", "pink", "orange", "brown", "gray", "black", "white", "gold", "silver"]
        args.prompt = args.prompt.replace("randc", np.random.choice(colors))

    prompt_embeds = pipeline._encode_prompt(
                        args.prompt,
                        accelerator.device,
                        1,
                        True,
                    ).to(dtype = inference_dtype)

    output_path = os.path.join(args.output, f"obj:{args.objective},pt:{args.prompt},st:{args.num_steps},et:{args.eta},gd:{args.guidance_scale},sd:{args.seed},ot:{args.opt_time},prec:{args.precision}")

    if accelerator.is_local_main_process:
        if os.path.exists(output_path):
            shutil.rmtree(output_path)
        os.makedirs(output_path)

    no_improve = 0
    current_best_noise = None
    current_best_loss = 99999
    use_mu = args.mu
    tol = args.tol
    # start optimization
    for i in range(args.opt_steps):
        optimizer.zero_grad()
        no_improve += 1
        start_time = time.time()
        with torch.no_grad():
            ddim_sampler = SequentialDDIM(timesteps = args.num_steps,
                                            scheduler = pipeline.scheduler, 
                                            eta = args.eta, 
                                            cfg_scale = args.guidance_scale, 
                                            device = accelerator.device,
                                            opt_timesteps = args.opt_time)
            noise_vectors_flat = noise_vectors.detach().unsqueeze(1).to(dtype=inference_dtype)
            cand_noise_vectors = noise_vectors_flat + use_mu * torch.randn(args.num_steps + 1, args.batch_per_device , 4, 64, 64, device = accelerator.device, dtype = inference_dtype)
            cand_noise_vectors = torch.concat([cand_noise_vectors, noise_vectors_flat], dim = 1)

            samples = sequential_sampling(pipeline, unet, ddim_sampler, prompt_embeds = prompt_embeds, noise_vectors = cand_noise_vectors)
            samples = decode_latent(pipeline.vae, samples)

            
            losses = loss_fn(samples, [args.prompt] * samples.shape[0])

        center_loss = losses[-1].item()
        cand_noise_vectors = cand_noise_vectors.permute(1, 0, 2, 3, 4)
        cand_noise_vectors = cand_noise_vectors[:-1]
        cand_noise_vectors = accelerator.gather(cand_noise_vectors)
        losses = losses[:-1]
        losses = accelerator.gather(losses)
        loss = losses.min().item()

        if center_loss < current_best_loss:
            current_best_loss = center_loss
            current_best_noise = noise_vectors.detach().clone().to(dtype = torch.float32)
            no_improve = 0

        if loss < current_best_loss:
            current_best_loss = loss
            current_best_noise = cand_noise_vectors[losses.argmin()]
            no_improve = 0

        if no_improve > tol:
            # restart the optimization
            if accelerator.is_local_main_process:
                print("Restart the optimization")
            noise_vectors = current_best_noise.to(dtype = torch.float32)
            noise_vectors.requires_grad_(True)
            optimizer = torch.optim.AdamW([{"params":noise_vectors, "lr":5e-3, "weight_decy":0}])
            no_improve = 0
            use_mu = max([0.01, use_mu / 2])
            tol *= 2

        else:
            losses_rank = losses.argsort().argsort() + 1
            weights = 2 * losses_rank - len(losses_rank) - 1
            weights = weights.to(dtype = inference_dtype)
            rank_grad_est = torch.einsum("k,kijlb->ijlb",weights, cand_noise_vectors)
            grad_norm = torch.sqrt(torch.sum(rank_grad_est ** 2, dim = [1,2,3], keepdim= True))
            rank_grad_est = rank_grad_est / (grad_norm + 1e-3)
            noise_vectors.grad = rank_grad_est.to(dtype = torch.float32)

            optimizer.step()

        end_time = time.time()

        # some auxiliary information
        with torch.no_grad():
            noise_norm = torch.norm(noise_vectors)
            noise_max = noise_vectors.max()

        if accelerator.is_local_main_process:
            print(losses.sort().values.cpu().numpy())
            print(i, center_loss, current_best_loss, use_mu, no_improve, noise_norm.item(), noise_max.item(), end_time - start_time, len(losses_rank))

            if i % args.log_interval == 0:
                current_best_sample = sequential_sampling(pipeline, unet, ddim_sampler, prompt_embeds = prompt_embeds, noise_vectors = current_best_noise.unsqueeze(1).to(dtype = inference_dtype))
                current_best_sample = decode_latent(pipeline.vae, current_best_sample)
                img = to_img(current_best_sample[0])
                img = Image.fromarray(img)
                img.save(os.path.join(output_path, f"{i}_{-current_best_loss}.png"))
                print("saved image")
    
    if accelerator.is_local_main_process:
        current_best_sample = sequential_sampling(pipeline, unet, ddim_sampler, prompt_embeds = prompt_embeds, noise_vectors = current_best_noise.unsqueeze(1).to(dtype = inference_dtype))
        current_best_sample = decode_latent(pipeline.vae, current_best_sample)
        img = to_img(current_best_sample[0])
        img = Image.fromarray(img)
        img.save(os.path.join(output_path, f"{i}_{-current_best_loss}.png"))
        print("saved image")

if __name__ == "__main__":
    main()