import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm, trange

import laft

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    parser.add_argument("-m", "--checkpoint", type=str, default="baselines/checkpoints/inctrl_2.pt")
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")
    assert args.guidance.startswith("guide_"), "inctrl only supports guide"

    ds = laft.prompts.get_ds(args.dataset_name)
    model, preprocess = laft.baselines.load_inctrl(args.checkpoint)
    data = laft.get_dataset(args.dataset_name, ds.DATASET_CONFIG, preprocess)

    ls = ds.get(data, args.guidance)
    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    valid_subset = ls["valid_features"]
    test_subset = ls["test_features"]
    valid_attrs = ls["valid_attrs"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]
    prompt_ind = ls["prompt_ind"]
    prompt_ood = ls["prompt_ood"]

    ckpt_name = ".".join(os.path.basename(args.checkpoint).split(".")[:-1])
    table_path = os.path.join(args.results_dir, ds.DATASET_NAME, "inctrl_aux", f"{args.guidance}-{ckpt_name}.txt")
    test_loader = DataLoader(test_subset, batch_size=32, num_workers=4)

    normal_text_features = model.encode_text(prompt_ind)
    anomal_text_features = model.encode_text(prompt_ood)

    n_samples = [1, 2, 4, 8, 16, 32, 64, 128, 256]

    # Caching
    outputs = []
    labels_list = []
    for images, labels in tqdm(test_loader, ncols=80, leave=False, desc="Caching"):
        token, token_ad, fp_list = model.query_forward(images)
        outputs.append((token.cpu(), token_ad.cpu(), fp_list.cpu()))
        labels_list.append(labels.view(-1))
    labels = torch.cat(labels_list)

    metrics = {
        attend_name: {},
        ignore_name: {},
    }

    for i, n_sample in enumerate(n_samples):
        attend_metrics_list = []
        ignore_metrics_list = []

        for seed in trange(args.n_seeds, ncols=80, leave=False, desc="Seeds"):
            idxs = laft.sample_index(valid_attrs, n_samples=n_sample, seed=seed)
            reference_images = torch.stack([valid_subset[i][0] for i in idxs]).cuda()
            token_n, fp_list_n = model.fewshot_forward(reference_images, batch_size=32)

            with torch.inference_mode():
                scores = torch.cat([
                    model.score(
                        token.cuda(), token_ad.cuda(), fp_list.cuda(),
                        token_n, fp_list_n, normal_text_features, anomal_text_features,
                    ).cpu()
                    for token, token_ad, fp_list in tqdm(outputs, ncols=80, leave=False, desc="Scoring")
                ]).flatten()

            attend_metrics_list.append(laft.binary_metrics(scores, test_attend_labels))
            ignore_metrics_list.append(laft.binary_metrics(scores, test_ignore_labels))

        metrics[attend_name][f"{n_sample}"] = attend_metrics_list if len(attend_metrics_list) > 1 else attend_metrics_list[0]
        metrics[ignore_name][f"{n_sample}"] = ignore_metrics_list if len(ignore_metrics_list) > 1 else ignore_metrics_list[0]

        table = laft.utils.build_table(metrics, group_headers=("#Shot",))
        print(table if i == 0 else table.split("\n")[-1])
        laft.utils.save_table(table, table_path)

    # Print
    table = laft.utils.build_table(metrics, group_headers=("#Shot",))
    print(table)
    laft.utils.save_table(table, table_path)
