R"""Makes CSV of TCAV p values."""
import os

from absl import app
from absl import flags

import numpy as np
from scipy import stats

from em.analysis.tcav import bert_tcav

FLAGS = flags.FLAGS


flags.DEFINE_string("tcav_filepath", None, "")
flags.DEFINE_list("p_thresholds", None, "")


def compute_p_values(scores: np.ndarray, mu0=0.5):
    # scores.shape [n_runs, n_classes]
    n = float(scores.shape[0])
    means = np.mean(scores, axis=0)
    sorta_std_dev = np.sqrt(np.sum((scores - means[None, :])**2, axis=0)) / n
    t = (means - mu0) / sorta_std_dev
    # two-sided pvalue
    pval = stats.t.sf(np.abs(t), int(n) - 1) * 2
    return pval


def compute_p_values_for_all_comps(tcav_run_scores):
    return np.stack([
        compute_p_values(tcav_run_scores[i])
        for i in sorted(tcav_run_scores.keys())
    ], axis=0)


def main(_):
    tcav_run_scores = bert_tcav.load_run_scores(FLAGS.tcav_filepath)
    ps = compute_p_values_for_all_comps(tcav_run_scores)

    p_thresholds = [float(p) for p in FLAGS.p_thresholds]
    corr_p_thresholds = [p / (ps.size()) for p in p_thresholds]

    min_per_comp = ps.min(axis=-1)

    for p in corr_p_thresholds:
        n_not_sig = (min_per_comp >= p).astype(np.int32).sum()
        print(n_not_sig)
    # print('\n'.join([','.join([str(cell) for cell in row]) for row in ps]))
    # print(ps.min(axis=-1).max())


if __name__ == "__main__":
    app.run(main)
