import numpy as np
from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Train.regularizers import apply_regularizers
import time

def train_masked_active(args, params, model, buffer, form="all", name="", log_batch=[], wrap_function=None, additional=[], both=False, itr_num=0, intermediate_logger = None, add_step=0):
    mask_form = "all_mask" if form == "all" else "mask"
    for j in range(max(1, args.active.active_steps)):
        last_time = time.time()
        batch, idxes = buffer.sample(args.train.batch_size, params.sample_active_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch

        # print("sample", time.time() - last_time)
        infer_time = time.time()
        # keeps both full result and result for comparison
        full_result = model.infer(batch, batch.valid, [form],log_batch=log_batch, additional=additional)[form] # stores log batch twice
        result = model.infer(batch, batch.valid, [mask_form],log_batch=log_batch, additional=additional)
        # print("infer_time", time.time() - infer_time)
        loss_time = time.time()
        result[form] = full_result
        result.weight_rate = np.sum(params.sample_active_weights[idxes]) / len(idxes) if params.sample_active_weights is not None else 1
        # result[form] = compute_likelihood(args, result[form], batch, model, name) # adds target, dist, done_flags, log_probs, loss_log_probs
        # result[mask_form] = compute_likelihood(args, result[mask_form], batch, model, name) # adds target, dist, done_flags, log_probs, loss_log_probs
        

        for k in full_result.keys():
            if k not in result[form]:
                result[form][k] = full_result[k]
        

        # compute the full and masked combined loss
        result.active_loss = (result[mask_form].log_probs * min(1-args.active.min_mixing, (1-params.active_full_weight))
                             + result[form].log_probs * float(max(args.active.min_mixing, params.active_full_weight)))
        # print("loss time", time.time()- loss_time)
        grad_time= time.time()

        grad_variables = [result[form].active_input, result[mask_form].active_input] if args.active.include_gradient else list()
        compute_models, optims = model.get_model_optim([form + "_both" if both else form])
        optim, compute_model = optims[0], compute_models[0]
        loss = apply_regularizers(- result.active_loss, args, params, model, batch, results=(result[mask_form], result[form]))
        result.gradients = run_optimizer(optim, compute_model, loss, grad_variables=grad_variables)
        # print("grad_time", time.time() - grad_time)
        # print("train step", time.time() - last_time)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.active_steps + j + add_step, {"mask": result}, intermediate_name = "_mask")
    return result