import copy
import hashlib
import json
import os
import shutil
from collections import defaultdict
from time import time

import accelerate
import click
import torch
from PIL import Image, UnidentifiedImageError
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm

from misc.scores import get_score
from misc.utils import EasyDict


class HPDv2(Dataset):
    def __init__(
        self,
        split="train",
        data_dirs="./data/hpdv2",
        transform=None,
        select_column=None,
        writer=print,
    ):
        self.split = split
        self.data_dirs = data_dirs
        self.transform = transform
        self.select_column = select_column
        self.writer = writer
        self.metadata = json.load(open(os.path.join(data_dirs, f"{split}.json")))

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        sample = self.metadata[index]
        if self.split == 'test':
            raw_annotations = [d['annotation'] for d in sample['raw_annotations']]
            user_hash = [d['user_hash'] for d in sample['raw_annotations']]

        item = {
            "caption": sample["prompt"],
            "human_preference": sample["human_preference"] if self.split == 'train' else [],
            "rank": sample['rank'] if self.split == "test" else [],
            "raw_annotations": raw_annotations if self.split == "test" else [],
            "user_hash": user_hash if self.split == "test" else [],
            "uid0": sample['image_path'][0],
            "uid1": sample['image_path'][1],
            "path0": os.path.join(self.data_dirs, self.split, sample['image_path'][0]),
            "path1": os.path.join(self.data_dirs, self.split, sample['image_path'][1]),
        }

        if self.select_column:
            if 'image0' in self.select_column:
                try:
                    image0 = Image.open(item['path0'])
                except UnidentifiedImageError as e:
                    self.writer(f"Corrupted image0 at index {index}: {e}")
                    return self.__getitem__((index + 1) % len(self))
                if self.transform:
                    image0 = self.transform(image0)
                item['image0'] = image0

            if 'image1' in self.select_column:
                try:
                    image1 = Image.open(item['path1'])
                except UnidentifiedImageError as e:
                    self.writer(f"Corrupted image1 at index {index}: {e}")
                    return self.__getitem__((index + 1) % len(self))
                if self.transform:
                    image1 = self.transform(image1)
                item['image1'] = image1

            item = {key: item[key] for key in self.select_column}

        return item


@click.command(context_settings={'show_default': True})
@click.option(
    "--cache", default="./data/cache", type=str,
    help="Cache directory for the `load_dataset`."
)
@click.option(
    "--split", default="train", type=str,
    help="The split of the dataset to load."
)
@click.option(
    "--score", default="pickscore", type=click.Choice(["pickscore", "hpsv2"]),
    help="The score to filter the dataset."
)
@click.option(
    "--top", default=1000, type=int,
    help="The number of top images (sorted by score) to retrieve."
)
@click.option(
    "--per_caption/--no-per_caption", default=False,
    help=(
        "Save the top images per caption. Default is to save the top images "
        "overall."
    )
)
@click.option(
    "--preferred/--no-preferred", default=True,
    help="Only consider the preferred images."
)
@click.option(
    "--num_proc", default=8, type=int,
    help=(
        "The number of processes to use when loading and filtering the "
        "dataset."
    )
)
@click.option(
    "--batch_size", default=64, type=int,
    help="The batch size to use when filtering the dataset."
)
@click.option(
    "--output", default="./data/hpdv2_hpsv2_1k", type=str,
    help="The path where the labels should be saved."
)
def main(**kwargs):
    args = EasyDict(kwargs)
    score_cache_path = os.path.join(
        args.cache, f"hpdv2_{args.split}_{args.score}.jsonl")

    calc_scores(args, score_cache_path)
    save_expert(args, score_cache_path)


