from ActualCausal.Train.Inter.train_trace import train_binaries
from ActualCausal.Inference.General.null import infer_null_values
import numpy as np

def train_null_assign(args, params, model, buffer, form="full", log_batch=[], additional=[], name = "", itr_num=0, intermediate_logger = None):
    # get the target binaries to train for (the traces, or a proxy like gradient, proximity, etc.)
    # then assigns null_traces and null_weights in the buffer to those values
    null_traces, null_weights, null_results = infer_null_values(args, params, model, buffer, keep_all=True, infer_names=[name] if len(name) > 0 else [])
    if form == "full": 
        form = "probs"
        null_traces = null_traces[name]
        params.trace_weights = null_weights[name] / np.sum(null_weights[name])
        null_weights = null_weights[name]
    else:
        form = "all_probs"
        params.trace_weights = null_weights / np.sum(null_weights)
    
    # print(buffer.confidence[:len(buffer)].shape, null_weights.shape)
    buffer.eval_binary[:len(buffer)] = null_traces
    buffer.confidence[:len(buffer)] = null_weights
    buffer.norm_confidence[:len(buffer)] = np.expand_dims(np.mean(buffer.confidence, axis=1), -1)
    buffer.norm_confidence[:len(buffer)] = (buffer.norm_confidence - np.min(buffer.norm_confidence) + 1) / np.sum(buffer.norm_confidence - np.min(buffer.norm_confidence) + 1)
    return null_results