from ActualCausal.Inference.General.attention import compute_attention_cause
from ActualCausal.Inference.General.counterfactual import compute_counterfactual_cause
from ActualCausal.Inference.General.dist_infer import infer_dist
from ActualCausal.Inference.General.gradient import infer_gradient
from ActualCausal.Inference.General.null import infer_null
from ActualCausal.Inference.General.granger import infer_all_granger
from tianshou.data import Batch
import numpy as np

def compute_inference(args, params, model, batch, infer_types, keep_all=False, perform_analysis=-1):
    # computes the necessary model values, then applies the appropriate
    # operations for inference, generating as many inference values as
    # necessary
    # expects to get result.infer_type.inter_masks, result.infer_type.bin_error, result.infer_type.omit_flags
    # inter mask is the raw mask values, bin error is the difference with trace
    # keep_all keeps all of the outputs, even dones
    result = Batch()
    if "nulls" in infer_types:
        null_result = infer_null(args, params, model, batch, perform_analysis=perform_analysis, keep_all=keep_all)
        result.nulls = null_result
    if "gradient" in infer_types:
        result.gradient = infer_gradient(args, params, model, batch)
    if "soft" in infer_types or "mixed" in infer_types or "hard" in infer_types:
        # cluster information used here
        dist_result = infer_dist(args, params, model, batch, infer_types, keep_all=keep_all)
        for k in dist_result.keys():
            result[k] = dist_result[k]
    if "counterfactual" in infer_types:
        counterfactual_result = compute_counterfactual_cause(args, params, model, batch, keep_all=keep_all)
        result.counterfactual = counterfactual_result
    if "attention" in infer_types: # TODO: not bugfixed yet
        attention_result = compute_attention_cause(args, params, model, batch, result, keep_all=keep_all)
        result.attention = attention_result
    if "granger" in infer_types:
        granger_result = infer_all_granger(args, params, model, batch, keep_all=keep_all)
        result.granger = granger_result
    return result

def evaluate_inference(i, args, params, model, buffer, force=False, test=False, wrap_function=None, weights=None):
    if force or (args.infer.infer_interval > 0 and i % args.infer.infer_interval == 0): # avoid eval every timestep for speed
        batch, idx = buffer.sample(params.infer_num, weights)
        batch = wrap_function(batch) if wrap_function is not None else batch
        # if weights is not None: print("infer weights", weights[idx], batch.trace[:,model.extractor.get_index(args.inter.train_names[0])])
        # else: print("WEIGHTS ARE NONE")
        MAX_BATCH_SIZE = 1024
        result = list()
        for i in range(int(np.ceil(len(batch) / MAX_BATCH_SIZE))):
            use_batch = batch[i*1024:(i+1) * 1024]
            result.append(compute_inference(args, params, model, use_batch, args.infer.infer_types, perform_analysis=args.infer.perform_analysis))
        return Batch.cat(result)
    return None