import hashlib
import os
from collections import defaultdict

import click
import torch
import tqdm
from torch.utils.data import DataLoader, Subset
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
from torchvision.utils import save_image

import training.dataset
from lpips import LPIPS


def get_images_by_indices(tgt_dataset, gen_dataset, tgt_indices, gen_indices):
    assert len(tgt_indices) == len(gen_indices)
    images = []
    for tgt_i, gen_i in zip(tgt_indices, gen_indices):
        gen_image, _, _ = gen_dataset[gen_i]    # ndarray, [0, 255]
        tgt_image, _, _ = tgt_dataset[tgt_i]    # ndarray, [0, 255]
        images.append(torch.tensor(gen_image).float() / 255)
        images.append(torch.tensor(tgt_image).float() / 255)
    images = torch.stack(images)
    return images


def get_embeddings(loader, model, transform, name_for_cache, device):
    os.makedirs(".cache", exist_ok=True)
    cache_path = os.path.join(".cache", hashlib.sha256(os.path.realpath(name_for_cache).encode('utf-8')).hexdigest() + ".pt")

    if os.path.exists(cache_path):
        embeddings = torch.load(cache_path).to(device)
        print(f"Loaded embeddings of '{name_for_cache}' from cache {cache_path}")
    else:
        embeddings = []
        with tqdm.trange(len(loader.dataset), leave=False, desc="Calculate Embedding") as pbar:
            for images, _, _ in loader:
                images = torch.stack([transform(image) for image in images])
                images = images.to(device)
                with torch.no_grad():
                    embeddings.append(model(images))
                pbar.update(len(images))
        embeddings = torch.cat(embeddings)
        torch.save(embeddings, cache_path)
    return embeddings


