import copy
import hashlib
import io
import json
import os
from collections import defaultdict
from functools import partial
from time import time
from typing import Callable, Dict

import accelerate
import click
import torch
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from misc.scores import get_score
from misc.utils import EasyDict


@click.command(context_settings={'show_default': True})
# General options
@click.option(
    "--cache", default="./data/cache", type=str,
    help="Cache directory for the `load_dataset`."
)
@click.option(
    "--version", default="v2", type=click.Choice(["v1", "v2"]),
    help="The version of the Pick-a-Pic dataset to load."
)
@click.option(
    "--split", default="train", type=str,
    help="The split of the dataset to load."
)
@click.option(
    "--score", default="pickscore", type=click.Choice([
        "pickscore", "hpsv2", "aesthetic", "clip", "imagereward"]),
    help="The score to filter the dataset."
)
@click.option(
    "--top", default=1000, type=int,
    help="The number of top images (sorted by score) to retrieve."
)
@click.option(
    "--per_caption/--no-per_caption", default=False,
    help=(
        "Save the top images per caption. Default is to save the top images "
        "overall."
    )
)
@click.option(
    "--num_proc", default=8, type=int,
    help=(
        "The number of processes to use when loading and filtering the "
        "dataset."
    )
)
@click.option(
    "--batch_size", default=64, type=int,
    help="The batch size to use when filtering the dataset."
)
@click.option(
    "--output", default="./data/pickapicv2_pickscore_1k", type=str,
    help="The path where the labels should be saved."
)
def main(**kwargs):
    args = EasyDict(kwargs)
    score_cache_path = os.path.join(
        args.cache, f"pickapic{args.version}_{args.split}_{args.score}.jsonl")

    calc_scores(args, score_cache_path)
    save_expert(args, score_cache_path)


