import argparse
import os
import pickle
from typing import Callable, Tuple

import numpy as np
import torch
from sg_utils import get_dataloader
from torchvision import transforms
from tqdm import tqdm


def record_similarity_feature(
    imagenet_path: str,
    get_similarity_features_fn: Callable[[torch.tensor], torch.Tensor],
    preprocess_fn: Callable,
    use_validation_set: bool = False,
    batch_size: int = 32,
    device: str = "cpu",
) -> dict[str, np.ndarray]:
    # Get center crop as done in psychophysics experiments.
    preprocess_fn = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), preprocess_fn]
    )

    source = "val" if use_validation_set else "train"
    use_webdataset = imagenet_path.endswith(".tar")
    ds_path = imagenet_path if use_webdataset else os.path.join(imagenet_path, source)
    dataloader = get_dataloader(
        ds_path,
        model_name=None,
        batch_size=batch_size,
        return_indices=False,
        transform=preprocess_fn,
        use_webdataset=use_webdataset,
    )

    results = {}
    with torch.no_grad():
        for batch_number, (batch, labels, paths) in tqdm(
            enumerate(dataloader), total=len(dataloader)
        ):
            if batch.ndim == 5:
                batch = batch.squeeze(1)
            batch = batch.to(device)
            features = get_similarity_features_fn(batch).cpu().numpy()

            for f, p in zip(features, paths):
                results[p] = f.astype(np.float32)
    return results


def setup_dreamsim_model(
    device: str = "cpu",
) -> Tuple[Callable[[torch.tensor], torch.tensor], Callable]:
    from dreamsim import dreamsim

    dreamsim_model, dreamsim_preprocess = dreamsim(
        pretrained=True,
        cache_dir=os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch") + "/dreamsim"),
    )
    dreamsim_model = dreamsim_model.to(device)
    get_features_fn = lambda images: dreamsim_model.embed(images)

    return get_features_fn, dreamsim_preprocess


def main():
    # Parse arguments.
    # Ask for output folder
    # Ask for similarity function to use.

    available_similarity_functions = {
        "dreamsim": setup_dreamsim_model,
    }

    device = "cuda" if torch.cuda.is_available() else "cpu"

    parser = argparse.ArgumentParser()
    parser.add_argument("--output-filename", type=str, required=True)
    parser.add_argument(
        "--similarity-function",
        type=str,
        required=True,
        choices=available_similarity_functions.keys(),
    )
    parser.add_argument("--use-validation-set", action="store_true")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument(
        "--imagenet-path",
        type=str,
        required=True,
        help="Path to imagenet dataset. If it ends with .tar, it is "
        "assumed to be a webdataset.",
    )
    args = parser.parse_args()

    if os.path.exists(args.output_filename):
        print("Output file already exists. Aborting.")
        return

    get_features_fn, preprocess_fn = available_similarity_functions[
        args.similarity_function
    ](device)
    results = record_similarity_feature(
        args.imagenet_path,
        get_features_fn,
        preprocess_fn,
        args.use_validation_set,
        args.batch_size,
        device,
    )

    with open(args.output_filename, "wb") as f:
        pickle.dump(results, f)


if __name__ == "__main__":
    main()
