import numpy as np
from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Train.regularizers import apply_regularizers
import time
from tianshou.data import Batch
from ActualCausal.Train.Active.train_mask_active import train_masked_active
from ActualCausal.Train.Inter.train_inter import train_inter

# trains both mask and interaction models, alternating steps
def train_mask_inter(args, params, model, buffer, form="all", name="", log_batch=[], wrap_function=None, additional=[], both=False, itr_num=0, intermediate_logger = None):
    true_active_steps, true_inline_steps, args.active.active_steps, params.masking_steps = args.active.active_steps, params.masking_steps, args.active.mask_inter_steps, args.active.mask_inter_steps
    for i in range(true_active_steps):
        mask = train_masked_active(args, params, model, buffer, form=form, name=name, log_batch=log_batch, wrap_function=wrap_function, additional=additional, both=both, itr_num=itr_num, intermediate_logger = intermediate_logger, add_step=i)
        inter = train_inter(args, params, model, buffer, form=form, name=name, log_batch=log_batch, wrap_function=wrap_function, additional=additional, both=both, itr_num=itr_num, intermediate_logger = intermediate_logger, add_step=i)
    args.active.active_steps, params.masking_steps = true_active_steps, true_inline_steps
    return mask, inter