from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Train.regularizers import apply_regularizers

def train_full_active(args, params, model, buffer, form="all", name="", log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger=None):
    for i in range(args.active.full_steps):
        batch, idxes = buffer.sample(args.train.batch_size, params.sample_active_full_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch

        result = model.infer(batch, batch.valid, [form], log_batch=log_batch, additional=additional) # adds logits [batch, num_obj*obj_dim]
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.full_steps + i, {"full": result})
        grad_variables = [result.full_active_input] if args.active.include_gradient else list() # TODO:change full_active_input to just active_input, include mask_input, keys, queries
        compute_models, optims = model.get_model_optim([form])
        optim, compute_model = optims[0], compute_models[0]
        loss = apply_regularizers(- result[form].log_probs, args, params, model, batch, results=result[form])
        result.reg_loss = loss
        result.gradients = run_optimizer(optim, compute_model, loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.full_steps + i, {"full": result}, intermediate_name = "_full")
    return result
