import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import laft

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16:dfn2b")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-k", "--n-neighbors", type=int, default=30)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")

    ds = laft.prompts.get_ds(args.dataset_name)
    model, data = laft.get_cached_features(args.model_name, ds.DATASET_NAME, ds.DATASET_CONFIG)
    print()

    ls = ds.get(data, args.guidance)
    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    train_features = ls["train_features"]
    test_features = ls["test_features"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]
    prompt = ls["prompt"]

    transform = laft.inner if args.guidance.startswith("guide") else laft.orthogonal

    metrics = {
        attend_name: {},
        ignore_name: {},
    }

    scores = laft.knn(train_features, test_features, n_neighbors=args.n_neighbors)
    metrics[attend_name]["KNN"] = laft.binary_metrics(scores, test_attend_labels)
    metrics[ignore_name]["KNN"] = laft.binary_metrics(scores, test_ignore_labels)

    features = model.encode_text(prompt).float()
    pairs = laft.prompt_pair(features)
    concept_basis = laft.pca(pairs)

    guide, attr = args.guidance.split("_")
    metric_name = f"{guide.capitalize()}/{attr.capitalize()}"

    for i, n_components in enumerate(range(2, concept_basis.size(0) + 1)):
        train_laft_features = transform(train_features, concept_basis[:n_components])
        test_laft_features = transform(test_features, concept_basis[:n_components])

        scores = laft.knn(train_laft_features, test_laft_features, n_neighbors=args.n_neighbors)
        metrics[attend_name][f"{metric_name}/{n_components}"] = laft.binary_metrics(scores, test_attend_labels)
        metrics[ignore_name][f"{metric_name}/{n_components}"] = laft.binary_metrics(scores, test_ignore_labels)

        table = laft.utils.build_table(metrics, group_headers=("Guide", "Attr.", "Comp."))
        print(table if i == 0 else table.split("\n")[-1])

    table = laft.utils.build_table(metrics, group_headers=("Guide", "Attr.", "Comp."))
    table_path = os.path.join(args.results_dir, ds.DATASET_NAME, "laft", f"{args.guidance}-{args.n_neighbors}.txt")
    laft.utils.save_table(table, table_path)
