import numpy as np
import torch
from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood

# TODO: not refactored, so it might not work
def train_cluster_active(args, params, model, buffer, log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger = None):
    for j in range(args.active.active_steps):
        batch, idxes = buffer.sample(args.train.batch_size, params.active_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch

        # keeps both full result and result for comparison
        result = model.infer(batch, batch.valid, "full_active", log_batch=log_batch, additional=additional) # fills results.logits with the selection weighted ones
        result.weight_rate = np.sum(params.active_weights[idxes]) / len(idxes)

        all_logits = result.all_logits
        split_logits = (all_logits[0].reshape(all_logits[0].shape[0], args.cluster.num_cluster, -1),
                        all_logits[1].reshape(all_logits[1].shape[0], args.cluster.num_cluster, -1))
        # compute log prob for every cluster
        all_loss_log_probs = list()
        for i in range(args.cluster.num_cluster):
            result.logits = (split_logits[0][:,i], split_logits[1][:,i])
            result = compute_likelihood(args, result, model.extractor, batch, model.dists.forward) # adds target, dist, done_flags, log_probs, loss_log_probs
            all_loss_log_probs.append(result.log_probs)
        result.loss_log_probs = torch.stack(all_loss_log_probs, dim=1)

        result.mixed_loss = - (result.loss_log_probs *(1-params.full_mixed_schedule))

        grad_variables = [result.active_input, result.active_embed] if args.active.include_gradient else list()
        optim, compute_model = model.get_model_optim("active")
        result.gradients = run_optimizer(optim, compute_model, result.mixed_loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.active_steps + j, {"active": result})

    return result