import copy
import hashlib
import io
import json
import os
from collections import defaultdict
from functools import partial
from time import time
from typing import Callable, Dict

import accelerate
import click
import torch
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from misc.scores import get_score
from misc.utils import EasyDict


@click.command(context_settings={'show_default': True})
# General options
@click.option(
    "--cache", default="./data/cache", type=str,
    help="Cache directory for the `load_dataset`."
)
@click.option(
    "--version", default="v1", type=click.Choice(["v1"]),
    help="The version of the open-image-preferences dataset to load."
)
@click.option(
    "--split", default="train", type=str,
    help="The split of the dataset to load."
)
@click.option(
    "--score", default="pickscore", type=click.Choice([
        "pickscore", "hpsv2", "aesthetic", "clip", "imagereward"]),
    help="The score to filter the dataset."
)
@click.option(
    "--top", default=500, type=int,
    help="The number of top images (sorted by score) to retrieve."
)
@click.option(
    "--per_prompt/--no-per_prompt", default=False,
    help=(
        "Save the top images per prompt. Default is to save the top images "
        "overall."
    )
)
@click.option(
    "--num_proc", default=8, type=int,
    help=(
        "The number of processes to use when loading and filtering the "
        "dataset."
    )
)
@click.option(
    "--batch_size", default=64, type=int,
    help="The batch size to use when filtering the dataset."
)
@click.option(
    "--output", default="./data/openimagepreferencesv1_pickscore_500", type=str,
    help="The path where the labels should be saved."
)
def main(**kwargs):
    args = EasyDict(kwargs)
    score_cache_path = os.path.join(
        args.cache, f"openimagepreferences-{args.version}_{args.split}_{args.score}.jsonl")

    calc_scores(args, score_cache_path)
    save_expert(args, score_cache_path)


