import click
import torch
import tqdm
from torchvision import transforms

import training.dataset
from metrics import get_embeddings


@click.command()
@click.option('--target', 'target_path',       help='Path to the dataset directory', metavar='PATH',                              type=str, required=True)
@click.option('--generated', 'generated_path', help='Path to the generated images', metavar='PATH',                               type=str, required=True)
@click.option('--loss', 'loss_path',           help='Path to the loss file', metavar='PATH',                                      type=str, required=True)
@click.option('--sscd-model',                  help='Path to the torchscript of SSCD model', metavar='PATH',                      type=str, default="./sscd/sscd_disc_mixup.torchscript.pt", show_default=True)
@click.option('--sscd-thresholds',             help='SSCD threshold for LPIPS, SSIM, L2 and dissimilar images', metavar='FLOAT+', type=float, default=[0.5, 0.7], multiple=True, show_default=True)
@click.option('--batch', 'batch_size',         help='Batch size', metavar='INT',                                                  type=int, default=1024, show_default=True)
@click.option('--fpt', 'fpt_upperbound',       help='False positive rate upperbound', metavar='FLOAT',                            type=float, default=0.01, show_default=True)
# @click.option('--unique_labels',               help='Generate with unique labels',                             type=parse_int_list, default=None, show_default=True)
def main(target_path, generated_path, loss_path, sscd_model, sscd_thresholds, batch_size, fpt_upperbound, device=torch.device('cuda:0')):
    tgt_dataset = training.dataset.ImageFolderDataset(path=target_path)
    gen_dataset = training.dataset.ImageFolderDataset(path=generated_path)

    tgt_loader = torch.utils.data.DataLoader(tgt_dataset, batch_size=batch_size, num_workers=4)
    gen_loader = torch.utils.data.DataLoader(gen_dataset, batch_size=batch_size, num_workers=4)

    # https://github.com/facebookresearch/sscd-copy-detection?tab=readme-ov-file#inference-using-sscd-models
    model = torch.jit.load(sscd_model).to(device)
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(288),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

    tgt_embeddings = get_embeddings(tgt_loader, model, transform, target_path, device)
    gen_embeddings = get_embeddings(gen_loader, model, transform, generated_path, device)

    # Calculate SSCD and find the top-1 most similar mapping
    gen2tgt = []
    similarities = []
    with tqdm.trange(len(gen_embeddings), leave=False) as pbar:
        for batch_embeddings in gen_embeddings.split(batch_size):
            batch_similarities = torch.einsum('ij,kj->ik', batch_embeddings, tgt_embeddings).cpu()
            batch_max = torch.max(batch_similarities, dim=1)
            gen2tgt.append(batch_max.indices)
            similarities.append(batch_max.values)
            pbar.update(len(batch_embeddings))
    gen2tgt = torch.cat(gen2tgt)
    similarities = torch.cat(similarities)

    loss = torch.load(loss_path)
    assert len(loss) == len(gen_embeddings)
    N, T = loss.shape
    for sscd_threshold in sscd_thresholds:
        print(f"SSCD > {sscd_threshold:.2f}")

        gt = similarities >= sscd_threshold
        pN = gt.sum()
        fN = (~gt).sum()

        tpr_t = []
        loss_threshold_t = []
        for t in range(T):
            loss_t = loss[:, t]
            max_tpr = 0
            max_tpr_loss_threshold = 0
            for loss_threshold in tqdm.tqdm(torch.sort(loss_t).values, leave=False, desc=f"t = {t}"):
                tpr = (gt & (loss_t <= loss_threshold)).sum() / pN
                fpr = ((~gt) & (loss_t <= loss_threshold)).sum() / fN

                # fpr is monotonically increasing
                if fpr <= fpt_upperbound:
                    if tpr > max_tpr:
                        max_tpr = tpr
                        max_tpr_loss_threshold = loss_threshold
                else:
                    break
            tpr_t.append(max_tpr)
            loss_threshold_t.append(max_tpr_loss_threshold)
            print(f"t = {t}: TPR = {max_tpr:.4f}, Loss Threshold = {max_tpr_loss_threshold:.4f}")


if __name__ == "__main__":
    main()
