
import os
import ot
import yaml
import random
import shutil
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm, trange
from safetensors import safe_open
from collections import defaultdict

import torch
import torch.nn as nn
import torchvision.transforms as T
from transformers import logging
from diffusers import DDIMScheduler, StableDiffusionXLPipeline, AutoencoderKL

from sdxl_utils import *

# filter warnings
logging.set_verbosity_error()
warnings.filterwarnings('ignore', message='.*deprecated.*')

class GenDemo(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        for k, v in vars(args).items():
            setattr(self, k, v)
    
    def load_emb(self, emb_path):
        # load embedding from prompt tuning
        with safe_open(emb_path, framework='pt', device='cpu') as f:
            emb = f.get_tensor('prompt_embeddings').squeeze()
        # calculate stats & pre-processing
        emb = emb / self.tau
        emb = emb.exp()
        emb_sum = emb.sum()
        emb = emb / emb_sum
        # print(f'[INFO]: emb_sum = {emb_sum.item()}')
        return emb

    def load_model(self, model='stabilityai/stable-diffusion-xl-base-1.0', vae='madebyollin/sdxl-vae-fp16-fix'):
        self.scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler')
        self.scheduler.set_timesteps(self.steps, device=self.device)
        print(f'[INFO] Loading model: {model}')
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
            model, torch_dtype=torch.float16,
        ).to(self.device)
        self.text_encoder = self.pipe.text_encoder
        self.tokenizer = self.pipe.tokenizer
        self.unet = self.pipe.unet
        self.vae = AutoencoderKL.from_pretrained(
            vae, torch_dtype=torch.float16
        ).to(self.device)
        self.tokenizer_kwargs = dict(
            truncation=True,
            return_tensors='pt',
            padding='max_length',
            max_length=self.tokenizer.model_max_length
        )
        self.preprocess = T.Compose([
            T.Resize((self.resolution, self.resolution)),
            T.ToTensor(),
            T.Normalize([0.5], [0.5]),
        ])

    @torch.no_grad()
    def encode_image(self, img):
        image = self.preprocess(img).unsqueeze(0).to(torch.float16).to(self.device)
        latent = self.vae.encode(image).latent_dist.sample()
        latent = latent * self.vae.config.scaling_factor
        return latent

    @torch.no_grad()
    def latent_to_image(self, latent, save_path, save_flag=True):
        image = self.vae.decode(latent / self.vae.config.scaling_factor, return_dict=False)[0]
        image = self.pipe.image_processor.postprocess(image, output_type='pil')[0]
        if save_flag:
            image.save(save_path)
        else:
            return image.convert('RGB')

    def select_colors(self, x):
        return np.stack([x[..., 'rgb'.index(c)] for c in self.colors], axis=-1)

    def compute_histogram(self, image, bins):
        if isinstance(image, str):
            image = np.array(Image.open(image)).astype(np.uint8)
        # Quantize input image
        quantized_image = (image // (256 // bins)).astype(int)
        quantized_image = self.select_colors(quantized_image)
        # Compute histogram
        full_range = [(0, bins)] * len(self.colors)
        full_hist = np.histogramdd(quantized_image.reshape(-1, len(self.colors)),
                                   bins=bins, range=full_range)[0].flatten()
        if self.mode == 'more':
            step = (bins ** len(self.colors)) // 4096
            indices = np.arange(0, (bins ** len(self.colors)), step)
        else:
            indices = np.arange((bins ** len(self.colors)))
        assert indices.shape[0] == 4096
        hist = full_hist[indices]
        return hist, indices

    def compute_cost_matrix(self, indices, bins):
        bin_centers = np.arange(0.5, bins, 1)  # Center of bins for quantization
        if len(self.colors) == 2:
            P, Q = np.meshgrid(bin_centers, bin_centers, indexing='ij')
            full_bin_colors = np.vstack([P.flatten(), Q.flatten()]).T
        else:
            R, G, B = np.meshgrid(bin_centers, bin_centers, bin_centers, indexing='ij')
            full_bin_colors = np.vstack([R.flatten(), G.flatten(), B.flatten()]).T
        self.bin_colors = full_bin_colors[indices]
        cost_matrix = ot.dist(self.bin_colors, self.bin_colors, metric='sqeuclidean')
        return cost_matrix

    @staticmethod
    def emb2hist(emb, total_pixels):
        hist = emb * total_pixels
        hist_raw = hist.numpy()
        hist_int = hist.round().int()
        diff = total_pixels - hist_int.sum()
        if diff != 0:
            frac = hist - hist.floor()
            indices = torch.argsort(frac, descending=True)
            adjustment = int(diff / abs(diff))
            hist_int[indices[:abs(diff)]] += adjustment
        assert hist_int.sum() == total_pixels
        hist_final = hist_int.numpy()
        return hist_final, hist_raw

    def recolor(self, image_np):            
        # solve for OT plan
        source_hist, indices = self.compute_histogram(image_np, bins=self.color_bins)
        self.cost_matrix = self.compute_cost_matrix(indices, bins=self.color_bins)
        self.target_hist, self.target_hist_raw = self.emb2hist(self.emb, int(source_hist.sum()))
        transport_plan = ot.emd(source_hist, self.target_hist, self.cost_matrix, numItermax=500000)
        # init mapping dict
        d = defaultdict(lambda: None)
        for i, idx in enumerate(indices):
            d[idx] = i
        # apply OT plan
        ot_bins = self.emb_bins ** len(self.colors)
        pixel_bins = [[] for _ in range(ot_bins)]
        quantized_image = (image_np // (256 // self.color_bins)).astype(int)
        quantized_image = self.select_colors(quantized_image.reshape(-1, 3))
        for i, pixel in enumerate(quantized_image):
            if len(self.colors) == 2:
                bin_index = pixel[0] * self.color_bins + pixel[1]
            else:
                bin_index = pixel[0] * (self.color_bins ** 2) + \
                            pixel[1] * self.color_bins + pixel[2]
            if d[bin_index] is not None:
                pixel_bins[d[bin_index]].append(i)
        for i in range(ot_bins):
            random.shuffle(pixel_bins[i])
        output_image = image_np.copy()
        part = self.select_colors(image_np.copy())
        for target_bin in range(ot_bins):
            required_pixels = self.target_hist[target_bin]
            if required_pixels == 0:
                continue
            for source_bin in range(ot_bins):
                if transport_plan[source_bin, target_bin] > 0 and len(pixel_bins[source_bin]) > 0:
                    num_pixels = int(transport_plan[source_bin, target_bin])
                    sampled_indices = pixel_bins[source_bin][:num_pixels]
                    target_color_quantized = (self.bin_colors[target_bin] - 0.5).astype(int)
                    target_color = tuple(int(255 / (self.color_bins-1) * c) for c in target_color_quantized)
                    part.reshape(-1, len(self.colors))[sampled_indices] = target_color
                    pixel_bins[source_bin] = pixel_bins[source_bin][num_pixels:]
        for i, c in enumerate(self.colors):
            output_image[:, :, 'rgb'.index(c)] = part[..., i]
        return output_image

    @torch.no_grad()
    def sampling_loop(self):

        h = w = self.resolution
        device = self.pipe._execution_device
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds \
        = self.pipe.encode_prompt(prompt=self.pos_prompt, negative_prompt=self.neg_prompt, device=device)

        timesteps, self.steps = retrieve_timesteps(
            self.scheduler, self.steps, device
        )

        generator = torch.Generator(device=self.device)
        generator.manual_seed(self.seed)
        latents = torch.randn(1, 4, h//8, w//8).to(torch.float16).to(self.device)

        add_text_embeds = pooled_prompt_embeds
        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        add_time_ids = self.pipe._get_add_time_ids(
            (h, w), (0, 0), (h, w), dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        negative_add_time_ids = add_time_ids
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
        add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(torch.float16).to(self.device).repeat(2, 1, 1)
        add_text_embeds = add_text_embeds.to(torch.float16).to(self.device).repeat(2, 1)
        add_time_ids = add_time_ids.to(torch.float16).to(self.device).repeat(2, 1)

        # self.latent_to_image(
        #     latent=latents,
        #     save_path=os.path.join(self.save_dir, 'noise.png')
        # )
        x, y = latents.clone(), latents.clone()
        with trange(self.steps) as progress_bar:
            for i, t in enumerate(timesteps):

                latent_model_input = torch.cat([x, x, y, y])
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                noise_pred = self.unet(
                    latent_model_input, t,
                    encoder_hidden_states=prompt_embeds,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                x_neg, x_pos, y_neg, y_pos = noise_pred.chunk(4)
                x_pred = x_neg + self.omega * (x_pos - x_neg)
                y_pred = y_neg + self.omega * (y_pos - y_neg)

                if i in self.perturb_steps:
                    # perturb color hist
                    z_0 = self.scheduler.step(x_pred, t, x).pred_original_sample
                    curr_img = self.latent_to_image(z_0, None, save_flag=False)
                    # curr_img.save(os.path.join(self.save_dir, f'perturb_step{i:02d}_before.png'))
                    perturbed_img = self.perturb(np.array(curr_img).astype(np.uint8))
                    # perturbed_img.save(os.path.join(self.save_dir, f'perturb_step{i:02d}_after.png'))
                    z_0_ = self.encode_image(perturbed_img)
                    t_next = self.scheduler.timesteps[i+1]
                    x = self.scheduler.add_noise(z_0_, x_pred, t_next)
                    y = self.scheduler.step(y_pred, t, y).prev_sample
                    # self.latent_to_image(
                    #     latent=x,
                    #     save_path=os.path.join(self.save_dir, f'perturb_step{i:02d}_xt.png')
                    # )
                    # self.latent_to_image(
                    #     latent=y,
                    #     save_path=os.path.join(self.save_dir, f'baseline_step{i:02d}.png')
                    # )
                else:
                    # denoise step
                    x = self.scheduler.step(x_pred, t, x).prev_sample
                    y = self.scheduler.step(y_pred, t, y).prev_sample

                progress_bar.update()

        # save output latent
        self.output_latent = x
        self.baseline_latent = y

    def perturb(self, pixels):
        # Adjust the image with the target histogram
        adjusted_pixels = self.recolor(pixels)
        return Image.fromarray(adjusted_pixels, 'RGB')

    def final_perturb(self, path):
        image_np = np.array(Image.open(path)).astype(np.uint8)
        adjusted_image_np = self.recolor(image_np)
        adjusted_image = Image.fromarray(adjusted_image_np, 'RGB')
        path = os.path.join(self.save_dir, 'final.png')
        adjusted_image.save(path)

    def final_histogram(self):
        output_path = os.path.join(self.save_dir, 'output.png')
        self.final_perturb(output_path)
        final_path = os.path.join(self.save_dir, 'final.png')
        final_hist, _ = self.compute_histogram(final_path, bins=self.color_bins)
        diff_hist = final_hist - self.target_hist_raw
        deviation = round(np.abs(diff_hist).max(), 4)
        print(f'Max pixel deviation: {deviation}')
        pixels = int(np.abs(diff_hist).sum())
        print(f'Incorrect pixel ratio: {pixels / (self.resolution ** 2):.4f}')

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def main(args):
    with open('../data/prompt.txt', 'r') as f:
        prompts = f.read().split('\n')
    demo = GenDemo(args)
    demo.load_model()
    for i in trange(300):
        emb_path = f'../exp/ckpt_test/tok{args.num_tokens}/{i}.safetensors'
        if not os.path.exists(emb_path):
            continue
        seed_everything(args.seed)
        demo.target_hist = None
        demo.pos_prompt = prompts[i % len(prompts)]
        print(demo.pos_prompt)
        demo.save_dir = f'./images_quant/tok{args.num_tokens}/{i}'
        if os.path.exists(demo.save_dir):
            continue
        os.makedirs(demo.save_dir, exist_ok=True)
        demo.emb = demo.load_emb(emb_path)
        demo.sampling_loop()
        demo.latent_to_image(
            latent=demo.output_latent,
            save_path=os.path.join(demo.save_dir, 'output.png')
        )
        demo.latent_to_image(
            latent=demo.baseline_latent,
            save_path=os.path.join(demo.save_dir, 'baseline.png')
        )
        demo.final_histogram()

if __name__ == "__main__":
    parse_steps = lambda s: [int(i) for i in s.split(',')]
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_tokens', type=int, default=32)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--resolution', type=int, default=1024,
                        help='Image resolution for height and width')
    parser.add_argument('--steps', type=int, default=50,
                        help='Number of inference steps')
    parser.add_argument('--perturb_steps', type=parse_steps, default=[9, 19, 29, 39])
    parser.add_argument('--emb_dir', type=str, default='../length/ckpt/tokens2048_x/',
                        help='Path to soft prompt embedding')
    parser.add_argument('--pos_prompt', type=str, default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                        help='Positive prompt for generation')
    parser.add_argument('--neg_prompt', type=str, default='worst quality, blurry, NSFW',
                        help='Negative prompt for generation')
    parser.add_argument('--omega', type=float, default=7.5,
                        help='Classifier-free guidance factor')
    parser.add_argument('--colors', type=str, default='rg',
                        help='Color channels for OT controlling')
    parser.add_argument('--mode', type=str, default='exact',
                        help='Mode for final histogram dimension')
    parser.add_argument('--tau', type=float, default=1.0,
                        help='Temperature for processing the histogram')  
    args = parser.parse_args()
    if len(args.colors) == 2 and args.mode == 'more':
        args.color_bins, args.emb_bins = 128, 64
    elif len(args.colors) == 2:
        args.color_bins, args.emb_bins = 64, 64
    elif args.mode == 'more':
        args.color_bins, args.emb_bins = 32, 16
    else:
        args.color_bins, args.emb_bins = 16, 16
    main(args)