import torch
import copy
import numpy as np
from ActualCausal.Utils.run_dataset import get_operation, compute_types

def select_proximity(name, extractor, proximity):
    idx = extractor.get_index(name)
    if name == "all":
        passive = np.eye(proximity.shape[-1])
        # closest distance of any object to any other, except self
        return np.min(np.min(proximity  - np.broadcast_to(passive, proximity.shape), axis=-1), axis=-1) 
    passive = np.zeros(extractor.num_objects)
    passive[idx] = 1
    # closest distance of the named object to every other, except self
    return np.min(proximity[:,idx] - np.broadcast_to(passive, proximity.shape), axis=-1)

def proximity_binary(args, model, buffer):
    # construct proximity batches if necessary
    # etype = compute_types.PROXIMITY_FULL if full else (compute_types.PROXIMITY_ALL if pall else compute_types.PROXIMITY)
    # proximal = get_operation(model, buffer, compute_type=etype).astype(int)
    # proximal_inst = get_operation(model, buffer, compute_type=etype, reduced=False).astype(int) # the same as above if not multiinstanced
    buffer = buffer[:len(buffer)]
    if args.inter.all_proximity: proximal_inst = select_proximity("all", model.extractor, buffer.proximal)
    else: proximal_inst = select_proximity(model.train_names[0], model.extractor, buffer.proximal)
    non_proximal = (proximal_inst != 1).astype(int)
    return proximal_inst, non_proximal

def uni_weights(buffer):
    passive_error = torch.ones(len(buffer))
    binaries = np.ones(len(buffer))
    weights = np.ones(len(buffer)).astype(np.float64) / float(len(buffer))
    return passive_error, weights, binaries

def passive_binary(passive_error, weighting, proximity, done):
    passive_error_cutoff, passive_error_upper = weighting
    # if len(passive_error) < 1000: binaries = copy.deepcopy(passive_error) # in case we want to preserve passive error, but don't copy if this is too large
    # else: binaries = passive_error
    binaries = copy.deepcopy(passive_error)
    # print(binaries[:30], passive_error_cutoff, binaries[:30] > passive_error_cutoff)
    greater = binaries>passive_error_cutoff
    lesser = binaries<=passive_error_cutoff
    over = binaries>passive_error_upper
    binaries[greater] = 1
    binaries[lesser] = 0
    binaries[over] = 0 # if the error is too high, this might be an anomaly
    binaries[done == 1] = 0 # if done, disregard

    # use proximity to narrow the range of weights, if proximity is not used, these should be ones TODO: replace with feasibility?
    if proximity is not None: binaries = (binaries.astype(int) * (proximity).astype(int)).astype(np.float128).squeeze()
    return binaries

def get_trace_weights(trace, weighting_ratio, trace_idxes=None):
    trace_np = trace - np.expand_dims(np.eye(trace.shape[-1]), axis=0)
    trace_np = trace_np[:,np.array(trace_idxes)]
    binaries = np.max(trace_np, axis=-1).squeeze()
    weights = get_weights(weighting_ratio, binaries)
    return binaries, binaries, weights

