import argparse, os, sys, glob
import clip
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
import scann
import time
from multiprocessing import cpu_count

from ldm.util import instantiate_from_config, parallel_data_prefetch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder

DATABASES = [
    "openimages",
    "artbench-art_nouveau",
    "artbench-baroque",
    "artbench-expressionism",
    "artbench-impressionism",
    "artbench-post_impressionism",
    "artbench-realism",
    "artbench-romanticism",
    "artbench-renaissance",
    "artbench-surrealism",
    "artbench-ukiyo_e",
]


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


class Searcher(object):
    def __init__(self, database, retriever_version='ViT-L/14'):
        assert database in DATABASES
        # self.database = self.load_database(database)
        self.database_name = database
        self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
        self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
        self.retriever = self.load_retriever(version=retriever_version)
        self.database = {'embedding': [],
                         'img_id': [],
                         'patch_coords': []}
        self.load_database()
        self.load_searcher()

    def train_searcher(self, k,
                       metric='dot_product',
                       searcher_savedir=None):

        print('Start training searcher')
        searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
                                                  np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
                                                  k, metric)
        self.searcher = searcher.score_brute_force().build()
        print('Finish training searcher')

        if searcher_savedir is not None:
            print(f'Save trained searcher under "{searcher_savedir}"')
            os.makedirs(searcher_savedir, exist_ok=True)
            self.searcher.serialize(searcher_savedir)

    def load_single_file(self, saved_embeddings):
        compressed = np.load(saved_embeddings)
        self.database = {key: compressed[key] for key in compressed.files}
        print('Finished loading of clip embeddings.')

    def load_multi_files(self, data_archive):
        out_data = {key: [] for key in self.database}
        for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
            for key in d.files:
                out_data[key].append(d[key])

        return out_data

    def load_database(self):

        print(f'Load saved patch embedding from "{self.database_path}"')
        file_content = glob.glob(os.path.join(self.database_path, '*.npz'))

        if len(file_content) == 1:
            self.load_single_file(file_content[0])
        elif len(file_content) > 1:
            data = [np.load(f) for f in file_content]
            prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
                                                     n_proc=min(len(data), cpu_count()), target_data_type='dict')

            self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
                             self.database}
        else:
            raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')

        print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')

    def load_retriever(self, version='ViT-L/14', ):
        model = FrozenClipImageEmbedder(model=version)
        if torch.cuda.is_available():
            model.cuda()
        model.eval()
        return model

    def load_searcher(self):
        print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
        self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
        print('Finished loading searcher.')

    def search(self, x, k):
        if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
            self.train_searcher(k)   # quickly fit searcher on the fly for small databases
        assert self.searcher is not None, 'Cannot search with uninitialized searcher'
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if len(x.shape) == 3:
            x = x[:, 0]
        query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]

        start = time.time()
        nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
        end = time.time()

        out_embeddings = self.database['embedding'][nns]
        out_img_ids = self.database['img_id'][nns]
        out_pc = self.database['patch_coords'][nns]

        out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
               'img_ids': out_img_ids,
               'patch_coords': out_pc,
               'queries': x,
               'exec_time': end - start,
               'nns': nns,
               'q_embeddings': query_embeddings}

        return out

    def __call__(self, x, n):
        return self.search(x, n)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
    # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )

    parser.add_argument(
        "--skip_grid",
        action='store_true',
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )

    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--n_repeat",
        type=int,
        default=1,
        help="number of repeats in CLIP latent space",
    )

    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=768,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--W",
        type=int,
        default=768,
        help="image width, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=3,
        help="how many samples to produce for each given prompt. A.k.a batch size",
    )

    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )

    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )

    parser.add_argument(
        "--config",
        type=str,
        default="configs/retrieval-augmented-diffusion/768x768.yaml",
        help="path to config which constructs model",
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        default="models/rdm/rdm768x768/model.ckpt",
        help="path to checkpoint of model",
    )

    parser.add_argument(
        "--clip_type",
        type=str,
        default="ViT-L/14",
        help="which CLIP model to use for retrieval and NN encoding",
    )
    parser.add_argument(
        "--database",
        type=str,
        default='artbench-surrealism',
        choices=DATABASES,
        help="The database used for the search, only applied when --use_neighbors=True",
    )
    parser.add_argument(
        "--use_neighbors",
        default=False,
        action='store_true',
        help="Include neighbors in addition to text prompt for conditioning",
    )
    parser.add_argument(
        "--knn",
        default=10,
        type=int,
        help="The number of included neighbors, only applied when --use_neighbors=True",
    )

    opt = parser.parse_args()

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

    batch_size = opt.n_samples
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    if not opt.from_file:
        prompt = opt.prompt
        assert prompt is not None
        data = [batch_size * [prompt]]

    else:
        print(f"reading prompts from {opt.from_file}")
        with open(opt.from_file, "r") as f:
            data = f.read().splitlines()
            data = list(chunk(data, batch_size))

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1

    print(f"sampling scale for cfg is {opt.scale:.2f}")

    searcher = None
    if opt.use_neighbors:
        searcher = Searcher(opt.database)

    with torch.no_grad():
        with model.ema_scope():
            for n in trange(opt.n_iter, desc="Sampling"):
                all_samples = list()
                for prompts in tqdm(data, desc="data"):
                    print("sampling prompts:", prompts)
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = clip_text_encoder.encode(prompts)
                    uc = None
                    if searcher is not None:
                        nn_dict = searcher(c, opt.knn)
                        c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
                    if opt.scale != 1.0:
                        uc = torch.zeros_like(c)
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    shape = [16, opt.H // 16, opt.W // 16]  # note: currently hardcoded for f16 model
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                     conditioning=c,
                                                     batch_size=c.shape[0],
                                                     shape=shape,
                                                     verbose=False,
                                                     unconditional_guidance_scale=opt.scale,
                                                     unconditional_conditioning=uc,
                                                     eta=opt.ddim_eta,
                                                     )

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    for x_sample in x_samples_ddim:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                    all_samples.append(x_samples_ddim)

                if not opt.skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

                    # to image
                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
                    grid_count += 1

    print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
