import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import laft

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16:dfn2b")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")
    assert args.guidance.startswith("guide_"), "clip_score only supports guide"

    ds = laft.prompts.get_ds(args.dataset_name)
    model, data = laft.get_cached_features(args.model_name, ds.DATASET_NAME, ds.DATASET_CONFIG)
    print()

    ls = ds.get(data, args.guidance)
    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    train_features = ls["train_features"]
    test_features = ls["test_features"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]
    prompt_ind = ls["prompt_ind"]
    prompt_ood = ls["prompt_ood"]

    metrics = {
        attend_name: {},
        ignore_name: {},
    }

    # MCM
    scores = laft.baselines.clip_score(model, test_features, prompt_ind)
    metrics[attend_name]["MCM"] = laft.binary_metrics(scores, test_attend_labels)
    metrics[ignore_name]["MCM"] = laft.binary_metrics(scores, test_ignore_labels)

    # ZOC
    scores = laft.baselines.clip_score(model, test_features, prompt_ind, prompt_ood)
    metrics[attend_name]["ZOC"] = laft.binary_metrics(scores, test_attend_labels)
    metrics[ignore_name]["ZOC"] = laft.binary_metrics(scores, test_ignore_labels)

    scores = laft.baselines.clip_score(model, test_features, prompt_ind, prompt_ood, temperature=None)
    metrics[attend_name]["ZOC-model"] = laft.binary_metrics(scores, test_attend_labels)
    metrics[ignore_name]["ZOC-model"] = laft.binary_metrics(scores, test_ignore_labels)

    scores = laft.baselines.clip_score(model, test_features, prompt_ind, prompt_ood, temperature=0.07)
    metrics[attend_name]["ZOC-0.07"] = laft.binary_metrics(scores, test_attend_labels)
    metrics[ignore_name]["ZOC-0.07"] = laft.binary_metrics(scores, test_ignore_labels)

    # Print
    table = laft.utils.build_table(metrics, group_headers=("Score",))
    table_path = os.path.join(args.results_dir, ds.DATASET_NAME, "clip_score",  f"{args.guidance}.txt")
    print(table)
    laft.utils.save_table(table, table_path)