def separate_weights(args, weighting_type, model, buffer, wrap_function=None): # this should work for all cases because passive_likelihood reduces, the only difference is the passive error threshold
    '''
    Generates weights either based on the passive error, the trace, or returns a single vector. This value is shared across the training computation
    '''
    if weighting_type == "passive_error":
        if len(args.inter.train_names) > 0: passive_error = - get_operation(model, buffer, all_compute=[compute_types.PASSIVE_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names)[0]
        else: passive_error = - np.sum(get_operation(model, buffer, all_compute=[compute_types.ALL_PASSIVE_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names, wrap_function=wrap_function)[0], axis=-1)
        # print((buffer.done.astype(float) * -100).shape, passive_error.shape)
        passive_error = (buffer.done.astype(float) * -100).squeeze() + passive_error # ensures that dones have "low" passive error so they don't get sampled
        # weighting hyperparameters, if passive_error_cutoff > 0 then using passive weighting
        binaries = passive_binary(passive_error, args.inter.passive_weighting, select_proximity(args.inter.train_names[0], model.extractor, buffer.proximity) if args.inter.weight_proximity else None, buffer.done.squeeze())
        if len(binaries.shape) > 1 and binaries.shape[-1] > 1: 
            binaries = np.sum(binaries, axis=-1)
            binaries[binaries > 1] = 1
        weights = get_weights(1, binaries)
        print("passive_error", "binaries", "trace_sum", "target_diff")
        for i in range(10):
            print("passive_error", np.concatenate([np.expand_dims(passive_error[i *250:(i+1) *250], axis=-1),
                                                np.expand_dims(binaries[i *250:(i+1) *250], axis=-1), 
                                                np.expand_dims( np.sum(buffer.trace[i *250:(i+1) *250, model.extractor.get_index(args.inter.train_names[0]),:4], axis = -1), axis=-1),
                                                buffer.done[i *250:(i+1) *250],
                                                buffer.target_diff[i *250:(i+1) *250].reshape(250,model.extractor.num_objects,-1)[:,model.extractor.get_index(args.inter.train_names[0])]], axis=-1))
        # print("passive_error", np.concatenate([np.expand_dims(passive_error[250:500], axis=-1), 
        #                                        np.expand_dims(binaries[250:500], axis=-1), 
        #                                        np.expand_dims( np.sum(buffer.trace[250:500, model.extractor.get_index(args.inter.train_names[0])], axis = -1), axis=-1), 
        #                                        buffer.target_diff[250:500].reshape(250,model.extractor.num_objects,-1)[:,model.extractor.get_index(args.inter.train_names[0]), :4]], axis=-1))
        # print("passive_error", np.concatenate([np.expand_dims(passive_error[500:750], axis=-1), 
        #                                        np.expand_dims(binaries[500:750], axis=-1), 
        #                                        np.expand_dims( np.sum(buffer.trace[500:750, model.extractor.get_index(args.inter.train_names[0])], axis = -1), axis=-1), 
        #                                        buffer.target_diff[500:750].reshape(250,model.extractor.num_objects,-1)[:,model.extractor.get_index(args.inter.train_names[0]), :4]], axis=-1))
        # print("passive_error", np.concatenate([np.expand_dims(passive_error[750:1000], axis=-1), 
        #                                        np.expand_dims(binaries[750:1000], axis=-1), 
        #                                        np.expand_dims( np.sum(buffer.trace[750:1000, model.extractor.get_index(args.inter.train_names[0])], axis = -1), axis=-1), 
        #                                        buffer.target_diff[750:1000].reshape(250,model.extractor.num_objects,-1)[:,model.extractor.get_index(args.inter.train_names[0]), :4]], axis=-1))
        bin_idxes = binaries.nonzero()
        # print(buffer.trace[bin_idxes[0], model.extractor.get_index(args.inter.train_names[0])].shape, buffer.target_diff.shape)
        print("vals", np.concatenate([buffer.done[bin_idxes], buffer.trace[bin_idxes[0], model.extractor.get_index(args.inter.train_names[0])], buffer.target_diff[bin_idxes[0],24:32]], axis=-1)[:100])
        print("trace_rate", np.mean(np.sum(buffer.trace[bin_idxes[0], model.extractor.get_index(args.inter.train_names[0])], axis=-1)), np.mean(np.sum(buffer.trace[:, model.extractor.get_index(args.inter.train_names[0])], axis=-1)))
        if np.sum(binaries) == 0:
            print("NO PASSIVE FOUND, USING UNIFORM WEIGHTS")
            passive_error, weights, binaries = uni_weights(buffer)
    elif weighting_type == "proximity" or weighting_type == "non_prox":
        prox_binaries, np_binaries = proximity_binary(model, buffer, full=False, pall=False)
        binaries = np_binaries if weighting_type == "non_prox" else prox_binaries
    elif weighting_type == "trace":
        passive_error, binaries, weights = get_trace_weights(buffer.trace[:len(buffer)], 1, model.trace_idx)
    else: # no special weighting on the samples
        passive_error, weights, binaries = uni_weights(buffer)
    if len(buffer.weight_binary.shape) == 2 and len(binaries.shape) == 1: binaries = np.expand_dims(binaries, -1)
    return passive_error, weights, binaries


def get_weights(ratio_lambda, binaries):
    # binaries are 0-1 values, either the trace values (supervised interactions)
    # or where the passive error exceeds a threshold, possibly combined with proximal states
    # determine error based binary weights
    weights = binaries.copy()
    num_weighted = binaries.copy().astype(bool).astype(int)

    # passes through if we are using uniform weights
    if ratio_lambda <= 0 or np.sum(weights) == 0:
        if np.sum(weights) == 0:
            weights = np.ones(weights.shape)
        weights = (weights.astype(np.float64) / np.sum(weights).astype(np.float64))
        weights[weights < 0] = 0
        if len(weights.shape) == 2: weights = weights[:,0] # squeeze the last dimension
        return None

    # generate a ratio based on the number of live versus dead
    total_live = np.sum(weights)
    total_dead = np.sum((num_weighted + 1)) - np.sum(num_weighted) * 2

    # for a ratio lambda of 1, will get 50-50 probability of sampling a "live" (high passive error) versus "dead" (low passive error)
    # change to: (total_dead - total_live) / total_live 
    live_factor = np.float64(np.round((total_dead - total_live) / max(1, total_live) * ratio_lambda))
    print("live factor", ratio_lambda, np.sum((weights + 1)), total_dead, total_live,(total_dead - total_live),np.round((total_dead - total_live) / max(1, total_live) * ratio_lambda), live_factor)
    if live_factor < 0: live_factor = 0
    weights = (weights * live_factor) + 1.0
    weights = (weights.astype(np.float64) / np.sum(weights).astype(np.float64))
    if len(weights.shape) == 2: weights = weights[:,0] # squeeze the last dimension
    return weights