import numpy as np
from Network.network_utils import pytorch_model
import copy
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Inference.compute_inference import compute_inference
from tianshou.data import Batch
from ActualCausal.Utils.run_utils import * # imports the naming conventions for errors, and compute_types dictionary


def compute_values(model, all_compute, part, reduced =False, normalized = False, object_names=None, given_mask=None, keep_all=True):
    # @param all_compute are all the compute types, using the names from Causal.Utils.run_utils compute_types
    # @param part is the segment of rollout data
    # @param reduced reduces along the final output, combining the features of object state
    # @param prenormalize normalizes the inputs 
    # @param given_mask of shape [num_objects]
    # @param object_names the first name is source, the second name is target 
    # computes the values of a certain output on a batch of the data

    # unnormalize functions, mostly for assessment
    rv = lambda x: model.norm.reverse(x, form="dyn" if model.mp.predict_dynamics else "target", name=model.target_name) # self.output_normalization_function.reverse
    rv_var = lambda x: model.norm.reverse(x, form="dyn" if model.mp.predict_dynamics else "diff", name=model.target_name) # self.output_normalization_function.reverse
    nf = lambda x: model.norm(x, form = "dyn" if model.mp.predict_dynamics else "target", name=model.target_name) # self.output_normalization_function
    num_batch = len(part)

    done_flags = (1-part.done).squeeze() # shape: batch_len
    infer_dict = get_base_infer(all_compute)
    additional_vals = get_additional_infer(all_compute)
    # TODO: given mask incompatible with other get_computes
    valid = part.valid if given_mask is None else part.valid * np.broadcast_to(given_mask, (num_batch, given_mask.shape[0]))
    result = model.infer(part, valid, list(set([infer_dict[num_compute_types[v]] for v in all_compute])), additional = additional_vals, log_batch=["trace"], keep_invalid=False, keep_all=keep_all)

    # print(list(result.keys()), list(result[infer_dict[num_compute_types[all_compute[0]]]].keys()), all_compute, infer_dict[num_compute_types[all_compute[0]]], result[infer_dict[num_compute_types[all_compute[0]]]].log_probs[:10], result[infer_dict[num_compute_types[all_compute[0]]]].target[:10])

    outcomes = list()
    for compute_type in all_compute:
        use_norm = False
        if compute_type in mean_names:
            outcome = pytorch_model.unwrap(result[infer_dict[num_compute_types[compute_type]]].params[0] - result[infer_dict[num_compute_types[compute_type]]].target)
            use_norm = True
            if not normalized: outcome = rv(outcome)
        elif compute_type in var_names:
            outcome = pytorch_model.unwrap(result[infer_dict[num_compute_types[compute_type]]].params[1])
            if not normalized: outcome = rv_var(outcome)
        elif compute_type in raw_names:
            outcome = pytorch_model.unwrap(result[infer_dict[num_compute_types[compute_type]]].params[0])
            if not normalized: outcome = rv(outcome)
        elif compute_type in like_names or compute_type in given_names: # given names by default uses the same output and is incompatible
            outcome = pytorch_model.unwrap(result[infer_dict[num_compute_types[compute_type]]].log_probs)
        elif compute_type in grad_names:
            outcome = pytorch_model.unwrap(result[infer_dict[compute_type]].gradient)
        elif compute_type in (inter + all_inter): # TODO: right now, all likelihood computation goes to mask_logits, but this should change
            outcome = pytorch_model.unwrap(result[infer_dict[num_compute_types[compute_type]]].mask_logits)
        elif compute_type in prox_names:
            all_proximity, all_dists = get_proximity(model.extractor.pos_size, model.fp.single_obj_dim, part.target_state, model.extractor.sp.proximity_epsilon)
            if compute_type == compute_types.PROXIMITY_ALL: return all_proximity
            else: 
                if compute_type == compute_types.PROXIMITY: outcome = all_proximity[:, object_names[0], object_names[1]]
                elif compute_type == compute_types.PROXIMITY_FULL: outcome = all_proximity[:, object_names[1]]
                if compute_type == compute_types.PROXIMITY_FLAT: outcome = all_dists[:, object_names[1]]
        elif compute_type == compute_types.TRACE: return part.trace
        elif compute_type == compute_types.DONE: return part.done
        else: raise Exception("invalid error type")
        if keep_all: outcome = outcome * np.expand_dims(done_flags, axis=-1) # zero out the dones if we are keeping them
        if reduced and use_norm: outcome = np.linalg.norm(outcome, ord=1, axis=-1)
        else: outcome = np.mean(outcome, axis=-1) # TODO: might not work for booleans
        outcomes.append(outcome)
        # print(outcome[:10], compute_type, compute_type in like_names or compute_type in given_names)
    return outcomes
    
def get_proximity(pos_size, pad_size, flattened_state, proximity_epsilon):
    # flattened state is of shape [batch, num_objects * pad_size] or [num_objects * pad_size], outputs [batch, num_objects, num_objects, pos_size]
    unflat_state = flattened_state.reshape((flattened_state.shape[0], int(flattened_state.shape[-1] // pad_size), int(pad_size))) if len(flattened_state.shape) == 2 else flattened_state.reshape((int(flattened_state.shape[-1] // pad_size), int(pad_size)))
    total_prox, total_dists = list(), list()
    for i in range(unflat_state.shape[-2]):
        dists = np.abs(unflat_state[...,i, :pos_size] - unflat_state[...,:pos_size])
        proximity = dists < proximity_epsilon
        total_prox.append(proximity)
        total_dists.append(dists)
    return np.stack(total_prox, axis=-3), np.stack(total_dists,axis=-3)

def get_operation(full_model, buffer, all_compute=[], reduced=True, normalized=False, object_names=None, given_mask=None, keep_all=True, wrap_function=None, batch=None):
    # computes some term over the entire rollout, iterates through batches of 500 to avoid overloading the GPU
    # @param wrap_function wraps a batch so it is usable for the function
    # @param buffer can be a batch or a buffer, using tianshou.data.Batch, otherwise assumes a subclass of tianshou.data.ReplayBuffer
    # gets all the data from rollouts, in the order of the data (for assignment)
    # model.set_target_name must be set appropriately prior to use

    model_errors = []
    CUTSIZE = 500
    if type(buffer) == Batch: part, range_val = batch, 1
    else: range_val = int(np.ceil(len(buffer) / min(CUTSIZE,len(buffer))))
    for i in range(range_val): # run CUTSIZE at a time, so that we don't overload the GPU
        if type(buffer) != Batch:
            part = buffer[i*CUTSIZE:(i+1)*CUTSIZE]
            part = wrap_function(part) if wrap_function is not None else part
        done_flags = np.expand_dims((1-part.done).squeeze(), -1)
        values = compute_values(full_model, all_compute, part, normalized=normalized, reduced=reduced, object_names = object_names, given_mask=given_mask)
        model_errors.append(values)
    return [np.concatenate([v[i] for v in model_errors], axis=0) for i in range(len(model_errors[0]))]