def calc_scores(args: EasyDict, score_cache_path: str):
    accelerator = accelerate.Accelerator()
    device = accelerator.device

    # Create the cache directory
    if accelerator.is_main_process:
        cache_dir = os.path.dirname(score_cache_path)
        os.makedirs(cache_dir, exist_ok=True)
    else:
        disable_progress_bar()

    # Load the already done prompts
    done_prompts = set()
    if os.path.exists(score_cache_path):
        accelerator.print(f"Loading done_prompts from {score_cache_path} ... ", end="")
        start_time = time()
        with open(score_cache_path) as f:
            for line in f:
                prompt_uid2score = json.loads(line)
                done_prompts.add(prompt_uid2score['prompt'])
        elapsed_time = time() - start_time
        accelerator.print(f"{elapsed_time:.2f}s")

    # Load the dataset without images to speed up the process
    dataset = load_dataset(
        f"data-is-better-together/open-image-preferences-{args.version}-binarized", split=args.split)
    dataset = dataset.select_columns(['id', 'prompt'])
    dataset = dataset.map(
        lambda examples, indices: {
            "index": indices,
            "uid_chosen": [f"{id}_chosen" for id in examples['id']],
            "uid_rejected": [f"{id}_rejected" for id in examples['id']]},
        batched=True,
        with_indices=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding UIDs",
    )

    # Prepare the mapping from prompt to uids
    start_time = time()
    prompt2uids = defaultdict(set)
    with tqdm(
        dataset,
        desc="Collecting UIDs",
        leave=False,
        disable=not accelerator.is_main_process,
    ) as pbar:
        for example in pbar:
            prompt, uid_chosen, uid_rejected = \
                example['prompt'], example['uid_chosen'], example['uid_rejected']
            if prompt not in done_prompts:
                prompt2uids[prompt].add(uid_chosen)
                prompt2uids[prompt].add(uid_rejected)
            pbar.set_postfix_str(
                f"Unique: {len(prompt2uids) + len(done_prompts)}, "
                f"Todo: {len(prompt2uids)}, "
                f"Done: {len(done_prompts)}")
    elapsed_time = time() - start_time
    accelerator.print(f"Collecting UIDs ... {elapsed_time:.2f}s")

    # Filter the dataset index to read as few rows as possible
    def filter_fn(examples, prompt2uids):
        keep = []
        for prompt, uid_chosen, uid_rejected in zip(
            examples['prompt'],
            examples['uid_chosen'],
            examples['uid_rejected'],
        ):
            keep_example = False
            if prompt in prompt2uids:
                if uid_chosen in prompt2uids[prompt]:
                    keep_example = True
                    prompt2uids[prompt].remove(uid_chosen)
                if uid_rejected in prompt2uids[prompt]:
                    keep_example = True
                    prompt2uids[prompt].remove(uid_rejected)
                if len(prompt2uids[prompt]) == 0:
                    del prompt2uids[prompt]
            keep.append(keep_example)
        return keep
    prompt2uids_copy = copy.deepcopy(prompt2uids)
    dataset = dataset.filter(
        partial(filter_fn, prompt2uids=prompt2uids_copy),
        batched=True,
        batch_size=args.batch_size,
        load_from_cache_file=False,
        desc="Filtering Indices",
    )
    # Collect the indices that contains the remaining prompts
    indices = sorted(dataset['index'])
    # Sanity check
    assert len(prompt2uids_copy) == 0, f"Remaining prompts: {len(prompt2uids_copy)}"

    # Exit if there are no more prompts to process
    if len(indices) == 0:
        if not accelerator.is_main_process:
            accelerator.end_training()
            exit(0)
        else:
            accelerator.end_training()
            return

    # Load the score model and image processor
    compute_score, processor = get_score(args.score, device)

    # Load the full dataset
    openimagepreferences = load_dataset(
        f"data-is-better-together/open-image-preferences-{args.version}-binarized",
        split=args.split,
        num_proc=args.num_proc)
    # Select only the necessary columns
    openimagepreferences = openimagepreferences.select_columns(
        ['id', 'prompt', 'chosen', 'rejected'])
    openimagepreferences = openimagepreferences.map(
        lambda examples: {
            "uid_chosen": [f"{id}_chosen" for id in examples['id']],
            "uid_rejected": [f"{id}_rejected" for id in examples['id']]},
        batched=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding UIDs",
    )
    # Select the remaining indices
    openimagepreferences = openimagepreferences.select(indices)

    # Transform jpg bytes to PIL.Image and then to torch.Tensor
    def transform(examples: Dict[str, list], processor: Callable) -> Dict[str, list]:
        examples['image_chosen'] = []
        examples['image_rejected'] = []
        for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
            with io.BytesIO(chosen['bytes']) as chosen_bytes:
                image_chosen = Image.open(chosen_bytes).copy()
            with io.BytesIO(rejected['bytes']) as rejected_bytes:
                image_rejected = Image.open(rejected_bytes).copy()
            examples["image_chosen"].append(processor(image_chosen))
            examples["image_rejected"].append(processor(image_rejected))
        return examples
    openimagepreferences = openimagepreferences.with_transform(partial(transform, processor=processor))

    def collat_fn(items):
        # List of dictionaries to dictionary of lists
        batch = {key: [item[key] for item in items] for key in items[0]}
        # Stack the torch.Tensor objects
        for key in batch:
            if isinstance(batch[key][0], torch.Tensor):
                batch[key] = torch.stack(batch[key])
        return batch
    loader = DataLoader(
        openimagepreferences,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
        collate_fn=collat_fn,
    )

    # Let the accelerator handle the last batch
    loader = accelerator.prepare(loader)

    # Run inference on the open-image-preferences dataset
    rel_err_pass = 0
    prompt2uid2score = defaultdict(dict)
    progress_bar = tqdm(
        loader,
        desc=args.score,
        disable=not accelerator.is_main_process,
    )
    for batch in progress_bar:
        prompts = batch['prompt']
        scores_chosen = compute_score(batch['image_chosen'], prompts)
        scores_rejected = compute_score(batch['image_rejected'], prompts)

        # Gather the results
        prompts = accelerator.gather_for_metrics(prompts, use_gather_object=True)
        prompts = prompts + prompts
        scores_chosen = accelerator.gather_for_metrics(scores_chosen)
        scores_rejected = accelerator.gather_for_metrics(scores_rejected)
        scores = torch.cat([scores_chosen, scores_rejected])
        uids_chosen = accelerator.gather_for_metrics(batch['uid_chosen'], use_gather_object=True)
        uids_rejected = accelerator.gather_for_metrics(batch['uid_rejected'], use_gather_object=True)
        uids = uids_chosen + uids_rejected

        # Update the prompt2uid2score mapping
        if accelerator.is_main_process:
            for prompt, score, uid in zip(prompts, scores, uids):
                if prompt in prompt2uids:
                    if uid in prompt2uids[prompt]:
                        prompt2uid2score[prompt][uid] = score.item()
                        prompt2uids[prompt].remove(uid)
                        # Save the prompt if all the uids are processed
                        if len(prompt2uids[prompt]) == 0:
                            line = json.dumps({
                                'prompt': prompt,
                                'uid2score': prompt2uid2score[prompt],
                            })
                            with open(score_cache_path, 'a') as f:
                                f.write(line + '\n')
                            done_prompts.add(prompt)
                            prompt2uids.pop(prompt)
                    elif uid in prompt2uid2score[prompt]:
                        # Check the relative error
                        err = abs(prompt2uid2score[prompt][uid] - score.item())
                        rel_err = err / prompt2uid2score[prompt][uid]
                        if rel_err > 0.01:
                            progress_bar.write(
                                f"Large Relative Error: {rel_err:.4f}. "
                                f"UID: {uid}. prompt: {prompt}")
                        else:
                            rel_err_pass += 1
                num_done = len(done_prompts)
                num_total = len(prompt2uids) + num_done
                progress_bar.set_postfix_str(
                    f"Finished prompts: {num_done}/{num_total}"
                    f" ({num_done / num_total:.2%}). "
                    f"Relative Error Pass: {rel_err_pass}")
        accelerator.wait_for_everyone()

    progress_bar.close()

    if not accelerator.is_main_process:
        accelerator.end_training()
        exit(0)
    else:
        accelerator.end_training()
        torch.cuda.empty_cache()
        return


