import torch
import torch.nn as nn
import torchvision
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import argparse
import torch.utils.checkpoint as checkpoint
import os, shutil
from PIL import Image
import time
from torch import autocast
from torch.cuda.amp import GradScaler
from transformers import CLIPModel, CLIPProcessor, AutoProcessor, AutoModel
import sys
# from optimizers import SignSGD

import numpy as np

# Aesthetic Scorer
class MLPDiff(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )


    def forward(self, embed):
        return self.layers(embed)

MODEL_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14'
ASSETS_PATH = '/mnt/workspace/workgroup/tangzhiwei.tzw/reward_optimization/reward_opt/assets'

class AestheticScorerDiff(torch.nn.Module):
    def __init__(self, dtype):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(MODEL_PATH)
        self.mlp = MLPDiff()
        state_dict = torch.load(os.path.join(ASSETS_PATH, "sac+logos+ava1-l14-linearMSE.pth"))
        self.mlp.load_state_dict(state_dict)
        self.dtype = dtype
        self.eval()

    def __call__(self, images):
        device = next(self.parameters()).device
        embed = self.clip.get_image_features(pixel_values=images)
        embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
        return self.mlp(embed).squeeze(1)

def aesthetic_loss_fn(device=None,
                     torch_dtype=None):
    
    target_size = 224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
    scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
    scorer.requires_grad_(False)
    target_size = 224
    def loss_fn(im_pix_un, prompts=None):
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        im_pix = torchvision.transforms.Resize(target_size)(im_pix)
        im_pix = normalize(im_pix).to(im_pix_un.dtype)
        rewards = scorer(im_pix)
        loss = -1 * rewards

        return loss
        
    return loss_fn

def white_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.mean() 
        loss = -1 * rewards

        return loss
        
    return loss_fn


def black_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.mean() 
        loss =  rewards

        return loss
        
    return loss_fn

def contrast_loss_fn(device=None,
                     torch_dtype=None):
    
    def loss_fn(im_pix_un, prompts=None):
        
        rewards = im_pix_un.sum(dim=1).var()
        loss = -1 * rewards

        return loss
        
    return loss_fn

# HPS-v2
HPS_V2_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/HPS_v2_compressed.pt"
def hps_loss_fn(inference_dtype=None, device=None):
    import hpsv2
    from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer

    model_name = "ViT-H-14"
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name,
        "/mnt/workspace/workgroup/tangzhiwei.tzw/open_clip_pytorch_model.bin",
        precision=inference_dtype,
        device=device,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )    
    
    tokenizer = get_tokenizer(model_name)
    
    checkpoint_path = HPS_V2_PATH
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        outputs = model(x_var, caption)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits = image_features @ text_features.T
        scores = torch.diagonal(logits)
        loss = - scores
        return  loss
    
    return loss_fn

