from tianshou.data import Batch
import numpy as np
from Network.network_utils import pytorch_model

def evaluate_difference(args, pair_result, single_result):
    active_threshold, passive_threshold, difference_threshold = args.infer.granger_threshold
    bin = (pair_result.log_probs.sum(dim=-1) > active_threshold) * (single_result.log_probs.sum(dim=-1) < passive_threshold) * (pair_result.log_probs.sum(dim=-1) - single_result.log_probs.sum(dim=-1) > difference_threshold) 
    diff = pair_result.log_probs.sum(dim=-1) - single_result.log_probs.sum(dim=-1)
    active = pair_result.log_probs.sum(dim=-1)
    return bin, diff, active

def infer_granger(pair_name, args, params, model, batch, single_result, keep_all=False):
    # infers granger causality for a single pair, returning binaries
    source, target = pair_name.split("->")
    source = source.split("|")
    binary = np.zeros((1,args.factor.num_objects))
    binary[:,model.extractor.get_index(source + [target])] = 1
    binary = np.broadcast_to(binary, (single_result.log_probs.shape[0], args.factor.num_objects))
    passive = np.zeros((single_result.log_probs.shape[0], args.factor.num_objects))
    passive[:,model.extractor.get_index([target])] = 1
    model.set_target_name(pair_name)
    pair_result = model.infer(batch, batch.valid, ["pair"], log_batch=[], keep_all=keep_all).pair
    bin, diff, active = evaluate_difference(args, pair_result, single_result)
    return pytorch_model.unwrap(bin.unsqueeze(-1)).astype(float) * binary + (1-pytorch_model.unwrap(bin.unsqueeze(-1)).astype(float)) * passive

def infer_all_granger(args, params, model, batch, keep_all=False):
    # combines granger results for all pairs by using the max of all the binary outcomes
    model.set_target_name(args.infer.infer_names[0])# TODO: assumes only one name
    single_result = model.infer(batch, batch.valid, ["single_passive"], log_batch=[], keep_all=keep_all).single_passive
    binaries = np.zeros((single_result.log_probs.shape[0], args.factor.num_objects))
    result = Batch()
    for pair_name in args.inter.pair_names:
        source, target = pair_name.split("->")
        result[target] = Batch()
        binaries += infer_granger(pair_name, args, params, model, batch, single_result, keep_all=keep_all)
        result[target].utrace = batch.trace[single_result.omit_flags[0], model.extractor.get_index([target])]
    binaries[binaries > 0] = 1
    result.inter_masks = binaries
    result.bin_error = result.inter_masks - result[target].utrace
    result.total_error = np.abs(result.inter_masks - result[target].utrace)
    result.valid = batch.valid
    result.omit_flags =  single_result.omit_flags
    result.trace = batch.trace[single_result.omit_flags[0]]
    return result