def calc_scores(args: EasyDict, score_cache_path: str):
    accelerator = accelerate.Accelerator()
    device = accelerator.device

    # Create the cache directory
    if accelerator.is_main_process:
        cache_dir = os.path.dirname(score_cache_path)
        os.makedirs(cache_dir, exist_ok=True)
    else:
        disable_progress_bar()

    # Load the already done captions
    done_captions = set()
    if os.path.exists(score_cache_path):
        accelerator.print(f"Loading done_captions from {score_cache_path} ... ", end="")
        start_time = time()
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                done_captions.add(caption_uid2score['caption'])
        elapsed_time = time() - start_time
        accelerator.print(f"{elapsed_time:.2f}s")

    # Load the dataset without images to speed up the process
    dataset = load_dataset(
        f"yuvalkirstain/pickapic_{args.version}_no_images", split=args.split)
    dataset = dataset.select_columns(['caption', 'image_0_uid', 'image_1_uid'])

    # Prepare the mapping from caption to uids
    start_time = time()
    caption2uids = defaultdict(set)
    with tqdm(
        dataset,
        desc="Collecting UIDs",
        leave=False,
        disable=not accelerator.is_main_process,
    ) as pbar:
        for example in pbar:
            caption, uid0, uid1 = example['caption'], example['image_0_uid'], example['image_1_uid']
            if caption not in done_captions:
                caption2uids[caption].add(uid0)
                caption2uids[caption].add(uid1)
            pbar.set_postfix_str(
                f"Unique: {len(caption2uids) + len(done_captions)}, "
                f"Todo: {len(caption2uids)}, "
                f"Done: {len(done_captions)}")
    elapsed_time = time() - start_time
    accelerator.print(f"Collecting UIDs ... {elapsed_time:.2f}s")

    # Add index to dataset for subset selection
    dataset = dataset.map(
        lambda examples, indices: {"index": indices},
        with_indices=True,
        batched=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding Indices",
    )

    # Filter the dataset index to read as few rows as possible
    def filter_fn(examples, caption2uids):
        keep = []
        for caption, uid0, uid1 in zip(
            examples['caption'], examples['image_0_uid'], examples['image_1_uid']
        ):
            keep_example = False
            if caption in caption2uids:
                if uid0 in caption2uids[caption]:
                    keep_example = True
                    caption2uids[caption].remove(uid0)
                if uid1 in caption2uids[caption]:
                    keep_example = True
                    caption2uids[caption].remove(uid1)
                if len(caption2uids[caption]) == 0:
                    del caption2uids[caption]
            keep.append(keep_example)
        return keep
    caption2uids_copy = copy.deepcopy(caption2uids)
    dataset = dataset.filter(
        partial(filter_fn, caption2uids=caption2uids_copy),
        batched=True,
        batch_size=args.batch_size,
        load_from_cache_file=False,
        desc="Filtering Indices",
    )
    # Collect the indices that contains the remaining captions
    indices = sorted(dataset['index'])
    # Sanity check
    assert len(caption2uids_copy) == 0, f"Remaining captions: {len(caption2uids_copy)}"

    # Exit if there are no more captions to process
    if len(indices) == 0:
        if not accelerator.is_main_process:
            accelerator.end_training()
            exit(0)
        else:
            accelerator.end_training()
            return

    # Load the score model and image processor
    compute_score, processor = get_score(args.score, device)

    # Load the full dataset
    pickapic = load_dataset(
        f"yuvalkirstain/pickapic_{args.version}",
        split=args.split,
        num_proc=args.num_proc)
    # Select only the necessary columns
    pickapic = pickapic.select_columns(
        ['caption', 'image_0_uid', 'image_1_uid', 'jpg_0', 'jpg_1'])
    # Select the remaining indices
    pickapic = pickapic.select(indices)

    # Transform jpg bytes to PIL.Image and then to torch.Tensor
    def transform(examples: Dict[str, list], processor: Callable) -> Dict[str, list]:
        examples['image_0'] = []
        examples['image_1'] = []
        for jpg_0, jpg_1 in zip(examples["jpg_0"], examples["jpg_1"]):
            with io.BytesIO(jpg_0) as jpg_0_bytes:
                image_0 = Image.open(jpg_0_bytes).copy()
            with io.BytesIO(jpg_1) as jpg_1_bytes:
                image_1 = Image.open(jpg_1_bytes).copy()
            examples["image_0"].append(processor(image_0))
            examples["image_1"].append(processor(image_1))
        return examples
    pickapic = pickapic.with_transform(partial(transform, processor=processor))

    def collat_fn(items):
        # List of dictionaries to dictionary of lists
        batch = {key: [item[key] for item in items] for key in items[0]}
        # Stack the torch.Tensor objects
        for key in batch:
            if isinstance(batch[key][0], torch.Tensor):
                batch[key] = torch.stack(batch[key])
        return batch
    loader = DataLoader(
        pickapic,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
        collate_fn=collat_fn,
    )

    # Let the accelerator handle the last batch
    loader = accelerator.prepare(loader)

    # Run inference on the Pickapic dataset
    rel_err_pass = 0
    caption2uid2score = defaultdict(dict)
    progress_bar = tqdm(
        loader,
        desc=args.score,
        disable=not accelerator.is_main_process,
    )
    for batch in progress_bar:
        captions = batch['caption']
        scores_0 = compute_score(batch['image_0'], captions)
        scores_1 = compute_score(batch['image_1'], captions)

        # Gather the results
        captions = accelerator.gather_for_metrics(captions, use_gather_object=True)
        captions = captions + captions
        scores_0 = accelerator.gather_for_metrics(scores_0)
        scores_1 = accelerator.gather_for_metrics(scores_1)
        scores = torch.cat([scores_0, scores_1])
        uids0 = accelerator.gather_for_metrics(batch['image_0_uid'], use_gather_object=True)
        uids1 = accelerator.gather_for_metrics(batch['image_1_uid'], use_gather_object=True)
        uids = uids0 + uids1

        # Update the caption2uid2score mapping
        if accelerator.is_main_process:
            for caption, score, uid in zip(captions, scores, uids):
                if caption in caption2uids:
                    if uid in caption2uids[caption]:
                        caption2uid2score[caption][uid] = score.item()
                        caption2uids[caption].remove(uid)
                        # Save the caption if all the uids are processed
                        if len(caption2uids[caption]) == 0:
                            line = json.dumps({
                                'caption': caption,
                                'uid2score': caption2uid2score[caption],
                            })
                            with open(score_cache_path, 'a') as f:
                                f.write(line + '\n')
                            done_captions.add(caption)
                            caption2uids.pop(caption)
                    elif uid in caption2uid2score[caption]:
                        # Check the relative error
                        err = abs(caption2uid2score[caption][uid] - score.item())
                        rel_err = err / caption2uid2score[caption][uid]
                        if rel_err > 0.01:
                            progress_bar.write(
                                f"Large Relative Error: {rel_err:.4f}. "
                                f"UID: {uid}. Caption: {caption}")
                        else:
                            rel_err_pass += 1
                num_done = len(done_captions)
                num_total = len(caption2uids) + num_done
                progress_bar.set_postfix_str(
                    f"Finished captions: {num_done}/{num_total}"
                    f" ({num_done / num_total:.2%}). "
                    f"Relative Error Pass: {rel_err_pass}")
        accelerator.wait_for_everyone()

    progress_bar.close()

    if not accelerator.is_main_process:
        accelerator.end_training()
        exit(0)
    else:
        accelerator.end_training()
        torch.cuda.empty_cache()
        return