# pickscore
PickScore_PATH = "/mnt/workspace/workgroup/tangzhiwei.tzw/pickscore"
def pick_loss_fn(inference_dtype=None, device=None):
    from open_clip import get_tokenizer

    model_name = "ViT-H-14"
    model = AutoModel.from_pretrained(PickScore_PATH) 
    
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix_un, prompts):    
        im_pix = ((im_pix_un / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        image_embs = model.get_image_features(x_var)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(caption)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
        # score
        scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0][0]
        loss = - scores
        return  loss
    
    return loss_fn


# CLIP score evaluation

def clip_score(inference_dtype=None, device=None):
    from transformers import CLIPProcessor, CLIPModel

    model = CLIPModel.from_pretrained("/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14")
    processor = CLIPProcessor.from_pretrained("/mnt/workspace/workgroup/tangzhiwei.tzw/clip-vit-large-patch14")
    
    model = model.to(device = device, dtype=inference_dtype)
    
    @torch.no_grad()
    def loss_fn(image, prompt):    
        inputs = processor(text=[prompt], images=image, return_tensors="pt", padding=True)
        
        for key, value in inputs.items():
            inputs[key] = value.to(device)

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image 
        score = logits_per_image.cpu().numpy()[0][0]
        
        return  score
    
    return loss_fn


# sampling algorithm
class SequentialDDIM:

    def __init__(self, timesteps = 100, pipeline = None, scheduler = None, loss_fn = None, prompt = None, eta = 0.0, cfg_scale = 4.0, device = "cuda", opt_timesteps = 50):
        self.eta = eta 
        self.timesteps = timesteps
        self.num_steps = timesteps
        self.pipeline = pipeline
        self.scheduler = scheduler
        self.loss_fn = loss_fn
        self.prompt = prompt
        self.device = device
        self.cfg_scale = cfg_scale
        self.opt_timesteps = opt_timesteps
         

        # compute some coefficients in advance
        scheduler_timesteps = self.scheduler.timesteps.tolist()
        scheduler_prev_timesteps = scheduler_timesteps[1:]
        scheduler_prev_timesteps.append(0)
        self.scheduler_timesteps = scheduler_timesteps[::-1]
        scheduler_prev_timesteps = scheduler_prev_timesteps[::-1]
        alphas_cumprod = [1 - self.scheduler.alphas_cumprod[t] for t in self.scheduler_timesteps]
        alphas_cumprod_prev = [1 - self.scheduler.alphas_cumprod[t] for t in scheduler_prev_timesteps]

        now_coeff = torch.tensor(alphas_cumprod)
        next_coeff = torch.tensor(alphas_cumprod_prev)
        now_coeff = torch.clamp(now_coeff, min = 0)
        next_coeff = torch.clamp(next_coeff, min = 0)
        m_now_coeff = torch.clamp(1 - now_coeff, min = 0)
        m_next_coeff = torch.clamp(1 - next_coeff, min = 0)
        self.noise_thr = torch.sqrt(next_coeff / now_coeff) * torch.sqrt(1 - (1 - now_coeff) / (1 - next_coeff))
        self.nl = self.noise_thr * self.eta
        self.nl[0] = 0.
        m_nl_next_coeff = torch.clamp(next_coeff - self.nl**2, min = 0)
        self.coeff_x = torch.sqrt(m_next_coeff) / torch.sqrt(m_now_coeff)
        self.coeff_d = torch.sqrt(m_nl_next_coeff) - torch.sqrt(now_coeff) * self.coeff_x

    def is_finished(self):
        return self._is_finished

    def get_last_sample(self):
        return self._samples[0]

    def prepare_model_kwargs(self, prompt_embeds = None):

        t_ind = self.num_steps - len(self._samples)
        t = self.scheduler_timesteps[t_ind]
   
        model_kwargs = {
            "sample": torch.stack([self._samples[0], self._samples[0]]),
            "timestep": torch.tensor([t, t], device = self.device),
            "encoder_hidden_states": prompt_embeds
        }

        model_kwargs["sample"] = self.scheduler.scale_model_input(model_kwargs["sample"],t)

        return model_kwargs


    def step(self, model_output):
        model_output_uncond, model_output_text = model_output[0].chunk(2)
        direction = model_output_uncond + self.cfg_scale * (model_output_text - model_output_uncond)
        direction = direction[0]
        # make a deep copy of self._samples[0]
        now_sample = self._samples[0].clone().requires_grad_(True)
        now_sample.to('cuda')

        t = self.num_steps - len(self._samples)
        x_0_hat = now_sample + self.noise_thr[t] * direction
        scores = torch.zeros(1).to('cuda')
        n = int(sys.argv[1])
        for i in range(n):
            x_0_hat_noise = x_0_hat + self.noise_thr[t]/torch.sqrt(1+self.noise_thr[t]**2) * torch.rand_like(x_0_hat)
            img_sample = decode_latent(self.pipeline.vae, x_0_hat_noise)
            score = - self.loss_fn(img_sample.to(dtype=torch.float32), self.prompt)[0]
            scores += score

        MC_score = torch.log(scores/n)
        
        grad = torch.autograd.grad(MC_score, now_sample)[0]

        direction += grad

        if t <= self.opt_timesteps:
            now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]
        else:
            with torch.no_grad():
                now_sample = self.coeff_x[t] * self._samples[0] + self.coeff_d[t] * direction  + self.nl[t] * self.noise_vectors[t]

        self._samples.insert(0, now_sample)
        print(t)
        if len(self._samples) > self.timesteps:
            self._is_finished = True

    def initialize(self, noise_vectors):
        self._is_finished = False

        self.noise_vectors = noise_vectors

        if self.num_steps == self.opt_timesteps:
            self._samples = [self.noise_vectors[-1]]
        else:
            self._samples = [self.noise_vectors[-1].detach()]

