import copy
import numpy as np
from ActualCausal.Utils.weighting import passive_binary
from ActualCausal.Utils.run_dataset import get_operation, compute_types

def infer_pseudo_null(args, model, buffer, params):
    # replaces the valid vector on states where passive error is low
    # does so by replacing valid in the buffer
    # TODO: not supported by multi-interactions

    if len(args.infer.nulls.pseudo_null) > 0:
        if "valid" not in params: # if valid is not already in params
            params.valid = copy.deepcopy(buffer.valid[:len(buffer)])
        if args.infer.nulls.pseudo_null_passive_weighting[2] >= 0:
            passive_error = - get_operation(model, buffer, all_compute=[compute_types.PASSIVE_LIKELIHOOD], reduced=True, normalized=False, object_names=[model.target_name])[0]
            passive_binaries = passive_binary(passive_error, args.infer.nulls.pseudo_null_passive_weighting[:2], None, buffer.done[:len(buffer)]) # TODO: proximity not supported
        else:
            passive_binaries = buffer.weight_binary
        idxes = (1-passive_binaries).nonzero()[0]
        if args.infer.nulls.pseudo_null == "zero":
            zero_vector = np.zeros(model.extractor.num_objects)
            zero_vector[model.extractor.get_index(model.target_name)] = 1
            buffer.valid[idxes] = params.valid[idxes] * np.expand_dims(zero_vector, axis=0)
        elif args.infer.nulls.pseudo_null == "rand":
            rand_vector = np.random.rand(*params.valid.shape)
            rand_vector[rand_vector < args.infer.nulls.pseudo_null_rate] = 0
            rand_vector[rand_vector >= args.infer.nulls.pseudo_null_rate] = 1
            rand_vector[:, model.extractor.get_index(model.target_name)] = 1
            buffer.valid[idxes] = rand_vector[idxes] * params.valid[idxes]
    
