
import os
import ot
import yaml
import random
import shutil
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm, trange
from safetensors import safe_open
from collections import defaultdict

import torch
import torch.nn as nn
import torchvision.transforms as T
from transformers import logging
from diffusers import DDIMScheduler, StableDiffusionXLPipeline, AutoencoderKL

from sdxl_utils import *

# filter warnings
logging.set_verbosity_error()
warnings.filterwarnings('ignore', message='.*deprecated.*')

class GenDemo(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        for k, v in vars(args).items():
            setattr(self, k, v)
        self.rand2num = np.load('indices_rand2num.npy')
        self.num2rand = np.load('indices_num2rand.npy')
        self.cost_matrix = np.load('cost_matrix.npy')
        self.target_indices = np.load('target_indices.npy')
        self.color_to_bin = np.load('color_to_bin.npy').astype(int)
        self.bin_to_colors = np.load('bin_to_colors.npy').astype(int)
        self.bin_to_indices = (self.bin_to_colors[..., 0] << 8) + self.bin_to_colors[..., 1]
    
    def load_emb(self, emb_path):
        # load embedding from prompt tuning
        with safe_open(emb_path, framework='pt', device='cpu') as f:
            emb = f.get_tensor('prompt_embeddings').squeeze()
        # calculate stats & pre-processing
        emb = emb / self.tau
        emb = emb.exp()
        emb_sum = emb.sum()
        emb = emb / emb_sum
        # print(f'[INFO]: emb_sum = {emb_sum.item()}')
        return emb.numpy()

    def load_model(self, model='stabilityai/stable-diffusion-xl-base-1.0', vae='madebyollin/sdxl-vae-fp16-fix'):
        self.scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler')
        self.scheduler.set_timesteps(self.steps, device=self.device)
        print(f'[INFO] Loading model: {model}')
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
            model, torch_dtype=torch.float16,
        ).to(self.device)
        self.text_encoder = self.pipe.text_encoder
        self.tokenizer = self.pipe.tokenizer
        self.unet = self.pipe.unet
        self.vae = AutoencoderKL.from_pretrained(
            vae, torch_dtype=torch.float16
        ).to(self.device)
        self.tokenizer_kwargs = dict(
            truncation=True,
            return_tensors='pt',
            padding='max_length',
            max_length=self.tokenizer.model_max_length
        )
        self.preprocess = T.Compose([
            T.Resize((self.resolution, self.resolution)),
            T.ToTensor(),
            T.Normalize([0.5], [0.5]),
        ])

    @torch.no_grad()
    def encode_image(self, img):
        image = self.preprocess(img).unsqueeze(0).to(torch.float16).to(self.device)
        latent = self.vae.encode(image).latent_dist.sample()
        latent = latent * self.vae.config.scaling_factor
        return latent

    @torch.no_grad()
    def latent_to_image(self, latent, save_path, save_flag=True):
        image = self.vae.decode(latent / self.vae.config.scaling_factor, return_dict=False)[0]
        image = self.pipe.image_processor.postprocess(image, output_type='pil')[0]
        if save_flag:
            image.save(save_path)
        else:
            return image.convert('RGB')

    def select_colors(self, x):
        return np.stack([x[..., 'rgb'.index(c)] for c in self.colors], axis=-1)

    def compute_histogram(self, image, reduce=True):
        if isinstance(image, str):
            image = np.array(Image.open(image))
        # Select two of the color channels
        image = self.select_colors(image.astype(int))
        image = (image[..., 0] << 8) + image[..., 1]
        image = self.num2rand[image].flatten()
        hist = np.bincount(image, minlength=65536)
        if reduce:
            return hist.reshape(4096, -1).sum(axis=1)
        else:
            return hist, image

    @staticmethod
    def emb2hist(emb, total_pixels):
        hist = emb * total_pixels
        hist_raw = hist.copy()
        hist_int = np.round(hist).astype(int)
        diff = total_pixels - np.sum(hist_int)
        if diff != 0:
            frac = hist - np.floor(hist)
            indices = np.argsort(-frac)
            adjustment = int(diff / abs(diff))
            hist_int[indices[:abs(diff)]] += adjustment
        assert np.sum(hist_int) == total_pixels
        return hist_int.astype(int), hist_raw

    def recolor(self, image_np):            
        # solve for OT plan
        if (not hasattr(self, 'target_hist')) or (self.target_hist is None):
            self.target_hist, self.target_hist_raw = self.emb2hist(self.emb, (self.resolution ** 2))
        source_hist, image_flat = self.compute_histogram(image_np, reduce=False)
        # print('Solving for OT plan...')
        transport_plan = ot.emd(source_hist, self.target_hist, self.cost_matrix, numItermax=500000)
        # print('Collecting source pixels...')
        pixel_bins = [[] for _ in range(source_hist.shape[0])]
        for i, idx in enumerate(image_flat):
            pixel_bins[idx].append(i)
        for i in range(source_hist.shape[0]):
            random.shuffle(pixel_bins[i])
        # print('Transporting pixels...')
        output_image = image_np.copy()
        part = self.select_colors(image_np.copy())
        for target_bin in range(self.target_hist.shape[0]):
            required_pixels = self.target_hist[target_bin]
            if required_pixels == 0:
                continue
            for source_bin in range(source_hist.shape[0]):
                if transport_plan[source_bin, target_bin] > 0:
                    num_pixels = int(transport_plan[source_bin, target_bin])
                    sampled_indices = pixel_bins[source_bin][:num_pixels]
                    target_color_idx = self.target_indices[source_bin, target_bin]
                    target_color = self.bin_to_colors[target_bin][target_color_idx]
                    part.reshape(-1, len(self.colors))[sampled_indices] = target_color
                    pixel_bins[source_bin] = pixel_bins[source_bin][num_pixels:]
        for i, c in enumerate(self.colors):
            output_image[:, :, 'rgb'.index(c)] = part[..., i]
        output_image = output_image.astype(np.uint8)
        # transport_hist = self.compute_histogram(output_image, reduce=True)
        # print(f'transport hist diff: {np.abs(transport_hist - self.target_hist).sum()}')
        return output_image

    @torch.no_grad()
    def sampling_loop(self):

        h = w = self.resolution
        device = self.pipe._execution_device
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds \
        = self.pipe.encode_prompt(prompt=self.pos_prompt, negative_prompt=self.neg_prompt, device=device)

        timesteps, self.steps = retrieve_timesteps(
            self.scheduler, self.steps, device
        )

        generator = torch.Generator(device=self.device)
        generator.manual_seed(self.seed)
        latents = torch.randn(1, 4, h//8, w//8).to(torch.float16).to(self.device)

        add_text_embeds = pooled_prompt_embeds
        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        add_time_ids = self.pipe._get_add_time_ids(
            (h, w), (0, 0), (h, w), dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        negative_add_time_ids = add_time_ids
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
        add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(torch.float16).to(self.device).repeat(2, 1, 1)
        add_text_embeds = add_text_embeds.to(torch.float16).to(self.device).repeat(2, 1)
        add_time_ids = add_time_ids.to(torch.float16).to(self.device).repeat(2, 1)

        # self.latent_to_image(
        #     latent=latents,
        #     save_path=os.path.join(self.save_dir, 'noise.png')
        # )
        x, y = latents.clone(), latents.clone()
        with trange(self.steps) as progress_bar:
            for i, t in enumerate(timesteps):

                latent_model_input = torch.cat([x, x, y, y])
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                noise_pred = self.unet(
                    latent_model_input, t,
                    encoder_hidden_states=prompt_embeds,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                x_neg, x_pos, y_neg, y_pos = noise_pred.chunk(4)
                x_pred = x_neg + self.omega * (x_pos - x_neg)
                y_pred = y_neg + self.omega * (y_pos - y_neg)

                if i in self.perturb_steps:
                    # perturb color hist
                    z_0 = self.scheduler.step(x_pred, t, x).pred_original_sample
                    curr_img = self.latent_to_image(z_0, None, save_flag=False)
                    # curr_img.save(os.path.join(self.save_dir, f'perturb_step{i:02d}_before.png'))
                    perturbed_img = self.perturb(np.array(curr_img).astype(np.uint8))
                    # perturbed_img.save(os.path.join(self.save_dir, f'perturb_step{i:02d}_after.png'))
                    z_0_ = self.encode_image(perturbed_img)
                    t_next = self.scheduler.timesteps[i+1]
                    x = self.scheduler.add_noise(z_0_, x_pred, t_next)
                    y = self.scheduler.step(y_pred, t, y).prev_sample
                    # self.latent_to_image(
                    #     latent=x,
                    #     save_path=os.path.join(self.save_dir, f'perturb_step{i:02d}_xt.png')
                    # )
                    # self.latent_to_image(
                    #     latent=y,
                    #     save_path=os.path.join(self.save_dir, f'baseline_step{i:02d}.png')
                    # )
                else:
                    # denoise step
                    x = self.scheduler.step(x_pred, t, x).prev_sample
                    y = self.scheduler.step(y_pred, t, y).prev_sample

                progress_bar.update()

        # save output latent
        self.output_latent = x
        self.baseline_latent = y

    def perturb(self, pixels):
        # Adjust the image with the target histogram
        adjusted_pixels = self.recolor(pixels)
        return Image.fromarray(adjusted_pixels, 'RGB')

    def final_perturb(self, path):
        image_np = np.array(Image.open(path)).astype(np.uint8)
        adjusted_image_np = self.recolor(image_np)
        adjusted_image = Image.fromarray(adjusted_image_np, 'RGB')
        path = os.path.join(self.save_dir, 'final.png')
        adjusted_image.save(path)

    def final_histogram(self):
        output_path = os.path.join(self.save_dir, 'output.png')
        self.final_perturb(output_path)
        final_path = os.path.join(self.save_dir, 'final.png')
        final_hist = self.compute_histogram(final_path, reduce=True)
        diff_hist = final_hist - self.target_hist_raw
        deviation = round(np.abs(diff_hist).max(), 4)
        print(f'Max pixel deviation: {deviation}')
        pixels = int(np.abs(diff_hist).sum())
        print(f'Incorrect pixel ratio: {pixels / (self.resolution ** 2):.4f}')

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def main(args):
    with open('../data/prompt.txt', 'r') as f:
        prompts = f.read().split('\n')
    demo = GenDemo(args)
    demo.load_model()
    for i in trange(300):
        emb_path = f'../exp/ckpt_test/tok{args.num_tokens}/{i}.safetensors'
        if not os.path.exists(emb_path):
            continue
        seed_everything(args.seed)
        demo.target_hist = None
        demo.pos_prompt = prompts[i % len(prompts)]
        print(demo.pos_prompt)
        demo.save_dir = f'./images_bin/tok{args.num_tokens}/{i}'
        if os.path.exists(demo.save_dir):
            continue
        os.makedirs(demo.save_dir, exist_ok=True)
        demo.emb = demo.load_emb(emb_path)
        demo.sampling_loop()
        demo.latent_to_image(
            latent=demo.output_latent,
            save_path=os.path.join(demo.save_dir, 'output.png')
        )
        demo.latent_to_image(
            latent=demo.baseline_latent,
            save_path=os.path.join(demo.save_dir, 'baseline.png')
        )
        demo.final_histogram()
        

if __name__ == "__main__":
    parse_steps = lambda s: [int(i) for i in s.split(',')]
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_tokens', type=int, default=32)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--resolution', type=int, default=1024,
                        help='Image resolution for height and width')
    parser.add_argument('--steps', type=int, default=50,
                        help='Number of inference steps')
    parser.add_argument('--perturb_steps', type=parse_steps, default=[])
    parser.add_argument('--emb_dir', type=str, default='../length/ckpt/tokens2048_x/',
                        help='Path to soft prompt embedding')
    parser.add_argument('--pos_prompt', type=str, default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                        help='Positive prompt for generation')
    parser.add_argument('--neg_prompt', type=str, default='worst quality, blurry, NSFW',
                        help='Negative prompt for generation')
    parser.add_argument('--omega', type=float, default=7.5,
                        help='Classifier-free guidance factor')
    parser.add_argument('--colors', type=str, default='rg',
                        help='Color channels for OT controlling')
    parser.add_argument('--tau', type=float, default=1.0,
                        help='Temperature for processing the histogram')  
    args = parser.parse_args()
    main(args)