import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm, trange

import laft
from laft.baselines import load_winclip
from laft.prompts.industrial import get_prompts


torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16-plus-240:laion400m_e31")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-c", "--category", type=str, required=True)
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    args = parser.parse_args()

    assert args.dataset_name in ("mvtec",)

    model, preprocess = load_winclip(args.model_name)

    category = args.category

    metrics = {
        category: {},
    }
    table_path = os.path.join(args.results_dir, args.dataset_name, "winclip", f"{args.category}.txt")

    train_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="train", transform=preprocess)
    test_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="test", transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=8, num_workers=2, shuffle=False)
    n_train = len(train_dataset)

    normal_prompts, anomaly_prompts = get_prompts(category)
    model.setup(normal_prompts=normal_prompts, anomaly_prompts=anomaly_prompts)

    n_samples = [0] + [2 ** v for v in range(0, 9) if 2 ** v < n_train] + [n_train]

    for i, n_sample in enumerate(n_samples):
        if n_sample > 0:
            metrics_list = []

            for seed in trange(args.n_seeds if n_sample < n_train else 1, ncols=80, leave=False, desc="Seeds"):
                if n_sample < n_train:
                    rng = torch.Generator().manual_seed(seed)
                    idxs = torch.randperm(n_train, generator=rng)[:n_sample].tolist()
                else:
                    idxs = list(range(n_train))

                reference_images = torch.stack([train_dataset[i][0] for i in idxs]).cuda()
                model.setup(reference_images=reference_images)

                scores_list, labels_list = [], []

                for images, labels in tqdm(test_loader, ncols=80, leave=False, desc=f"{n_sample}-shot"):
                    scores_list.append(model(images.cuda()).cpu())
                    labels_list.append(labels.view(-1))

                metrics_list.append(laft.binary_metrics(torch.cat(scores_list), torch.cat(labels_list)))

            metrics[category][f"{n_sample}"] = metrics_list if len(metrics_list) > 1 else metrics_list[0]

        else:
            scores_list, labels_list = [], []

            for images, labels in tqdm(test_loader, ncols=80, leave=False, desc=f"{n_sample}-shot"):
                scores_list.append(model(images.cuda()).cpu())
                labels_list.append(labels.view(-1))

            metrics[category][f"{n_sample}"] = laft.binary_metrics(torch.cat(scores_list), torch.cat(labels_list))

        table = laft.utils.build_table(metrics, group_headers=("#Shot",))
        print(table if i == 0 else table.split("\n")[-1])
        laft.utils.save_table(table, table_path)

    table = laft.utils.build_table(metrics, group_headers=("#Shot",))
    laft.utils.save_table(table, table_path)
