R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 em/projects/neurips2023/make_perturbation_acc_csv.py



"""
import json
import os

from em.projects.m_npeff import perturbations

###############################################################################

JSON_DIR = os.path.expanduser("~/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons")


JSON_NAME = "snli3og_lrm_npeff/acc_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.json"

###############################################################################


output = perturbations.PmPerturbationExperimentOutput.load(os.path.join(JSON_DIR, JSON_NAME))

# print(output.make_kl_ratios_csv_str())
ret = [['comp', 'max', 'min', 'amean']]
for comp in range(len(output.component_outputs)):
    a = output.component_outputs[comp]
    res = a.plus_results if a.plus_results.kl_ratio() > a.minus_results.kl_ratio() else a.minus_results
    ret.append([
        a.component_index,
        res.total_acc,
        # max(a.plus_results.total_acc, a.minus_results.total_acc),
        # min(a.plus_results.total_acc, a.minus_results.total_acc),
        # 0.5 * sum([a.plus_results.total_acc, a.minus_results.total_acc]),
    ])

print('\n'.join([','.join([str(cell) for cell in row]) for row in ret]))


# ###############################################################################
# ###############################################################################

# JSON_DIR = os.path.expanduser("~/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons")


# JSON_NAME = "snli3og_lrm_npeff/kl_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.json"
# COMP_INDS = [15, 26, 102, 119, 128, 129, 134, 143, 146, 162, 168, 177, 189, 200, 212, 221, 242, 277, 282, 379, 404, 426, 436, 471, 491]

# # JSON_NAME = "snli3og_lrm_npeff/kl_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.json"
# # COMP_INDS = [0, 1, 2, 4, 9, 14, 17, 18, 19, 21, 22, 24, 27, 28, 30, 34, 38, 40, 42, 45, 50, 54, 55, 59, 62, 63, 79, 90, 166, 193, 198, 207, 210, 226, 232, 241, 253, 276, 285, 335, 341, 346, 443, 468, 490, 535, 547, 555]

# ###############################################################################


# output = perturbations.PmPerturbationExperimentOutput.load(os.path.join(JSON_DIR, JSON_NAME))

# # print(output.make_kl_ratios_csv_str())
# ret = [['comp', 'max', 'min', 'amean']]
# for comp in COMP_INDS:
#     a = output.component_outputs[comp]
#     ret.append([
#         a.component_index,
#         max(a.plus_results.total_acc, a.minus_results.total_acc),
#         min(a.plus_results.total_acc, a.minus_results.total_acc),
#         0.5 * sum([a.plus_results.total_acc, a.minus_results.total_acc]),
#     ])

# print('\n'.join([','.join([str(cell) for cell in row]) for row in ret]))