import hashlib
import json
import os
import shutil
from glob import glob

import accelerate
import click
import torch
from PIL import Image, UnidentifiedImageError
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from misc.scores import get_score
from misc.utils import EasyDict


class HPDv2Benchmark(Dataset):
    def __init__(
        self,
        prompt_path: str,
        image_dir: str,
        transform=None,
        writer=print,
    ):
        self.transform = transform
        self.writer = writer

        self.prompts = json.load(open(prompt_path))
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.jpg")))

        if len(self.prompts) > len(self.image_paths):
            raise ValueError(
                f"Number of prompts ({len(self.prompts)}) is greater than the "
                f"number of images ({len(self.image_paths)}). Please check the "
                f"prompt file: {prompt_path} and the image directory: "
                f"{image_dir}")

        self.image_paths = self.image_paths[:len(self.prompts)]

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index):
        prompt = self.prompts[index]
        image_path = self.image_paths[index]
        try:
            image = Image.open(image_path)
        except UnidentifiedImageError as e:
            self.writer(f"Corrupted image at index {index}: {e}")
            return self.__getitem__((index + 1) % len(self))
        if self.transform:
            image = self.transform(image)

        return {
            "prompt": prompt,
            "image": image,
            "path": image_path,
        }


@click.command(context_settings={'show_default': True})
@click.option(
    "--cache", default="./data/cache", type=str,
    help="Cache directory for the `load_dataset`."
)
@click.option(
    "--hpdv2_root", default="./data/hpdv2", type=str,
    help="The root directory of the HPDv2 dataset."
)
@click.option(
    "--output", default="./data/hpdv2_{style}", type=str,
    help="The path where the labels should be saved."
)
@click.option(
    "--num_proc", default=8, type=int,
    help=(
        "The number of processes to use when loading and filtering the "
        "dataset."
    )
)
@click.option(
    "--batch_size", default=64, type=int,
    help="The batch size to use when filtering the dataset."
)
def main(**kwargs):
    args = EasyDict(kwargs)

    accelerator = accelerate.Accelerator()
    device = accelerator.device

    # Create the cache directory
    if accelerator.is_main_process:
        os.makedirs(args.cache, exist_ok=True)

    model_names = []
    for model_dirs in glob(os.path.join(args.hpdv2_root, "benchmark/benchmark_imgs/*/")):
        model_name = os.path.basename(os.path.dirname(model_dirs))
        model_names.append(model_name)
    model_names = sorted(model_names)

    # Load the HPSv2 preprocessor and model
    compute_score, processor = get_score(args.score, device)

    for style in ["anime", "concept-art", "paintings", "photo"]:
        pbar = tqdm(
            model_names,
            desc=f"Style {style}",
            disable=not accelerator.is_main_process
        )
        model_scores = []
        model_paths = []
        for model_name in pbar:
            dataset = HPDv2Benchmark(
                prompt_path=os.path.join(args.hpdv2_root, f"benchmark/{style}.json"),
                image_dir=os.path.join(args.hpdv2_root, f"benchmark/benchmark_imgs/{model_name}/{style}"),
                transform=processor,
            )
            model_paths.append(dataset.image_paths)

            cache_path = os.path.join(
                args.cache, f"hpdv2_hps{args.hps_version}_benchmark_{model_name}_{style}.json")
            if os.path.exists(cache_path):
                scores = json.load(open(cache_path))
                if accelerator.is_main_process:
                    pbar.write(f"Loaded {model_name}/{style} from cache")
            else:
                loader = DataLoader(
                    dataset,
                    batch_size=args.batch_size,
                    num_workers=args.num_proc,
                )
                # Let the accelerator handle the last batch
                loader = accelerator.prepare(loader)

                scores = []
                pbar_loader = tqdm(
                    loader,
                    desc=f"Model {model_name}",
                    disable=not accelerator.is_main_process,
                    leave=False,
                )
                for batch in pbar_loader:
                    captions = batch['prompt']
                    batch_scores = compute_score(batch['image'], captions)
                    batch_scores = accelerator.gather_for_metrics(batch_scores).cpu()
                    scores.extend(batch_scores.tolist())
                    accelerator.wait_for_everyone()
                pbar_loader.close()
                if accelerator.is_main_process:
                    pbar.write(f"Saving {model_name}/{style} to cache")

                with open(cache_path, 'w') as f:
                    json.dump(scores, f)
            assert len(scores) == len(dataset)
            model_scores.append(scores)
        pbar.close()

        model_scores = torch.tensor(model_scores)
        best_model_index = model_scores.argmax(dim=0)
        prompts = json.load(open(os.path.join(args.hpdv2_root, f"benchmark/{style}.json")))

        if accelerator.is_main_process:
            pbar = tqdm(
                total=len(prompts),
                desc="Saving",
            )
            for image_index, (model_index, prompt) in enumerate(zip(best_model_index, prompts)):
                source_path = model_paths[model_index][image_index]
                ext = os.path.splitext(source_path)[1]
                dirname = hashlib.sha1(prompt.encode()).hexdigest()
                output_dir = os.path.join(args.output.format(style=style), dirname)
                os.makedirs(output_dir, exist_ok=True)
                # Copy the image if it doesn't exist
                target_path = os.path.join(output_dir, f"{image_index}{ext}")
                if not os.path.exists(target_path):
                    shutil.copy(source_path, target_path)
                # Write the caption if it doesn't exist
                caption_path = os.path.join(output_dir, "caption.txt")
                if not os.path.exists(caption_path):
                    with open(caption_path, 'w') as f:
                        f.write(prompt)
                # Write the metadata if it doesn't exist
                metadata_path = os.path.join(output_dir, f"{image_index}.json")
                if not os.path.exists(metadata_path):
                    with open(metadata_path, 'w') as f:
                        json.dump({
                            "model": model_names[model_index],
                            "score": model_scores[model_index][image_index].item(),
                        }, f)
                pbar.update(1)
            pbar.close()
        accelerator.wait_for_everyone()


if __name__ == '__main__':
    main()
