import numpy as np
from ActualCausal.Inference.inference_utils import get_valid, compute_distributional_distance
from ActualCausal.Train.train_utils import compute_adaptive_rate, filter_batch_names, aggregate_result
from Network.network_utils import pytorch_model
from tianshou.data import Batch
import torch

def evaluate_distance(params1, params2): # computes the wasserstein 2 distance, TODO: could use KL distance
    # return torch.norm(params1[0] - params2[0], p=2, dim=-1) + torch.norm((params1[1] + params2[1] - 2 * torch.sqrt(params1[1] * params2[1])), p=2, dim=-1)
    return torch.norm(params1[0] - params2[0], p=2, dim=-1) # + torch.norm((params1[1] + params2[1] - 2 * torch.sqrt(params1[1] * params2[1])), p=2, dim=-1)

def evaluate_weights(weights, clip_value=-1):
    # weights of shape batch size (or reduced batch size), num heads, num keys, num queries
    # returns the query rate for each head, subject to clipping, and the number of unclipped heads / sample
    if clip_value > 0:
        weights[weights > clip_value] = 1
        weights[weights < clip_value] = 0
    num_layers, num_heads, num_keys = weights.shape[1], weights.shape[2], weights.shape[3]
    live_rate = weights.sum(dim=-1).reshape(-1).mean()
    per_weights = weights.sum(dim=1).sum(dim=1) / (num_heads * num_layers) # num_batch, num_keys, num queries
    return pytorch_model.unwrap(live_rate), pytorch_model.unwrap(per_weights), pytorch_model.unwrap(per_weights.sum(dim=1).mean(dim=0))

def compare_base(args, infer_result, masked_result, batch):
    # compares the outputs of the base (infer_result) to the result of the masked (masked_result), filtered by valid and not done (the other parameters)
    # at some point, we can probably reduce the number of input variables for simplicity
    outcome_binaries = list()
    if args.infer.nulls.use_vals == "dist":
        dists = pytorch_model.unwrap(compute_distributional_distance(args.infer.nulls.distance_form, infer_result, masked_result)) # distances as batch, target factors
        outcome_binaries.append((dists > args.infer.nulls.dist_epsilon))
    if args.infer.nulls.use_vals == "weights":
        _, per_weights, average_weights = evaluate_weights(infer_result[args.infer.nulls.weight_form], clip_value=CLIP_VALUE)
        mask_live, mask_per_weights, _ = evaluate_weights(masked_result[args.infer.nulls.weight_form], clip_value=CLIP_VALUE)
        outcome_binaries.append(mask_live > args.infer.nulls.null_weight_epsilon)
        dists = mask_live
    return sum(outcome_binaries), dists

def combine_masks(num_objects, all_dists):
    # all dists maps combinations of the indexes of 1s in the valid vector to whether nulling out that value had an effect for data i in batch
    # outputs batch, num_objects for which objects are inferred to be necessary
    k = list(all_dists.keys())[0]
    total_binaries = np.zeros((all_dists[k].shape[0], num_objects)).astype(bool)
    for vec in all_dists.keys():
        check = np.array(vec)
        total_binaries[:,check] += all_dists[vec].astype(bool) # or for any combination where it indicated true after a change
    return total_binaries.astype(float)

def infer_null(args, params, model,batch, keep_all=False, perform_analysis=-1, use_extractor_names = False, infer_names=[]):
    if len(infer_names) == 0:
        if len(args.infer.infer_names) > 0 and not use_extractor_names: infer_names = args.infer.infer_names
        else: infer_names = model.extractor.names

    single_inters = Batch()
    for n in infer_names: # this must match the names actually trained
        model.set_target_name(n)
        single_inters[n] = infer_null_single(n, args, params, model, batch, keep_all=keep_all, perform_analysis=perform_analysis)
    return single_inters

