import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm, trange

import laft

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    parser.add_argument("-m", "--model-name", default="ViT-B-16:dfn2b")
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")
    assert args.guidance.startswith("guide_"), "winclip only supports guide"

    ds = laft.prompts.get_ds(args.dataset_name)
    model, preprocess = laft.baselines.load_winclip(args.model_name)
    data = laft.get_dataset(args.dataset_name, ds.DATASET_CONFIG, preprocess)

    ls = ds.get(data, args.guidance)
    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    train_subset = ls["train_features"]
    test_subset = ls["test_features"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]
    prompt_ind = ls["prompt_ind"]
    prompt_ood = ls["prompt_ood"]

    metrics = {
        attend_name: {},
        ignore_name: {},
    }
    n_train = len(train_subset)

    model.setup(normal_prompts=prompt_ind, anomaly_prompts=prompt_ood)

    table_path = os.path.join(args.results_dir, ds.DATASET_NAME, "winclip", f"{args.guidance}.txt")
    test_loader = DataLoader(test_subset, batch_size=64, num_workers=4, persistent_workers=True)

    for i, n_sample in enumerate([0, 1, 2, 4, 8, 16, 32, 64, 128, 256]):
        if n_sample > 0:
            attend_metrics_list = []
            ignore_metrics_list = []

            for seed in trange(args.n_seeds, ncols=80, leave=False, desc="Seeds"):
                rng = torch.Generator().manual_seed(seed)
                idxs = torch.randperm(n_train, generator=rng)[:n_sample].tolist()
                reference_images = torch.stack([train_subset[i][0] for i in idxs]).cuda()
                model.setup(reference_images=reference_images)

                scores_list = []

                for batch in tqdm(test_loader, ncols=80, leave=False, desc=f"{n_sample}-shot"):
                    images, _ = batch
                    scores_list.append(model(images.cuda()).cpu())

                attend_metrics_list.append(laft.binary_metrics(torch.cat(scores_list), test_attend_labels))
                ignore_metrics_list.append(laft.binary_metrics(torch.cat(scores_list), test_ignore_labels))

            metrics[attend_name][f"{n_sample}"] = attend_metrics_list
            metrics[ignore_name][f"{n_sample}"] = ignore_metrics_list

        else:
            scores_list = []

            for batch in tqdm(test_loader, ncols=80, leave=False, desc=f"{n_sample}-shot"):
                images, _ = batch
                scores_list.append(model(images.cuda()).cpu())

            metrics[attend_name][f"{n_sample}"] = laft.binary_metrics(torch.cat(scores_list), test_attend_labels)
            metrics[ignore_name][f"{n_sample}"] = laft.binary_metrics(torch.cat(scores_list), test_ignore_labels)

        table = laft.utils.build_table(metrics, group_headers=("#Shot",))
        print(table if i == 0 else table.split("\n")[-1])
        laft.utils.save_table(table, table_path)

    # Print
    table = laft.utils.build_table(metrics, group_headers=("#Shot",))
    print(table)
    laft.utils.save_table(table, table_path)
