import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm, trange

import laft
from laft.baselines import load_inctrl
from laft.prompts.industrial import get_prompts

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--checkpoint", type=str, default="data/checkpoints/inctrl_8.pt")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-c", "--category", type=str, required=True)
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    args = parser.parse_args()

    assert args.dataset_name in ("mvtec",)

    model, preprocess = load_inctrl(args.checkpoint)

    category = args.category

    metrics = {
        category: {},
    }
    checkpoint_name = os.path.basename(args.checkpoint).replace(".pt", "")
    table_path = os.path.join(args.results_dir, args.dataset_name, "inctrl", f"{args.category}-{checkpoint_name}.txt")

    train_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="train", transform=preprocess)
    test_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="test", transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=8, num_workers=2, shuffle=False)
    n_train = len(train_dataset)

    normal_prompts, anomaly_prompts = get_prompts(category)

    normal_text_features = model.encode_text(normal_prompts)
    anomal_text_features = model.encode_text(anomaly_prompts)

    # n_samples = [2 ** v for v in range(0, 9) if 2 ** v < n_train] + [n_train]
    n_samples = [1, 2, 4, 8, 16]

    # Caching
    outputs = []
    labels_list = []
    for images, labels in tqdm(test_loader, ncols=80, leave=False, desc="Caching"):
        token, token_ad, fp_list = model.query_forward(images)
        outputs.append((token, token_ad, fp_list))
        labels_list.append(labels.view(-1))
    labels = torch.cat(labels_list)

    for i, n_sample in enumerate(n_samples):
        metrics_list = []

        for seed in trange(args.n_seeds if n_sample < n_train else 1, ncols=80, leave=False, desc="Seeds"):
            if n_sample < n_train:
                rng = torch.Generator().manual_seed(seed)
                idxs = torch.randperm(n_train, generator=rng)[:n_sample].tolist()
            else:
                idxs = list(range(n_train))

            reference_images = torch.stack([train_dataset[i][0] for i in idxs])
            token_n, fp_list_n = model.fewshot_forward(reference_images, batch_size=8)

            with torch.inference_mode():
                scores = torch.cat([
                    model.score(token, token_ad, fp_list, token_n, fp_list_n, normal_text_features, anomal_text_features)
                    for token, token_ad, fp_list in tqdm(outputs, ncols=80, leave=False, desc="Scoring")
                ]).flatten()

            metrics_list.append(laft.binary_metrics(scores, labels))

        metrics[category][f"{n_sample}"] = metrics_list if len(metrics_list) > 1 else metrics_list[0]

        table = laft.utils.build_table(metrics, group_headers=("#Shot",))
        print(table if i == 0 else table.split("\n")[-1])
        laft.utils.save_table(table, table_path)

    table = laft.utils.build_table(metrics, group_headers=("#Shot",))
    laft.utils.save_table(table, table_path)
