from tianshou.data import Batch
from Model.InferenceModule.module_utils import apply_mask
from Network.network_utils import pytorch_model
import numpy as np


def infer_dist(args, params, model,batch, infer_types, keep_all = False):
    if len(args.infer.infer_names) == 0:
        return infer_dist_single(args, params, model, batch, "all_probs", infer_types, keep_all=keep_all)
    single_inters = Batch()
    for n in args.infer.infer_names: # this must match the names actually trained
        model.set_target_name(n)
        single_inters[n] = infer_dist_single(args, params, model, batch, "probs", infer_types, name=n, keep_all=keep_all)
    return single_inters

def infer_dist_single(args, params, model, batch,infer_form, infer_types, name = "", keep_all=False):
    infer_result = model.infer(batch, batch.valid, [infer_form], keep_all= keep_all)[infer_form]
    infer_result.trace = batch.trace[infer_result.omit_flags[0]]
    if 'soft' in infer_types:
        infer_result.soft = Batch()
        infer_result.soft.inter_masks = apply_mask(args.inter.masking, model.dists, infer_result.mask_logits, soft=True, flat=False, mixed=False, test=False)
        inter_stats(model, infer_result, infer_result.soft, infer_form, name=name)
    if 'mixed' in infer_types:
        infer_result.mixed = Batch()
        infer_result.mixed.inter_masks = apply_mask(args.inter.masking, model.dists, infer_result.mask_logits, soft=True, flat=False, mixed=True, test=False)
        inter_stats(model, infer_result, infer_result.mixed, infer_form, name=name)
    if 'hard' in infer_types:
        infer_result.hard = Batch()
        infer_result.hard.inter_masks = apply_mask(args.inter.masking, model.dists, infer_result.mask_logits, soft=False, flat=False, mixed=False, test=False)
        inter_stats(model, infer_result, infer_result.hard, infer_form, name=name)
    if 'flat' in infer_types:
        infer_result.flat = Batch()
        infer_result.flat.inter_masks = apply_mask(args.inter.masking, model.dists, infer_result.mask_logits, soft=False, flat=True, mixed=False, test=False)
        inter_stats(model, infer_result, infer_result.flat, infer_form, name=name)
    return infer_result

def inter_stats(model, infer_result, target_result, infer_form, name=""):
    target_result.mask_logits = infer_result.mask_logits
    target_result.utrace = infer_result.trace if infer_form == "all_probs" else infer_result.trace[:, model.extractor.get_index([name])]
    if len(target_result.utrace.shape) == 2: target_result.utrace = np.expand_dims(target_result.utrace, axis=1)
    target_result.inter_variance = np.std(pytorch_model.unwrap(target_result.inter_masks), axis=0)
    # print([np.mean(pytorch_model.unwrap(target_result.inter_masks[...,i][target_result.utrace[...,i] == 1])) for i in range(target_result.inter_masks.shape[-1])])
    target_result.inter_one_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(target_result.inter_masks[...,i][target_result.utrace[...,i] == 1]))) for i in range(target_result.inter_masks.shape[-1])]), axis=0)
    target_result.inter_zero_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(target_result.inter_masks[...,i][target_result.utrace[...,i] == 0]))) for i in range(target_result.inter_masks.shape[-1])]), axis=0)

    target_result.bin_error = pytorch_model.unwrap(target_result.inter_masks) - target_result.utrace # assume only one target
    target_result.total_error = np.abs(pytorch_model.unwrap(target_result.inter_masks) - target_result.utrace) # assume only one target
    target_result.logit_error = pytorch_model.unwrap(target_result.mask_logits) - target_result.utrace # assume only one target
    target_result.trace = infer_result.trace # redundant and somewhat expensive
    target_result.inter_fp = (pytorch_model.unwrap(target_result.inter_masks) - target_result.utrace) # must have log_batch contain trace
    target_result.inter_fp[target_result.inter_fp<0] = 0
    target_result.inter_fn = (pytorch_model.unwrap(target_result.inter_masks) - target_result.utrace) # must have log_batch contain trace
    target_result.inter_fn[target_result.inter_fn>0] = 0
