from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
import torch
import numpy as np

# TODO: not implemented
# def selection_penalty(args, params, result, passive_mask, done_flags):
#     # penalizes the selection based on the mask magnitude and the performance of that mask
#     params.lasso_lambda = compute_adaptive_lasso(args, params, result, batch)
#     mask_loss = (result.mask.logits - passive_mask).norm(p=args.masking.lasso_order, dim=-1).unsqueeze(-1) # penalize for deviating from the passive mask
#     selection_penalty = (result.mask.log_probs + params.lasso_lambda * mask_loss) * done_flags
#     return selection_penalty

def selection_entropy(args, params, result, done_flags):
    # penalizes based on the magnitude of the entropy
    # TODO: we want some entropy overall (selects different clusters) while low entropy for a single state
    entropy_loss = torch.sum(-result.mask.select_logits*torch.log(result.mask.select_logits + 1e-10), dim=-1).unsqueeze(-1)
    return entropy_loss

def train_cluster_inter(args, params, model, buffer, log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger = None):
    for j in range(args.active.active_steps):
        batch, idxes = buffer.sample(args.train.batch_size, params.active_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch

        # keeps both full result and result for comparison
        result = model.infer(params, batch, batch.valid, "active", log_batch=log_batch, addition=additional)
        result.weight_rate = np.sum(params.active_weights[idxes]) / len(idxes)

        # still need all losses to compute the per-cluster losses
        logits, all_logits = result.logits, result.all_logits, result.log_probs, result.loss_log_probs
        split_logits = (all_logits[0].reshape(all_logits[0].shape[0], args.cluster.num_cluster, -1),
                        all_logits[1].reshape(all_logits[1].shape[0], args.cluster.num_cluster, -1))
        # compute log prob for every cluster
        all_loss_log_probs = list()
        for i in range(args.cluster.num_cluster):
            result.logits = (split_logits[0][:,i], split_logits[1][:,i])
            result = compute_likelihood(args, model.extractor, result, batch, model.dists.forward) # adds target, dist, done_flags, log_probs, loss_log_probs
            all_loss_log_probs.append(result.loss_log_probs)
        result.all_loss_log_probs = torch.stack(all_loss_log_probs, dim=1)
        result.logits = logits

        passive_mask = model.check_passive_mask(result.mask.logits)
        mask_logits, loss_log_probs = result.mask.logits, result.mask.logits
        done_flags = get_done_flags(batch, iscuda)
        all_selection_losses, all_interaction_losses = list(), list()
        # for each cluster, penalized using the interaction penalty weighted by the selection of that cluster
        for i in range(args.cluster.num_cluster):
            result.mask.logits = result.mask.logits.reshape(result.mask.logits.shape[0], args.cluster.num_cluster, -1)
            result.loss_log_probs = result.loss_log_probs.reshape(result.mask.logits.shape[0], args.cluster.num_cluster, -1)
            interaction_loss = evaluate_active_interaction(args, params, result, passive_mask, done_flags) * result.mask.select_logits[i]
            all_interaction_losses.append(interaction_loss)
        selection_entropy = selection_entropy(args, params, result, done_flags) * params.selection_entropy
        interaction_loss = all_interaction_losses.stack(dim=1)

        grad_variables = [result.active_input, result.active_embed] if args.active.include_gradient else list()
        optim, compute_model = model.get_model_optim("inter")
        result.gradients = run_optimizer(optim, compute_model, interaction_loss + selection_entropy, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.active_steps + j, {"inter": result})
    return result