def sequential_sampling(pipeline, unet, sampler, prompt_embeds, noise_vectors): 


    sampler.initialize(noise_vectors)

    model_time = 0
    while not sampler.is_finished():
        model_kwargs = sampler.prepare_model_kwargs(prompt_embeds = prompt_embeds)
        #model_output = pipeline.unet(**model_kwargs)
        model_output = checkpoint.checkpoint(unet, model_kwargs["sample"], model_kwargs["timestep"], model_kwargs["encoder_hidden_states"],  use_reentrant=False)
        sampler.step(model_output) 

    return sampler.get_last_sample()


def decode_latent(decoder, latent):
    img = decoder.decode(latent.unsqueeze(0) / 0.18215).sample
    return img

def to_img(img):
    img = torch.clamp(127.5 * img.cpu().float() + 128.0, 0, 255).permute(0, 2, 3, 1).to(dtype=torch.uint8).numpy()

    return img[0]


def main():
    timesteps = 50
    pipeline = StableDiffusionPipeline.from_pretrained('/mnt/workspace/workgroup/tangzhiwei.tzw/sdv1-5-full-diffuser').to('cuda')
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)
    # disable safety checker
    pipeline.safety_checker = None
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    # set the number of steps
    pipeline.scheduler.set_timesteps(timesteps)
    unet = pipeline.unet

    aesthetic_evaluator = aesthetic_loss_fn(torch_dtype = torch.float32, device = 'cuda')
    loss_fn = aesthetic_evaluator

    torch.manual_seed(123)
    

    with open("/mnt/workspace/workgroup/tangzhiwei.tzw/spec_samp/simple_animals_activities.txt") as f:
        animals = [line.strip("\n") for line in f.readlines()]
    
        
  
    colors = ["red", "green", "blue", "yellow", "purple", "pink", "orange", "brown", "gray", "black", "white", "gold", "silver"]
    
    n=int(sys.argv[1])
    output_path = f"/mnt/workspace/workgroup/tangzhiwei.tzw/reward_optimization/output_n{n}"
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.makedirs(output_path)
    
    for ii in range(1000):
        prompt =np.random.choice(animals)
        prompt = " ".join([np.random.choice(colors), prompt])
        
        noise_vectors = torch.randn(timesteps + 1, 4, 64, 64, device = 'cuda')
        
        print(prompt)
        
        prompt_embeds = pipeline._encode_prompt(
                            prompt,
                            'cuda',
                            1,
                            True,
                        )
        
        


        ddim_sampler = SequentialDDIM(timesteps = 50,
                                        pipeline = pipeline,
                                        scheduler = pipeline.scheduler, 
                                        loss_fn = loss_fn,
                                        eta = 1.0, 
                                        cfg_scale = 5.0, 
                                        device = 'cuda',
                                        opt_timesteps = timesteps)
        sample = sequential_sampling(pipeline, unet, ddim_sampler, prompt_embeds = prompt_embeds, noise_vectors = noise_vectors)
        sample = decode_latent(pipeline.vae, sample)


        img = to_img(sample)
        IMG = Image.fromarray(img)
        aesthetic_score = - aesthetic_evaluator(sample.to(dtype=torch.float32), prompt)[0].item()
        print(ii, "aesthetic_score", aesthetic_score)
        IMG.save(os.path.join(output_path, f"{ii}_{prompt.replace(' ', '_')}_{aesthetic_score}.png"),)  # Must specify desired format here
    


if __name__ == "__main__":
    main()