import os
import sys
import argparse
from collections import defaultdict

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import laft
from sklearn.linear_model import LogisticRegression

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16:dfn2b")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")
    assert args.guidance.startswith("guide_"), "linear probe only supports guide"

    ds = laft.prompts.get_ds(args.dataset_name)
    model, data = laft.get_cached_features(args.model_name, ds.DATASET_NAME, ds.DATASET_CONFIG, device="cpu")
    print()

    ls = ds.get(data, args.guidance)

    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    valid_attrs = ls["valid_attrs"]
    train_features = ls["train_features"]
    valid_features = ls["valid_features"]
    test_features = ls["test_features"]
    valid_attend_labels = ls["valid_attend_labels"]
    valid_ignore_labels = ls["valid_ignore_labels"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]

    metrics = {
        attend_name: defaultdict(list),
        ignore_name: defaultdict(list),
    }

    # N-shots
    for n_samples in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
        for seed in range(args.n_seeds):
            subset_idx = laft.sample_index(valid_attrs, n_samples=n_samples, seed=seed)

            subset_features = valid_features[subset_idx]
            subset_labels = valid_attend_labels[subset_idx]
            linear_train_features = torch.cat((train_features, subset_features), dim=0)
            linear_train_labels = torch.cat((torch.zeros(train_features.size(0), dtype=subset_labels.dtype), subset_labels), dim=0)

            regressor = LogisticRegression(random_state=seed, max_iter=1000)
            regressor.fit(linear_train_features, linear_train_labels)
            scores = torch.from_numpy(regressor.predict_proba(test_features.cpu())[:, 1])

            metrics[attend_name][f"{n_samples}"].append(laft.binary_metrics(scores, test_attend_labels))
            metrics[ignore_name][f"{n_samples}"].append(laft.binary_metrics(scores, test_ignore_labels))

    # Full-shot
    for seed in range(args.n_seeds):
        regressor = LogisticRegression(random_state=seed, max_iter=1000)
        regressor.fit(valid_features, valid_attend_labels)
        scores = torch.from_numpy(regressor.predict_proba(test_features.cpu())[:, 1])

        metrics[attend_name]["Full"].append(laft.binary_metrics(scores, test_attend_labels))
        metrics[ignore_name]["Full"].append(laft.binary_metrics(scores, test_ignore_labels))

    table = laft.utils.build_table(metrics, group_headers=("Samples",))
    table_path = os.path.join(args.results_dir, ds.DATASET_NAME, "linear_probe",  f"{args.guidance}.txt")
    print(table)
    laft.utils.save_table(table, table_path)
