from ACState.object_dict import ObjDict

## TODO: make the records not hardcoded?
# record_names = Batch(mean=["active_loss", "mask_loss", "interaction_loss", "log_probs"],
#                 complete=["log_probs", "mask_logits", "trace", "valid"],
#                 rates=["mask_logits"])
# infer_names = Batch(mean=["active_loss", "mask_loss", "interaction_loss", "log_probs"],
#                 complete=["bin_error", "utrace", "inter_masks", "valid"],
#                 rates=["bin_error"]) # TODO should be different, but I don't know what yet
record_names = ObjDict({"mean": ["active_loss", "log_probs", "trace_log_probs", "lasso_lambda", "trace_loss"],
                "complete": ["trace_true_diff", "trace_diff"],
                "rates": []})
infer_names = ObjDict({"mean": [],
                # "complete": ["bin_error", "total_error", "utrace", "inter_masks", "trace", "valid"],
                # "rates": ["inter_masks"]}) # TODO should be different, but I don't know what yet
                "mean": ["log_probs", "trace_loss"],
                "complete": ["bin_error", "total_error", "utrace", "inter_masks", "inter_variance", "inter_one_trace_rate", "inter_zero_trace_rate", "trace_loss", "inter_fp", "inter_fn",
                            "null_dists", "null_positive_dists", "null_negative_dists", "null_fp_dists", "null_fn_dists", "fp_log_probs", "fn_log_probs"
                             ],
                "rates": []}) # TODO should be different, but I don't know what yet