@click.command()
@click.option('--target', 'target_path',       help='Path to the dataset directory', metavar='PATH',                              type=str, required=True)
@click.option('--generated', 'generated_path', help='Path to the generated images', metavar='PATH',                               type=str, required=True)
@click.option('--sscd-model',                  help='Path to the torchscript of SSCD model', metavar='PATH',                      type=str, default="./sscd/sscd_disc_mixup.torchscript.pt", show_default=True)
@click.option('--sscd-thresholds',             help='SSCD threshold for LPIPS, SSIM, L2 and dissimilar images', metavar='FLOAT+', type=float, default=[0.5, 0.7], multiple=True, show_default=True)
@click.option('--batch', 'batch_size',         help='Batch size', metavar='INT',                                                  type=int, default=1024, show_default=True)
@click.option('--save', 'save_path',           help='Save directory for the most similar and dissimilar images', metavar='PATH',  type=str, default=None, show_default=True)
def main(target_path, generated_path, sscd_model, sscd_thresholds, batch_size, save_path, device=torch.device('cuda:0')):
    tgt_dataset = training.dataset.ImageFolderDataset(path=target_path)
    gen_dataset = training.dataset.ImageFolderDataset(path=generated_path)

    tgt_loader = DataLoader(tgt_dataset, batch_size=batch_size, num_workers=4)
    gen_loader = DataLoader(gen_dataset, batch_size=batch_size, num_workers=4)

    lpips_fn = LPIPS(net='vgg', version="0.1", verbose=False)
    lpips_fn = lpips_fn.to(device)
    lpips_fn.requires_grad_(False)

    # https://github.com/facebookresearch/sscd-copy-detection?tab=readme-ov-file#inference-using-sscd-models
    model = torch.jit.load(sscd_model).to(device)
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(288),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

    tgt_embeddings = get_embeddings(tgt_loader, model, transform, target_path, device)
    gen_embeddings = get_embeddings(gen_loader, model, transform, generated_path, device)

    # Calculate SSCD and find the top-1 most similar mapping
    gen2tgt = []
    similarities = []
    with tqdm.trange(len(gen_embeddings), leave=False) as pbar:
        for batch_embeddings in gen_embeddings.split(batch_size):
            batch_similarities = torch.einsum('ij,kj->ik', batch_embeddings, tgt_embeddings).cpu()
            batch_max = torch.max(batch_similarities, dim=1)
            gen2tgt.append(batch_max.indices)
            similarities.append(batch_max.values)
            pbar.update(len(batch_embeddings))
    gen2tgt = torch.cat(gen2tgt)
    similarities = torch.cat(similarities)

    gen_indices = (similarities > 0.5).nonzero().flatten()
    tgt_indices = gen2tgt[gen_indices]
    # torch.save((gen_indices, tgt_indices), "afhqv2_EDM_ours_indices.pt")
    # torch.save((gen_indices, tgt_indices), "afhqv2_EDM_LTA_indices.pt")
    # torch.save((gen_indices, tgt_indices), "afhqv2_EDM_Dup_indices.pt")
    torch.save((gen_indices, tgt_indices), "afhqv2_EDM_indices.pt")
    exit(0)

    if save_path:
        # Masking out the low similarity
        for threshold in sscd_thresholds:
            mask = similarities > threshold
            masked_gen2tgt = gen2tgt[mask]
            masked_gen2gen = torch.nonzero(mask).flatten()
            masked_similarities = similarities[mask]

            k = min(32, len(masked_gen2tgt))
            if k == 0:
                print(f"SSCD > {threshold:.2f}: No images found")
                continue

            # Save the most similar and dissimilar images
            # topk_similar = torch.topk(masked_similarities, k=k, largest=True).indices
            # similar_images = get_images_by_indices(
            #     tgt_dataset,
            #     gen_dataset,
            #     tgt_indices=masked_gen2tgt[topk_similar],
            #     gen_indices=masked_gen2gen[topk_similar],
            # )
            # save_image(similar_images, os.path.join(save_path, 'similar' + f'{threshold:g}'.replace(".", "") + '.png'))

            topk_dissimilar = torch.topk(masked_similarities, k=k, largest=False).indices
            dissimilar_images = get_images_by_indices(
                tgt_dataset,
                gen_dataset,
                tgt_indices=masked_gen2tgt[topk_dissimilar],
                gen_indices=masked_gen2gen[topk_dissimilar],
            )
            save_image(dissimilar_images, os.path.join(save_path, 'dissimilar' + f'{threshold:g}'.replace(".", "") + '.png'))

        exit(0)

    # Calculate precision, recall, F1
    for threshold in [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]:
        precision = torch.mean((similarities > threshold).float())
        recall = gen2tgt[similarities > threshold].unique().size(0) / len(tgt_embeddings)
        f1 = 2 * precision * recall / (precision + recall)
        print(f"Threshold: {threshold:.2f}, Precision: {precision:.5f}, Recall: {recall:.5f}, F1: {f1:.5f}")
    print(f"SSCD : {similarities.mean():.5f} ({similarities.std():.5f})")

    gen_loader = DataLoader(gen_dataset, batch_size=batch_size, num_workers=4)
    tgt_loader = DataLoader(
        Subset(tgt_dataset, gen2tgt),   # Reorder the target dataset to match the generated dataset
        batch_size=batch_size,
        num_workers=4)
    metric_results = defaultdict(list)
    with tqdm.tqdm(total=len(gen_dataset), desc="Calculating LPIPS, SSIM and L2", leave=False) as pbar:
        for (gen_images, _, _), (tgt_images, _, _) in zip(gen_loader, tgt_loader):
            # gen_images, tgt_images: int8, [0, 255]

            # SSIM
            for gen_image, tgt_image in zip(gen_images, tgt_images):
                val_ssim = ssim(tgt_image.numpy(), gen_image.numpy(), data_range=255., channel_axis=0)
                metric_results['ssim'].append(val_ssim)

            # LPIPS
            tgt_images_ = gen_images.float().cuda() / 127.5 - 1  # float32, [-1, 1]
            gen_images_ = tgt_images.float().cuda() / 127.5 - 1  # float32, [-1, 1]
            val_lpips = lpips_fn.forward(tgt_images_, gen_images_,).cpu()
            val_lpips = val_lpips[:, 0, 0, 0]
            metric_results['lpips'].append(val_lpips)

            # normalized L2-norm
            tgt_images_ = gen_images.float() / 255  # float32, [0, 1]
            gen_images_ = tgt_images.float() / 255  # float32, [0, 1]
            val_l2 = (tgt_images_ - gen_images_).pow(2).mean(dim=[1, 2, 3]).sqrt()
            metric_results['l2'].append(val_l2)

            pbar.update(len(gen_images))

    metric_results['ssim'] = torch.tensor(metric_results['ssim'])
    metric_results['lpips'] = torch.cat(metric_results['lpips'])
    metric_results['l2'] = torch.cat(metric_results['l2'])

    for metric_name in ["ssim", "lpips", "l2"]:
        print(f"{metric_name.upper():5s}: {metric_results[metric_name].mean():.5f}")


if __name__ == '__main__':
    main()
