from Network.network_utils import pytorch_model
import numpy as np
import torch
from tianshou.data import Batch

def compute_attention_cause(args, params, model,batch, infer_types, keep_all = False):
    form = "all" if len(args.infer.infer_names) == 0 else "full"
    if form == "full":
        result = Batch()
        for name in args.infer.infer_names:
            model.set_target_name(name)
            result[name] = compute_attention_cause_single(args, params, model, batch, form, name=name, keep_all=keep_all)
    else:
        result = compute_attention_cause_single(args, params, model, batch, form, keep_all=keep_all)
    return result


def compute_attention_cause_single(args, params, model, batch, form,name="",  keep_all=False, attn_result=None):
    
    if attn_result: 
        infer_attn_result = attn_result
        attn_result = Batch()
    else:
        infer_attn_result = model.infer(batch, batch.valid, [form], keep_all= keep_all, additional=['attn'])[form]
        attn_result = Batch()
    # TODO: implement keep_all logic

    attn_result.trace = batch.trace[infer_attn_result.omit_flags[0]]

    # convert the weights to binaries and inter_grads, averages over layers, even though this is not quite principled unless there is only oen layer
    # weights of shape batch x num_layers x num_heads x keys x queries
    # print(list(infer_attn_result.keys()), infer_attn_result)
    # print("head_weights", infer_attn_result.attn[0].mean(dim=0).mean(dim=1).mean(dim=1).shape)
    input_weights = pytorch_model.unwrap(infer_attn_result.attn.mean(dim=0).mean(dim=1).mean(dim=1)) # mean over the mean and std, the layers, the heads
    input_weights = input_weights[:,0] # assumes only one key

    attn_result.mask_logits = input_weights
    attn_result.omit_flags = infer_attn_result.omit_flags
    attn_result.utrace = batch.trace if form == "all" else batch.trace[:, model.extractor.get_index([name])]
    attn_result.utrace = attn_result.utrace[attn_result.omit_flags[0],0]
    # if len(attn_result.utrace.shape) == 2: attn_result.utrace = np.expand_dims(attn_result.utrace, axis=1)
    # print([np.mean(pytorch_model.unwrap(attn_result.inter_masks[...,i][attn_result.utrace[...,i] == 1])) for i in range(attn_result.inter_masks.shape[-1])])
    # print(infer_attn_result.attn.shape, attn_result.mask_logits.shape, attn_result.utrace.shape)
    attn_result.inter_one_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(attn_result.mask_logits[...,i][attn_result.utrace[...,i] == 1]))) for i in range(attn_result.mask_logits.shape[-1])]), axis=0)
    attn_result.inter_zero_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(attn_result.mask_logits[...,i][attn_result.utrace[...,i] == 0]))) for i in range(attn_result.mask_logits.shape[-1])]), axis=0)
    if args.infer.attention.select_ideal: 
        midpoint = (attn_result.inter_one_trace_rate + attn_result.inter_zero_trace_rate) / 2
        bins = (pytorch_model.unwrap(attn_result.mask_logits) > midpoint).astype(int)        
    else:
        bins = (pytorch_model.unwrap(attn_result.mask_logits) > args.infer.attention.attention_threshold).astype(int)
    
    attn_result.inter_masks = bins
    attn_result.inter_variance = np.abs(pytorch_model.unwrap(attn_result.inter_masks))
    attn_result.bin_error = pytorch_model.unwrap(attn_result.inter_masks) - attn_result.utrace # assume only one target
    attn_result.total_error = np.abs(pytorch_model.unwrap(attn_result.inter_masks) - attn_result.utrace) # assume only one target
    attn_result.logit_error = pytorch_model.unwrap(attn_result.mask_logits) - attn_result.utrace # assume only one target
    attn_result.trace = attn_result.trace # redundant and somewhat expensive
    return attn_result

def compute_attention_loss(model, batch, args, params, results, keep_all = False):
    if type(results) == tuple:
        results = results[0] # just use the first one for embedding losses, if multiple
    input_weights = results.attn.mean(dim=0).mean(dim=1).mean(dim=1).mean(dim=1)
    # attn_result = compute_attention_cause_single(args, params, model, batch, form,name="",  keep_all=False, attn_result=None)
    # input_weights = attn_result.mask_logits.reshape(-1,model.num_objects) # assumes only one key
    entropy_loss = (-input_weights * torch.log(input_weights + 1e-6)).sum(axis = -1).mean()
    # print(input_weights.shape, input_weights[:10], entropy_loss)
    return entropy_loss * args.inter.regularization.attention.attn_reg_lambda