def infer_null_single(name, args, params, model, batch, keep_all=False, perform_analysis=-1):
    all_inters = list()
    # for dv in [0.2 * i for i in range(10)]:
    num_factors = len(model.extractor.names) - 2
    all_combinations = [np.arange(num_factors) for i in range(args.infer.nulls.max_combination)] # TODO: remove action from possible factors
    all_combinations = np.array(np.meshgrid(*all_combinations)).T.reshape(-1,args.infer.nulls.max_combination)
    # TODO: make general to not just full_name
    # valid = get_valid(batch.valid, model.extractor.get_index([name]))
    all_dists = dict()
    #  compute the base prediction
    additional = ["attention"] if args.infer.nulls.use_vals == "weights" else []
    # print(params.mask_mode)
    infer_result = model.infer(batch, batch.valid, params.mask_mode, additional = additional, keep_all=keep_all)[params.mask_mode] # TODO: all evaluation not tested, but might be more efficient
    # print("null input", batch[:5].obs.reshape(5, num_factors + 2, -1), batch[:5].target_diff, infer_result.log_probs[:5], infer_result.target[:5])
    total_mask = dict()
    total_log_likelihood = list()
    comb_totals = torch.zeros(num_factors).to(infer_result.log_probs.device)
    log_prob_comb_totals = torch.zeros(infer_result.log_probs.shape[0], num_factors).to(infer_result.log_probs.device)
    dist_totals = torch.zeros(infer_result.log_probs.shape[0], num_factors).to(infer_result.log_probs.device)
    mean_var_totals = (torch.zeros(infer_result.log_probs.shape[0], num_factors, infer_result.params[0].shape[-1]).to(infer_result.log_probs.device), torch.zeros(infer_result.log_probs.shape[0], num_factors, infer_result.params[0].shape[-1]).to(infer_result.log_probs.device))
    
    for cidx, comb in enumerate(all_combinations):
        # always drops the invalid values (where the target object is not present in the scene), TODO: not sure why I had code saying keep_invalid = model.extractor.get_index(name) in comb
        keep_invalid = False 
        # assign a mask with zero at each combination and one every else
        given_mask = np.ones(batch.valid.shape)
        given_mask[...,comb] = 0
        # compute the masked prediction parameters
        given_valid = batch.valid * given_mask
        masked_result = model.infer(batch, given_valid, params.mask_mode, additional=additional, keep_invalid=keep_invalid, keep_all=keep_all)[params.mask_mode]
        gv_idxes = get_valid(batch.valid, comb)
        # compute the binaries for this combination
        binaries, dists = compare_base(args, infer_result, masked_result, batch)
        if perform_analysis >= 0 and cidx == perform_analysis: analyze_errors(model, name, comb, binaries, batch, dists, infer_result, masked_result, target_names=model.extractor.get_name(comb))
        comb_totals[comb] += 1
        # print(masked_result.log_probs.sum(dim=-1).unsqueeze(-1).shape, log_prob_comb_totals.shape)
        # log_prob_comb_totals[...,comb] += torch.broadcast_to(masked_result.log_probs.sum(dim=-1).unsqueeze(-1), (masked_result.log_probs.shape[0], masked_result.log_probs.shape[1], len(comb)))
        # take the total log prob prediction for each index
        log_prob_comb_totals[...,comb] += masked_result.log_probs.sum(dim=-1).unsqueeze(-1)
        # take the total distance for each index
        dist_totals[...,comb] = dist_totals[...,comb] + pytorch_model.wrap(dists).to(infer_result.log_probs.device)
        # print(log_prob_comb_totals.shape, mean_var_totals[0].shape)
        mean_var_totals[0][...,comb,:] += masked_result.params[0].unsqueeze(1)
        # print(mean_var_totals[0].shape, mean_var_totals[1].shape, masked_result.params[1].shape, masked_result.params[0].shape)
        mean_var_totals[1][...,comb,:] += masked_result.params[1].unsqueeze(1)
        # combine the binaries 
        all_dists[tuple(comb)] = binaries
    result = Batch()
    # log prob statistics
    result.null_log_probs = log_prob_comb_totals / comb_totals.unsqueeze(0)
    result.null_params = mean_var_totals
    result.log_probs = infer_result.log_probs
    result.params = infer_result.params
    result.inter_masks = combine_masks(batch.valid.shape[-1], all_dists)
    # print(name, model.extractor.get_index([name]), infer_result.omit_flags[0], keep_all, infer_result.omit_flags)
    if keep_all: result.utrace = batch.trace[:, model.extractor.get_index(name)]
    else: result.utrace = batch.trace[infer_result.omit_flags[0], model.extractor.get_index([name])]
    # print(np.concatenate([pytorch_model.unwrap(infer_result.log_probs.mean(axis=-1)[:20].unsqueeze(-1)), batch.target_diff[:20, 8:12],batch.obs[:20], result.utrace[:20,1:2]],axis=-1))
    # print(result.inter_masks.shape, result.utrace.shape)
    # Trace statistics
    result.bin_error = result.inter_masks - result.utrace
    result.total_error = np.abs(result.inter_masks - result.utrace)
    result.inter_fp = (pytorch_model.unwrap(result.inter_masks) - result.utrace) # must have log_batch contain trace
    result.inter_fp[result.inter_fp<0] = 0
    result.inter_fn = (pytorch_model.unwrap(result.inter_masks) - result.utrace) # must have log_batch contain trace
    result.inter_fn[result.inter_fn>0] = 0
    result.inter_variance = np.expand_dims(np.std(pytorch_model.unwrap(result.inter_masks), axis=0), axis=0)
    result.inter_one_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(result.inter_masks[...,i][result.utrace[...,i] == 1]))) for i in range(result.inter_masks.shape[-1])]), axis=0)
    result.inter_zero_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(result.inter_masks[...,i][result.utrace[...,i] == 0]))) for i in range(result.inter_masks.shape[-1])]), axis=0)
    # distance statistics
    result.null_dists = dist_totals / comb_totals.unsqueeze(0)
    result.null_positive_dists = torch.zeros(num_factors).to(infer_result.log_probs.device)
    result.null_negative_dists = torch.zeros(num_factors).to(infer_result.log_probs.device)
    result.null_fp_dists = torch.zeros(num_factors).to(infer_result.log_probs.device)
    result.null_fn_dists = torch.zeros(num_factors).to(infer_result.log_probs.device)
    result.fp_log_probs = torch.zeros(num_factors).to(infer_result.log_probs.device)
    result.fn_log_probs = torch.zeros(num_factors).to(infer_result.log_probs.device)
    for i in range(num_factors):
        result.null_positive_dists[i] = result.null_dists[result.utrace[:,i].nonzero()[0]][:,i].mean()
        result.null_negative_dists[i] = result.null_dists[(1-result.utrace)[:,i].nonzero()[0]][:,i].mean()
        result.null_fp_dists[i] = result.null_dists[result.inter_fp[:,i].nonzero()[0]][:,i].mean()
        result.null_fn_dists[i] = result.null_dists[(result.inter_fn)[:,i].nonzero()[0]][:,i].mean()
        result.fp_log_probs[i] = result.log_probs[result.inter_fp[:,i].nonzero()[0]].sum()
        result.fn_log_probs[i] = result.log_probs[result.inter_fn[:,i].nonzero()[0]].sum()
    result.null_positive_dists = torch.nan_to_num(result.null_positive_dists.unsqueeze(0), nan=1.0)
    result.null_negative_dists = torch.nan_to_num(result.null_negative_dists.unsqueeze(0), nan=0.0)
    result.null_fp_dists = torch.nan_to_num(result.null_fp_dists.unsqueeze(0), nan=1.0)
    result.null_fn_dists = torch.nan_to_num(result.null_fn_dists.unsqueeze(0), nan=0.0)
    result.fp_log_probs = torch.nan_to_num(result.fp_log_probs.unsqueeze(0), nan=1.0)
    result.fn_log_probs = torch.nan_to_num(result.fn_log_probs.unsqueeze(0), nan=0.0)

    # base statistics
    result.valid = batch.valid
    result.omit_flags =  infer_result.omit_flags
    result.trace = batch.trace[infer_result.omit_flags[0]]
    # print(result.utrace.shape, result.inter_masks.shape)
    # print("inter_masks", np.concatenate([result.inter_masks[:5], result.utrace[:5]], axis=-1))
    # print("at distance", args.infer.nulls.dist_epsilon)
    # print("false positive:", np.sum((result.bin_error > 0).astype(int), axis=0) / np.sum(np.abs(result.bin_error), axis=0) )
    # print("false negative:", np.sum((result.bin_error < 0).astype(int), axis=0) / np.sum(np.abs(result.bin_error), axis=0) )
    # print(len(infer_result.omit_flags), result.inter_masks.shape[-2], (result.inter_masks.shape[-1]-1))
    # print("nulls total_error:", np.sum(np.abs(result.bin_error), axis=0) / (len(result.bin_error)))
    return result

