import os
import sys
import argparse
from collections import defaultdict

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm, trange

import laft
from laft.baselines.winclip_laft import load_winclip_laft
from laft.prompts.industrial import get_prompts


torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16-plus-240:laion400m_e31")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-c", "--category", type=str, required=True)
    parser.add_argument("-n", "--n-samples", type=int, default=0)
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    parser.add_argument("-p", "--prompt", type=int, default=0)
    args = parser.parse_args()

    assert args.dataset_name in ("mvtec",)

    model, preprocess = load_winclip_laft(args.model_name)

    category = args.category

    metrics = {
        category: defaultdict(list),
    }
    table_path = os.path.join(args.results_dir, args.dataset_name, "winclip_laft", f"{args.category}", f"{args.n_samples}.txt")

    train_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="train", transform=preprocess)
    test_dataset = laft.build_dataset(f"{args.dataset_name}:{args.category}", split="test", transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=16, num_workers=2)
    n_train = len(train_dataset)

    # Caching
    labels_list, image_embeds_list, window_embeds_list, patch_embeds_list = [], [], [], []

    for images, _labels in tqdm(test_loader, ncols=80, leave=False, desc="Caching"):
        _image_embeds, _window_embeds, _patch_embeds = model.encode_image(images)
        labels_list.append(_labels.view(-1))
        image_embeds_list.append(_image_embeds)
        window_embeds_list.append(_window_embeds)
        patch_embeds_list.append(_patch_embeds)

    labels = torch.cat(labels_list)
    image_embeds = torch.cat(image_embeds_list)
    window_embeds = [torch.cat(window_embeds) for window_embeds in zip(*window_embeds_list)]
    patch_embeds = torch.cat(patch_embeds_list)


    normal_prompts, anomaly_prompts = get_prompts(category)
    model.setup(normal_prompts=normal_prompts, anomaly_prompts=anomaly_prompts)

    if args.prompt > 0:
        exec(f"from laft.prompts.industrial{args.prompt} import get_prompts as _get_prompts")
        normal_prompts, anomaly_prompts = _get_prompts(category)
        table_path = os.path.join(args.results_dir, args.dataset_name, "winclip_laft", f"{args.category}", f"{args.n_samples}-p{args.prompt}.txt")

    features = model.clip.encode_text(normal_prompts + anomaly_prompts).float()
    pairs = laft.prompt_pair(features)
    basis = laft.pca(pairs)

    n_sample = args.n_samples
    ranges = lambda: (*range(2, 32), *range(32, 301, 2))

    with torch.inference_mode():
        if n_sample > 0:
            metrics_list = []

            for seed in trange(args.n_seeds if n_sample < n_train else 1, ncols=80, leave=False, desc="Seeds"):
                if n_sample < n_train:
                    rng = torch.Generator().manual_seed(seed)
                    idxs = torch.randperm(n_train, generator=rng)[:n_sample].tolist()
                else:
                    idxs = list(range(n_train))

                reference_images = torch.stack([train_dataset[i][0] for i in idxs]).cuda()
                model.setup(reference_images=reference_images)

                for j, n_component in enumerate(tqdm(ranges(), ncols=80, leave=False, desc="Components")):
                    model.setup_laft(basis[:n_component], guide=True)
                    scores = model.forward_cache(image_embeds, window_embeds, patch_embeds, laft_image=True).cpu()
                    metrics[category][f"{n_sample}/{n_component}"].append(laft.binary_metrics(scores, labels))

                    table = laft.utils.build_table(metrics, group_headers=("#Shot", "#Comp."))
                    tqdm.write(table if j == 0 else table.split("\n")[3 + j])

                laft.utils.save_table(table, table_path)

        else:
            for j, n_component in enumerate(tqdm(ranges(), ncols=80, leave=False, desc="Components")):
                model.setup_laft(basis[:n_component], guide=True)
                scores = model.forward_cache(image_embeds, window_embeds, patch_embeds, laft_image=True).cpu()
                metrics[category][f"{n_sample}/{n_component}"].append(laft.binary_metrics(scores, labels))

                table = laft.utils.build_table(metrics, group_headers=("#Shot", "#Comp."))
                tqdm.write(table if j == 0 else table.split("\n")[3 + j])

    table = laft.utils.build_table(metrics, group_headers=("#Shot", "#Comp."))
    print(table)
    laft.utils.save_table(table, table_path)
