import os
import sys
import argparse
from collections import defaultdict

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

import torch

from tqdm.auto import tqdm, trange

import laft
from laft.prompts.industrial import get_prompts


torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16-plus-240:laion400m_e31")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-k", "--n-neighbors", type=int, default=30)
    parser.add_argument("-c", "--category", type=str, required=True)
    parser.add_argument("-ns", "--n-seeds", type=int, default=5)
    args = parser.parse_args()

    assert args.dataset_name in ("mvtec",)

    model, data = laft.get_cached_features(args.model_name, f"{args.dataset_name}:{args.category}", {}, splits=("train", "test"))
    print()

    category = args.category

    train_features, train_attrs = data["train"]
    test_features, test_attrs = data["test"]
    test_attend_labels = test_attrs[:, 0]

    normal_prompts, anomaly_prompts = get_prompts(category)
    prompts = normal_prompts + anomaly_prompts

    features = model.encode_text(prompts).float()
    pairs = laft.prompt_pair(features)
    concept_basis = laft.pca(pairs)

    n_train = train_features.size(0)
    n_samples = [2 ** v for v in range(0, 9) if 2 ** v < n_train] + [n_train]

    for i, n_sample in enumerate(tqdm(n_samples, ncols=80, leave=False, desc="Shots")):
        metrics = {
            category: defaultdict(list),
        }

        for seed in trange(args.n_seeds if n_sample < n_train else 1, ncols=80, leave=False, desc="Seeds"):
            if n_sample < n_train:
                rng = torch.Generator().manual_seed(seed)
                idxs = torch.randperm(n_train, generator=rng)[:n_sample]
                train_sub_features = train_features[idxs]
            else:
                train_sub_features = train_features

            scores = laft.knn(train_sub_features, test_features, n_neighbors=min(args.n_neighbors, n_sample))
            metrics[category]["KNN"].append(laft.binary_metrics(scores, test_attend_labels))

            for j, n_components in enumerate(range(2, concept_basis.size(0) + 1)):
                train_laft_features = laft.inner(train_sub_features, concept_basis[:n_components])
                test_laft_features = laft.inner(test_features, concept_basis[:n_components])

                scores = laft.knn(train_laft_features, test_laft_features, n_neighbors=min(args.n_neighbors, n_sample))
                metrics[category][f"{n_components}"].append(laft.binary_metrics(scores, test_attend_labels))

        for k, v in metrics[category].items():
            if len(v) == 1:
                metrics[category][k] = v[0]

        table = laft.utils.build_table(metrics, group_headers=("Comp.",))
        table_path = os.path.join(args.results_dir, args.dataset_name, "laft", f"{args.category}", f"{n_sample}-{args.n_neighbors}.txt")
        print(table)
        laft.utils.save_table(table, table_path)