def analyze_errors(model, name, comb, binaries, batch, dists, infer_result, masked_result, target_names=[]):
    '''
    # Gives printouts for the values where the binaries differ from the
    trace outputs
    '''
    print("used_inputs", comb, name, target_names, )
    names = [model.extractor.get_name(c) for c in comb]
    target_found = sum([n in names for n in target_names])
    if target_found: 
        mean_infer_log_prob = infer_result.log_probs.sum(axis=-1).unsqueeze(-1)
        mean_mask_log_prob = masked_result.log_probs.sum(axis=-1).unsqueeze(-1)
        infer_mean = infer_result.params[0]
        mask_mean = masked_result.params[0]
        target = infer_result.target
        valid = pytorch_model.wrap(batch.valid[masked_result.omit_flags[0], model.extractor.get_index(target_names)], cuda=model.iscuda).unsqueeze(-1)
        trace = pytorch_model.wrap(batch.trace[masked_result.omit_flags[0], model.extractor.get_index(name), model.extractor.get_index(target_names)], cuda=model.iscuda).unsqueeze(-1)
        if len(valid.shape) > 2: valid = valid.sum(dim=1) # take the combination of all of the indexed valid values
        if len(trace.shape) > 2: trace = trace.sum(dim=1)
        binaries = pytorch_model.wrap(binaries, cuda=model.iscuda)
        dists = pytorch_model.wrap(dists, cuda=model.iscuda)
        # TODO: add the actual counterfactual result
        # print(mean_infer_log_prob.shape, mean_mask_log_prob.shape, infer_mean.shape, mask_mean.shape, target.shape, valid.shape, trace.shape, binaries.shape, dists.shape)

        # print(model.extractor.get_index(name), model.extractor.get_index(target_names))
        all_vals = torch.cat([mean_infer_log_prob, mean_mask_log_prob, infer_mean[:,:4], mask_mean[:,:4], target[:,:4], valid, trace, binaries, dists], axis=-1)

        value_idxes = (binaries.squeeze() + trace.squeeze()) > 0
        difference_idxes = (binaries.squeeze() - trace.squeeze()).nonzero()
        FP_idxes = (binaries.squeeze() - trace.squeeze()) > 0
        FN_idxes = (binaries.squeeze() - trace.squeeze()) < 0
        
        print(batch.trace[masked_result.omit_flags[0], model.extractor.get_index(name)][value_idxes.cpu().numpy()][:10])
        print("printing out: ", "mean_infer_log_prob", "mean_mask_log_prob", "infer_mean(4)", "mask_mean(4)", "target(4)", "valid", "trace", "binaries", "dists")
        print(name, "nonzero", all_vals[value_idxes][:10])
        print(name, "fp", all_vals[FP_idxes][:10])
        print(name, "fn", all_vals[FN_idxes][:10])

        # print(name, model.extractor.get_name(comb[0]), "log_prob_values", torch.cat([infer_result.log_probs.sum(axis=-1).unsqueeze(-1), 
        #              masked_result.log_probs.sum(axis=-1).unsqueeze(-1), 
        #              pytorch_model.wrap(batch.valid[masked_result.omit_flags[0], model.extractor.get_index('$B')], cuda=model.iscuda).unsqueeze(-1),
        #              pytorch_model.wrap(batch.trace[masked_result.omit_flags[0], model.extractor.get_index(name), model.extractor.get_index('$B')], cuda=model.iscuda).unsqueeze(-1),
        #              pytorch_model.wrap(binaries, cuda=model.iscuda),
        #              pytorch_model.wrap(dists, cuda=model.iscuda),
        #              pytorch_model.wrap(batch.target_diff.reshape(-1, 5, 4)[masked_result.omit_flags[0], model.extractor.get_index('$B')], cuda=model.iscuda)], axis=-1)[value_idxes][:20])

