import numpy as np
import torch
from Network.network_utils import pytorch_model
from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood, get_done_flags, compute_likelihood_adaptive_lasso, compute_mean_adaptive_lasso, compute_mean_var_adaptive_lasso, compute_adaptive_rate
from tianshou.data import Batch

def compute_adaptive_lasso(args, params, result, batch):
    not_use_adaptive= args.masking.adaptive_lasso[0] < 0
    base_value = params.lasso
    baseline_likelihood = params.converged_active_loss_value if "converged_active_loss_value" in params else 3.5 * result.log_probs.shape[-1]
    adaptive_lasso = params.adaptive_lasso
    bias = args.masking.adaptive_lasso_bias[0]
    flatten_factor = args.masking.adaptive_lasso_bias[1]
    # print(result.log_probs[0])
    # print(not_use_adaptive, args.masking.adaptive_lasso)
    return compute_adaptive_rate(args.masking.adaptive_lasso_type, not_use_adaptive, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor, pointwise=False)


def evaluate_active_interaction(args, params, result, passive_mask):
    # computes losses based on the magnitude of the mask,, and the log likelihood of the data
    # penalizes any value
    result.mask_loss = (result.mask_logits - passive_mask).norm(p=args.masking.lasso_order, dim=-1).reshape(result.log_probs.shape[0], -1).mean(dim=-1) # penalize for deviating from the passive mask
    # penalizes zero masks
    result.zero_mask_loss = (result.mask_logits).norm(p=args.masking.lasso_order, dim=-1).reshape(result.log_probs.shape[0], -1).mean(dim=-1) # penalize for deviating from the passive mask
    # penalizes one masks
    result.one_mask_loss = (1-result.mask_logits).norm(p=args.masking.lasso_order, dim=-1).reshape(result.log_probs.shape[0], -1).mean(dim=-1)
    # moves masks towards 0.5
    result.half_mask_loss = (0.5 - result.mask_logits).norm(p=args.masking.lasso_order, dim=-1).reshape(result.log_probs.shape[0], -1).mean(dim=-1)
    # moves masks towards zero or one
    result.entropy_loss = torch.sum(-result.mask_logits*torch.log(result.mask_logits + 1e-10), dim=-1).reshape(result.log_probs.shape[0], -1).mean(dim=-1)


    full_loss = (- result.log_probs.mean(dim=-1)
                    + result.mask_loss * float(params.lasso_lambda)
                    + result.one_mask_loss * float(params.lasso_one_lambda)
                    + result.half_mask_loss * float(params.lasso_half_lambda)
                    + result.entropy_loss * float(params.entropy_lambda))
    # print(- pytorch_model.unwrap(result.log_probs.mean(dim=-1)[0]), pytorch_model.unwrap(result.mask_loss * float(params.lasso_lambda))[0], pytorch_model.unwrap(full_loss[0]), full_loss.shape, result.log_probs.mean(dim=-1).shape,
    #         params.lasso_one_lambda,
    #         params.lasso_half_lambda,
    #         params.entropy_lambda)
    return full_loss.mean()

def train_inter(args, params, model, buffer, form="all", name="", log_batch=[], wrap_function=None, additional=[], both=False, itr_num=0, intermediate_logger = None, add_step=0):
    mask_form = "all_mask" if form == "all" else "mask"
    for j in range(params.masking_steps):
        high_batch, high_idxes = buffer.sample(int(np.round(args.train.batch_size / 2)), params.sample_interaction_weights)
        low_batch, low_idxes = buffer.sample(int(np.round(args.train.batch_size / 2)), params.sample_low_interaction_weights)
        batch = Batch.cat([low_batch, high_batch])
        idxes = np.concatenate([low_idxes, high_idxes])

        batch = wrap_function(batch) if wrap_function is not None else batch
        # keeps both full result and result for comparison
        result = model.infer(batch, batch.valid, [mask_form],log_batch=log_batch, additional=additional)
        result.weight_rate = np.sum(params.sample_interaction_weights[idxes]) / len(idxes) if params.sample_interaction_weights is not None else 1
        # result[mask_form] = compute_likelihood(args, result[mask_form], batch, model, name) # adds target, dist, done_flags, log_probs, loss_log_probs
        params.lasso_lambda = compute_adaptive_lasso(args, params, result[mask_form], batch)
        result.lasso_lambda = params.lasso_lambda

        iscuda = result[mask_form].params.is_cuda
        passive_mask = pytorch_model.wrap(model.passive_mask(len(result[mask_form].omit_flags[0]), name, form), cuda=iscuda)
        result.interaction_loss = evaluate_active_interaction(args, params, result[mask_form], passive_mask)

        grad_variables = [result.active_input, result.active_embed] if args.active.include_gradient else list()
        compute_models, optims = model.get_model_optim([form + "_both" if both else form + "_inter"])
        optim, compute_model = optims[0], compute_models[0]
        # print("inter_loss", result.interaction_loss, params.lasso_lambda, result[mask_form].log_probs.sum(), passive_mask.sum())
        result.gradients = run_optimizer(optim, compute_model, result.interaction_loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * params.masking_steps + j + add_step, {"mask": result}, intermediate_name = "_inter")
    return result