def save_expert(args, score_cache_path):
    # dataset = load_dataset(
    #     f"data-is-better-together/open-image-preferences-{args.version}-binarized", split=args.split)
    # dataset = dataset.select_columns([
    #     'id', 'prompt', 'uid_chosen', 'uid_rejected'
    # ])

    print(f"Loading prompt2uid2score from {score_cache_path} ... ", end="")
    start_time = time()
    num_images = 0
    scores = 0
    prompt2uids = defaultdict(set)  # Target prompt and its uids
    if args.per_prompt:
        with open(score_cache_path) as f:
            for line in f:
                prompt_uid2score = json.loads(line)
                prompt = prompt_uid2score['prompt']
                uid2score = prompt_uid2score['uid2score']
                assert prompt not in prompt2uids, f"Duplicate prompt: {prompt}"
                score_uid = [(score, uid) for uid, score in uid2score.items()]
                score_uid = sorted(score_uid, reverse=True)
                size = max(min(args.top, len(score_uid)), 1)
                prompt2uids[prompt] = {uid for _, uid in score_uid[:size]}
                scores += sum(score for score, _ in score_uid[:size])
                num_images += size
    else:
        all_items = []
        with open(score_cache_path) as f:
            for line in f:
                prompt_uid2score = json.loads(line)
                prompt = prompt_uid2score['prompt']
                uid2score = prompt_uid2score['uid2score']
                score_uid_prompt = [(score, uid, prompt) for uid, score in uid2score.items()]
                all_items.extend(score_uid_prompt)
        all_items = sorted(all_items, reverse=True)

        # Prepare the highest scoring images
        size = max(min(args.top, len(all_items)), 1)
        with tqdm(all_items[:size], desc=f"Highest {size} Image UIDs", leave=False) as pbar:
            for score, uid, prompt in pbar:
                prompt2uids[prompt].add(uid)
                scores += score
        num_images = size
    elapsed_time = time() - start_time
    print(f"{elapsed_time:.2f}s")
    print(f"Average score: {scores / num_images:.4f}")

    # Prepare the remaining indices
    # Load the dataset without images to speed up the process
    dataset = load_dataset(
        f"data-is-better-together/open-image-preferences-{args.version}-binarized", split=args.split)
    dataset = dataset.select_columns(['id', 'prompt'])
    dataset = dataset.map(
        lambda examples, indices: {
            "index": indices,
            "uid_chosen": [f"{id}_chosen" for id in examples['id']],
            "uid_rejected": [f"{id}_rejected" for id in examples['id']]},
        batched=True,
        with_indices=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding UIDs",
    )

    # Filter the dataset index to read as few rows as possible
    def filter_fn(examples, prompt2uids):
        keep = []
        for prompt, uid_chosen, uid_rejected in zip(
            examples['prompt'], examples['uid_chosen'], examples['uid_rejected']
        ):
            keep_example = False
            if prompt in prompt2uids:
                if uid_chosen in prompt2uids[prompt]:
                    keep_example = True
                    prompt2uids[prompt].remove(uid_chosen)
                if uid_rejected in prompt2uids[prompt]:
                    keep_example = True
                    prompt2uids[prompt].remove(uid_rejected)
                if len(prompt2uids[prompt]) == 0:
                    del prompt2uids[prompt]
            keep.append(keep_example)
        return keep
    prompt2uids_copy = copy.deepcopy(prompt2uids)
    dataset = dataset.filter(
        partial(filter_fn, prompt2uids=prompt2uids_copy),
        batched=True,
        batch_size=args.batch_size,
        load_from_cache_file=False,
        desc="Filtering Indices",
    )
    # Collect the indices that contains the remaining prompts
    indices = sorted(dataset['index'])
    # Sanity check
    assert len(prompt2uids_copy) == 0, f"Remaining prompts: {len(prompt2uids_copy)}"

    # Load the full dataset
    openimagepreferences = load_dataset(
        f"data-is-better-together/open-image-preferences-{args.version}-binarized",
        split=args.split,
        num_proc=args.num_proc)
    openimagepreferences = openimagepreferences.select(indices)
    openimagepreferences = openimagepreferences.select_columns(
        ['id', 'prompt', 'chosen', 'rejected'])
    openimagepreferences = openimagepreferences.map(
        lambda examples, indices: {
            "index": indices,
            "uid_chosen": [f"{id}_chosen" for id in examples['id']],
            "uid_rejected": [f"{id}_rejected" for id in examples['id']]},
        batched=True,
        with_indices=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding UIDs",
    )

    def collat_fn(items):
        # List of dictionaries to dictionary of lists
        batch = {key: [item[key] for item in items] for key in items[0]}
        # Stack the torch.Tensor objects
        for key in batch:
            if isinstance(batch[key][0], torch.Tensor):
                batch[key] = torch.stack(batch[key])
        return batch
    loader = DataLoader(
        openimagepreferences,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
        collate_fn=collat_fn,
    )

    # Save the images
    with tqdm(loader, desc="Saving Subset Images") as pbar:
        saved_counter = 0
        for batch in pbar:
            for prompt, uid_chosen, uid_rejected, chosen, rejected in zip(
                batch['prompt'],
                batch['uid_chosen'], batch['uid_rejected'],
                batch['chosen'], batch['rejected'],
            ):
                if prompt in prompt2uids:
                    prompt_dir = os.path.join(
                        args.output, hashlib.sha1(prompt.encode()).hexdigest())
                    # Save the prompt
                    if not os.path.exists(prompt_dir):
                        os.makedirs(prompt_dir, exist_ok=True)
                        with open(os.path.join(prompt_dir, 'caption.txt'), 'w') as f:
                            f.write(prompt)
                    # Save the images
                    if uid_chosen in prompt2uids[prompt]:
                        image_path = os.path.join(prompt_dir, f"{uid_chosen}.png")
                        if not os.path.exists(image_path):
                            with io.BytesIO(chosen['bytes']) as jpg_bytes:
                                Image.open(jpg_bytes).save(image_path)
                        saved_counter += 1
                        prompt2uids[prompt].remove(uid_chosen)
                    if uid_rejected in prompt2uids[prompt]:
                        image_path = os.path.join(prompt_dir, f"{uid_rejected}.png")
                        if not os.path.exists(image_path):
                            with io.BytesIO(rejected['bytes']) as jpg_bytes:
                                Image.open(jpg_bytes).save(image_path)
                        saved_counter += 1
                        prompt2uids[prompt].remove(uid_rejected)
                pbar.set_postfix_str(
                    f"Saved images: {saved_counter}/{num_images} "
                    f"({saved_counter / num_images:.2%})")


if __name__ == '__main__':
    main()