# Null Consistency: regularize the size of representations that could be nulled out without loss of performance. 
# Adaptive strategies ( also pointwise or general): log likelihood of model performance, 
#                       log-likelihood of the nulled predict 
#                       log likelihood of single passive, 
#                       performance of the single passive model relative to interaction model, all of these can be executed at a pointwise level, or a general level
def compute_embed_reg(args, null_result, mask, params, model, batch, result):
    if args.inter.regularization.null_consistency.null_embed_reg_type[0] == "flat": return params.null_embed_reg
    if args.inter.regularization.null_consistency.null_embed_reg_type[0] == "base": 
        use_result = null_result 
    elif args.inter.regularization.null_consistency.null_embed_reg_type[0] == "null":
        use_result = Batch()
        use_result.log_probs = null_result.null_log_probs
        use_result.params = null_result.null_params
    elif args.inter.regularization.null_consistency.null_embed_reg_type[0] == "single_passive":
        use_result = result[args.inter.train_names[0]]["single_passive"] # TODO: if single passive not available, uses a model
    else: raise NotImplementedError("Invalid Reg Type")
    base_value = params.null_embed_reg
    baseline_likelihood = params.converged_active_loss_value if "converged_active_loss_value" in params else 3.5 * null_result.log_probs.shape[-1]
    adaptive_lasso = params.null_embed_reg # null_embed_reg overloaded for both base and adaptive
    bias = args.inter.regularization.null_consistency.null_adaptive[0]
    flatten_factor = args.inter.regularization.null_consistency.null_adaptive[1]
    return compute_adaptive_rate(args.null_embed_reg_type[1], False, base_value, adaptive_lasso, use_result, batch, baseline_likelihood, bias, flatten_factor, pointwise=args.inter.regularization.null_consistency.null_embed_reg_type[1])