def calc_scores(args: EasyDict, score_cache_path: str):
    accelerator = accelerate.Accelerator()
    device = accelerator.device

    # Create the cache directory
    if accelerator.is_main_process:
        cache_dir = os.path.dirname(score_cache_path)
        os.makedirs(cache_dir, exist_ok=True)

    # Load the already done captions
    done_captions = set()
    if os.path.exists(score_cache_path):
        accelerator.print(f"Loading done_captions from {score_cache_path} ... ", end="")
        start_time = time()
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                done_captions.add(caption_uid2score['caption'])
        elapsed_time = time() - start_time
        accelerator.print(f"{elapsed_time:.2f}s")

    # Load the dataset without images to speed up the process
    dataset = HPDv2(args.split, select_column=['caption', 'uid0', 'uid1'])

    # Prepare the mapping from caption to uid
    start_time = time()
    caption2uids = defaultdict(set)
    with tqdm(
        dataset,
        desc="Collecting UIDs",
        disable=not accelerator.is_main_process,
    ) as pbar:
        for example in pbar:
            caption, uid0, uid1 = example['caption'], example['uid0'], example['uid1']
            if caption not in done_captions:
                caption2uids[caption].add(uid0)
                caption2uids[caption].add(uid1)
            pbar.set_postfix_str(
                f"Unique: {len(caption2uids) + len(done_captions)}, "
                f"Todo: {len(caption2uids)}, "
                f"Done: {len(done_captions)}")
    elapsed_time = time() - start_time
    accelerator.print(f"Collecting UIDs ... {elapsed_time:.2f}s")

    # Filter the dataset index to read as few rows as possible
    caption2uids_copy = copy.deepcopy(caption2uids)
    indices = []
    with tqdm(
        dataset,
        desc="Filtering Indices",
        disable=not accelerator.is_main_process,
    ) as pbar:
        for i, example in enumerate(pbar):
            caption, uid0, uid1 = example['caption'], example['uid0'], example['uid1']
            keep_example = False
            if caption in caption2uids_copy:
                if uid0 in caption2uids_copy[caption]:
                    keep_example = True
                    caption2uids_copy[caption].remove(uid0)
                if uid1 in caption2uids_copy[caption]:
                    keep_example = True
                    caption2uids_copy[caption].remove(uid1)
                if len(caption2uids_copy[caption]) == 0:
                    del caption2uids_copy[caption]
            # Collect the indices that contains the remaining captions
            if keep_example:
                indices.append(i)
    # Sanity check
    assert len(caption2uids_copy) == 0, f"Remaining captions: {len(caption2uids_copy)}"

    # Exit if there are no more captions to process
    if len(indices) == 0:
        if not accelerator.is_main_process:
            accelerator.end_training()
            exit(0)
        else:
            accelerator.end_training()
            return

    # Load the score model and image processor
    compute_score, processor = get_score(args.score, device)

    # Load the full dataset
    dataset = HPDv2(
        args.split,
        transform=processor,
        select_column=['caption', 'image0', 'image1', 'uid0', 'uid1'])
    # Select the remaining indices
    subset = Subset(dataset, indices)

    loader = DataLoader(
        subset,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
    )

    # Let the accelerator handle the last batch
    loader = accelerator.prepare(loader)

    # Run inference on the HPDv2 dataset
    rel_err_pass = 0
    caption2uid2score = defaultdict(dict)
    progress_bar = tqdm(
        loader,
        desc="HPDv2",
        disable=not accelerator.is_main_process,
    )
    subset.dataset.writer = progress_bar.write
    for batch in progress_bar:
        captions = batch['caption']
        scores_0 = compute_score(batch['image0'], captions)
        scores_1 = compute_score(batch['image1'], captions)

        # Gather the results
        captions = accelerator.gather_for_metrics(captions, use_gather_object=True)
        captions = captions + captions
        scores0 = accelerator.gather_for_metrics(scores_0)
        scores1 = accelerator.gather_for_metrics(scores_1)
        scores = torch.cat([scores0, scores1])
        uids0 = accelerator.gather_for_metrics(batch['uid0'], use_gather_object=True)
        uids1 = accelerator.gather_for_metrics(batch['uid1'], use_gather_object=True)
        uids = uids0 + uids1

        # Update the caption2uid2score mapping
        if accelerator.is_main_process:
            for caption, score, uid in zip(captions, scores, uids):
                if caption in caption2uids:
                    if uid in caption2uids[caption]:
                        caption2uid2score[caption][uid] = score.item()
                        caption2uids[caption].remove(uid)
                        if len(caption2uids[caption]) == 0:
                            line = json.dumps({
                                'caption': caption,
                                'uid2score': caption2uid2score[caption],
                            })
                            with open(score_cache_path, 'a') as f:
                                f.write(line + '\n')
                            done_captions.add(caption)
                            caption2uids.pop(caption)
                    elif uid in caption2uid2score[caption]:
                        # Check the relative error
                        err = abs(caption2uid2score[caption][uid] - score.item())
                        rel_err = err / caption2uid2score[caption][uid]
                        if rel_err > 0.01:
                            progress_bar.write(
                                f"Large Relative Error: {rel_err:.4f}. "
                                f"UID: {uid}. Caption: {caption}")
                        else:
                            rel_err_pass += 1
                num_done = len(done_captions)
                num_total = len(caption2uids) + num_done
                progress_bar.set_postfix_str(
                    f"Finished captions: {num_done}/{num_total}"
                    f" ({num_done / num_total:.2%}). "
                    f"Relative Error Pass: {rel_err_pass}")
        accelerator.wait_for_everyone()

    progress_bar.close()

    if not accelerator.is_main_process:
        accelerator.end_training()
        exit(0)
    else:
        accelerator.end_training()
        torch.cuda.empty_cache()
        return