def save_expert(args, score_cache_path):
    dataset = load_dataset(
        f"yuvalkirstain/pickapic_{args.version}_no_images", split=args.split)
    dataset = dataset.select_columns([
        'caption', 'has_label', 'best_image_uid', 'image_0_uid', 'image_1_uid'
    ])

    print(f"Loading caption2uid2score from {score_cache_path} ... ", end="")
    start_time = time()
    num_images = 0
    scores = 0
    caption2uids = defaultdict(set)  # Target caption and its uids
    if args.per_caption:
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                caption = caption_uid2score['caption']
                uid2score = caption_uid2score['uid2score']
                assert caption not in caption2uids, f"Duplicate caption: {caption}"
                score_uid = [(score, uid) for uid, score in uid2score.items()]
                score_uid = sorted(score_uid, reverse=True)
                size = max(min(args.top, len(score_uid)), 1)
                caption2uids[caption] = {uid for _, uid in score_uid[:size]}
                scores += sum(score for score, _ in score_uid[:size])
                num_images += size
    else:
        all_items = []
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                caption = caption_uid2score['caption']
                uid2score = caption_uid2score['uid2score']
                score_uid_caption = [(score, uid, caption) for uid, score in uid2score.items()]
                all_items.extend(score_uid_caption)
        all_items = sorted(all_items, reverse=True)

        # Prepare the highest scoring images
        size = max(min(args.top, len(all_items)), 1)
        with tqdm(all_items[:size], desc=f"Highest {size} Image UIDs", leave=False) as pbar:
            for score, uid, caption in pbar:
                caption2uids[caption].add(uid)
                scores += score
        num_images = size
    elapsed_time = time() - start_time
    print(f"{elapsed_time:.2f}s")
    print(f"Average score: {scores / num_images:.4f}")

    # Prepare the remaining indices
    # Load the dataset without images to speed up the process
    dataset = load_dataset(
        f"yuvalkirstain/pickapic_{args.version}_no_images", split=args.split)
    dataset = dataset.select_columns(['caption', 'image_0_uid', 'image_1_uid'])
    dataset = dataset.map(
        lambda examples, indices: {"index": indices},
        with_indices=True,
        batched=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        desc="Adding Indices",
    )

    # Filter the dataset index to read as few rows as possible
    def filter_fn(examples, caption2uids):
        keep = []
        for caption, uid0, uid1 in zip(
            examples['caption'], examples['image_0_uid'], examples['image_1_uid']
        ):
            keep_example = False
            if caption in caption2uids:
                if uid0 in caption2uids[caption]:
                    keep_example = True
                    caption2uids[caption].remove(uid0)
                if uid1 in caption2uids[caption]:
                    keep_example = True
                    caption2uids[caption].remove(uid1)
                if len(caption2uids[caption]) == 0:
                    del caption2uids[caption]
            keep.append(keep_example)
        return keep
    caption2uids_copy = copy.deepcopy(caption2uids)
    dataset = dataset.filter(
        partial(filter_fn, caption2uids=caption2uids_copy),
        batched=True,
        batch_size=args.batch_size,
        load_from_cache_file=False,
        desc="Filtering Indices",
    )
    # Collect the indices that contains the remaining captions
    indices = sorted(dataset['index'])
    # Sanity check
    assert len(caption2uids_copy) == 0, f"Remaining captions: {len(caption2uids_copy)}"

    # Load the full dataset
    pickapic = load_dataset(
        f"yuvalkirstain/pickapic_{args.version}",
        split=args.split,
        num_proc=args.num_proc)
    pickapic = pickapic.select(indices)
    pickapic = pickapic.select_columns(
        ['caption', 'image_0_uid', 'image_1_uid', 'jpg_0', 'jpg_1'])
    loader = DataLoader(
        pickapic,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
    )

    # Save the images
    with tqdm(loader, desc="Saving Subset Images") as pbar:
        saved_counter = 0
        for batch in pbar:
            for caption, uid0, uid1, jpg0, jpg1 in zip(
                batch['caption'],
                batch['image_0_uid'], batch['image_1_uid'],
                batch['jpg_0'], batch['jpg_1'],
            ):
                if caption in caption2uids:
                    caption_dir = os.path.join(
                        args.output, hashlib.sha1(caption.encode()).hexdigest())
                    # Save the caption
                    if not os.path.exists(caption_dir):
                        os.makedirs(caption_dir, exist_ok=True)
                        with open(os.path.join(caption_dir, 'caption.txt'), 'w') as f:
                            f.write(caption)
                    # Save the images
                    if uid0 in caption2uids[caption]:
                        image_path = os.path.join(caption_dir, f"{uid0}.png")
                        if not os.path.exists(image_path):
                            with io.BytesIO(jpg0) as jpg_bytes:
                                Image.open(jpg_bytes).save(image_path)
                        saved_counter += 1
                        caption2uids[caption].remove(uid0)
                    if uid1 in caption2uids[caption]:
                        image_path = os.path.join(caption_dir, f"{uid1}.png")
                        if not os.path.exists(image_path):
                            with io.BytesIO(jpg1) as jpg_bytes:
                                Image.open(jpg_bytes).save(image_path)
                        saved_counter += 1
                        caption2uids[caption].remove(uid1)
                pbar.set_postfix_str(
                    f"Saved images: {saved_counter}/{num_images} "
                    f"({saved_counter / num_images:.2%})")


if __name__ == '__main__':
    main()