def compute_null_consistency_losses(args, params, model,batch, results, perform_analysis=-1):
    # computes a loss by regularizing the magnitude of the embedding for values that have no effect on the outcome
    # TODO only uses the first name in train names
    null_result = infer_null(args, params, model,batch, perform_analysis=perform_analysis, use_extractor_names=True, infer_names=[args.inter.train_names[0]])
    mask = null_result[args.inter.train_names[0]].inter_masks
    # TODO: only uses the first name in the train names
    if type(results) == tuple:
        results = results[0] # just use the first one for embedding losses, if multiple
    embed_reg = compute_embed_reg(args, null_result, mask, params, model, batch, results)
    # print(type(mask), type(params.embed_reg))
    return results.pre_embeddings_query.mean(dim=-1) * float(embed_reg) * pytorch_model.wrap(mask, cuda=model.iscuda)

def infer_null_values(args, params, model, buffer, keep_all=True, use_extractor_names=True, infer_names = [], result_names=[], batch_size = 1024):
    '''
    Infers the null values over a buffer, returning the null assessments, weights for each index (measure of confidence), and the combined null result
    infer_names are the names to perform null inference over
    result_names are the keys in result that should be kept (to prevent using too much RAM)
    batch size is the size of the batches iterated through
    '''
    full_data, full_indices = buffer.sample(0)
    if len(infer_names) == 0:
        if len(args.infer.infer_names) > 0 and not use_extractor_names: infer_names = args.infer.infer_names
        else: infer_names = model.extractor.names

    results_batch = Batch()
    null_weights = {n: list() for n in infer_names}
    null_binaries = {n: list() for n in infer_names}
    for i in range(int(np.ceil(len(full_data) / batch_size)) ):
        batch = full_data[i * batch_size:(i+1) * batch_size]
        named_results = infer_null(args, params, model, batch, keep_all=keep_all,  infer_names=infer_names)
        for n in infer_names:
            null_binaries[n].append(pytorch_model.unwrap(named_results[n].inter_masks))
            null_weights[n].append(compute_null_weights(args, named_results[n]))
        filtered_batch, kept_any = filter_batch_names(named_results, result_names)
        results_batch = aggregate_result(results_batch, filtered_batch, i, combine_type="cat0")
    for n in null_weights.keys():
        null_weights[n] = np.concatenate(null_weights[n], axis=0)
        null_binaries[n] = np.concatenate(null_binaries[n], axis=0)
    return null_binaries, null_weights, results_batch


def compute_null_weights(args, results):
    '''
    various ways of computing a confidence weight between 0 and 1 for a particular set of results
    returns a vector of values 0-1
    TODO: brainstorm and implement better methods
    '''
    keep = None
    if args.inter.null_em.null_bin_weight_type == "likelihood": 
        drop = 1 - pytorch_model.unwrap((results.null_log_probs > args.inter.null_em.bin_weight[0])  |
                                     (results.log_probs.mean(dim=-1).unsqueeze(-1) < args.inter.null_em.bin_weight[1]))
        print("dropped", np.sum(drop), drop.shape)
        keep = np.clip(pytorch_model.unwrap((results.log_probs.mean(dim=-1).unsqueeze(-1) - results.null_log_probs) / args.inter.null_em.bin_weight[2]), 0, args.inter.null_em.bin_weight[3]) * drop
    elif args.inter.null_em.null_bin_weight_type == "uni":
        keep = np.ones(results.log_probs.shape[0])
    else:
        raise NotImplementedError("Invalid null_bin_weight_type, avaliable are: likelihood, uni")
    return keep
    # result.null_log_probs = log_prob_comb_totals / comb_totals.unsqueeze(0)
    # result.null_params = mean_var_totals
    # result.log_probs = infer_result.log_probs
    # result.params = infer_result.params