def save_expert(args, score_cache_path):
    dataset = HPDv2(args.split, select_column=['caption', 'uid0', 'uid1', 'human_preference'])

    # Consider the preference annotations
    if args.preferred:
        # Prepare the mapping from caption to non-preferred uids
        caption2non_preferred_uids = defaultdict(set)
        with tqdm(dataset, desc="Collecting non-Preferred Names") as pbar:
            for example in pbar:
                caption, uid0, uid1 = example['caption'], example['uid0'], example['uid1']
                if example['human_preference'][0] == 0:
                    non_preferred_uid = uid0
                elif example['human_preference'][1] == 0:
                    non_preferred_uid = uid1
                else:
                    assert False, f"Invalid preference: {example['human_preference']}"
                caption2non_preferred_uids[caption].add(non_preferred_uid)

        # Prepare the mapping from caption to preferred uids
        caption2preferred_uids = defaultdict(set)
        with tqdm(dataset, desc="Collecting Preferred Names") as pbar:
            for example in pbar:
                caption, uid0, uid1 = example['caption'], example['uid0'], example['uid1']
                if example['human_preference'][0] == 1:
                    preferred_uid = uid0
                elif example['human_preference'][1] == 1:
                    preferred_uid = uid1
                else:
                    assert False, f"Invalid preference: {example['human_preference']}"
                if preferred_uid not in caption2non_preferred_uids[caption]:
                    caption2preferred_uids[caption].add(preferred_uid)

    # Load hpsv2 results
    print(f"Loading caption2uid2score from {score_cache_path} ... ", end="")
    start_time = time()
    num_images = 0
    scores = 0
    caption2uids = defaultdict(set)  # Target caption and its uids
    if args.per_caption:
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                caption = caption_uid2score['caption']
                uid2score = caption_uid2score['uid2score']
                assert caption not in caption2uids, f"Duplicate caption: {caption}"
                score_uid = [(score, uid) for uid, score in uid2score.items()]
                if args.preferred:
                    score_uid = list(filter(lambda x: x[1] in caption2preferred_uids[caption], score_uid))
                score_uid = sorted(score_uid, reverse=True)
                size = max(min(args.top, len(score_uid)), 1)
                caption2uids[caption] = {uid for _, uid in score_uid[:size]}
                scores += sum(score for score, _ in score_uid[:size])
                num_images += size
    else:
        all_items = []
        with open(score_cache_path) as f:
            for line in f:
                caption_uid2score = json.loads(line)
                caption = caption_uid2score['caption']
                uid2score = caption_uid2score['uid2score']
                score_uid_caption = [(score, uid, caption) for uid, score in uid2score.items()]
                if args.preferred:
                    score_uid_caption = list(filter(
                        lambda x: x[1] in caption2preferred_uids[caption],
                        score_uid_caption)
                    )
                all_items.extend(score_uid_caption)
        all_items = sorted(all_items, reverse=True)

        # Prepare the highest scoring images
        size = max(min(args.top, len(all_items)), 1)
        with tqdm(all_items[:size]) as pbar:
            pbar.set_description(f"Highest {size} Image Names")
            for score, uid, caption in pbar:
                caption2uids[caption].add(uid)
                scores += score
        num_images = size
    elapsed_time = time() - start_time
    print(f"{elapsed_time:.2f}s")
    print(f"Average score: {scores / num_images:.4f}")

    # Filter the dataset index to read as few rows as possible
    caption2uids_copy = copy.deepcopy(caption2uids)
    indices = []
    with tqdm(dataset, desc="Filtering Indices") as pbar:
        for i, example in enumerate(pbar):
            caption, uid0, uid1 = example['caption'], example['uid0'], example['uid1']
            keep_example = False
            if caption in caption2uids_copy:
                if uid0 in caption2uids_copy[caption]:
                    keep_example = True
                    caption2uids_copy[caption].remove(uid0)
                if uid1 in caption2uids_copy[caption]:
                    keep_example = True
                    caption2uids_copy[caption].remove(uid1)
                if len(caption2uids_copy[caption]) == 0:
                    del caption2uids_copy[caption]
            # Collect the indices that contains the remaining captions
            if keep_example:
                indices.append(i)
    assert len(caption2uids_copy) == 0, f"Remaining captions: {len(caption2uids_copy)}"

    # Load the full dataset
    dataset = HPDv2(args.split, select_column=['caption', 'path0', 'path1', 'uid0', 'uid1'])
    subset = Subset(dataset, indices)
    loader = DataLoader(
        subset,
        batch_size=args.batch_size,
        num_workers=args.num_proc,
    )

    # Save the images
    with tqdm(loader, desc="Saving Subset Images") as pbar:
        saved_counter = 0
        for batch in pbar:
            for caption, path0, path1, uid0, uid1 in zip(
                batch['caption'],
                batch['path0'], batch['path1'], batch['uid0'], batch['uid1']
            ):
                if caption in caption2uids:
                    caption_dir = os.path.join(
                        args.output, hashlib.sha1(caption.encode()).hexdigest())
                    # Save the caption
                    if not os.path.exists(caption_dir):
                        os.makedirs(caption_dir, exist_ok=True)
                        with open(os.path.join(caption_dir, 'caption.txt'), 'w') as f:
                            f.write(caption)
                    # Save the images
                    if uid0 in caption2uids[caption]:
                        image0_path = os.path.join(caption_dir, uid0)
                        if not os.path.exists(image0_path):
                            shutil.copy(path0, image0_path)
                        caption2uids[caption].remove(uid0)
                        saved_counter += 1
                    if uid1 in caption2uids[caption]:
                        image1_path = os.path.join(caption_dir, uid1)
                        if not os.path.exists(image1_path):
                            shutil.copy(path1, image1_path)
                        caption2uids[caption].remove(uid1)
                        saved_counter += 1
                pbar.set_postfix_str(
                    f"Saved images: {saved_counter}/{num_images} "
                    f"({saved_counter / num_images:.2%})")


if __name__ == '__main__':
    main()
