from Environment.Normalization.pad_norm import PadNormalizationModule
from ACState.pad_selector import PadSelector
from ACState.object_dict import ObjDict
from tianshou.data import Batch
import numpy as np


# todo: set rng and dyn based on the encoding values according to the range and variance of the seen encodings
def regenerate(args, environment, all=False, encoding_dim=-1, enc_rng=None, enc_dyn = None):
    # removes the object_name fom the object
    # assumes that passive names can be identified with # as the main character
    object_names = [n for n in environment.object_names if n[0] != '#'] if args.inter.passive_reassign else environment.object_names
    all_names = [n for n in environment.all_names if n[0] != '#'] if args.inter.passive_reassign else environment.all_names

    extractor = Extractor(args, all_names, object_names, environment, encoding_dim)
    # norm = FullNormalizationModule(environment.object_range, environment.object_dynamics, name, environment.object_instanced, environment.object_names)
    pad_size = extractor.pad_dim
    expand_size = extractor.expand_dim
    
    object_range, object_range_true, object_dynamics, instanced = environment.object_range, environment.object_range_true, environment.object_dynamics, environment.object_instanced
    if args.inter.passive_reassign: 
        object_range, object_range_true, object_dynamics, instanced = {n: object_range[n] for n in object_names}, {n: object_range_true[n] for n in object_names}, {n: object_dynamics[n] for n in object_names}, {n: instanced[n] for n in object_names}
    # replace the ranges with different ranges if we are using encodings instead of the object ranges
    if encoding_dim > 0:
        for n in environment.object_range.keys():
            if n not in ["Action", "Reward", "Done"]: # TODO: in the future could have issues if there are other reserved keys
                object_range[n], object_dynamics[n] = enc_rng[n], enc_dyn[n]
    
    norm = PadNormalizationModule(object_range, object_range_true, object_dynamics, instanced, object_names, pad_size, expand_size, all=all)
    args.factor.first_obj_dim, args.factor.single_obj_dim, args.factor.object_dim, args.factor.all_obj_dim, args.factor.named_first_obj_dim = extractor._get_dims()
    args.factor.num_objects = len(all_names)
    args.factor.name_idxes = extractor.name_idxes
    args.factor.name_idx = -1
    return extractor, norm

def get_factor_params(extractor):
    factor = ObjDict()
    factor.first_obj_dim, factor.single_obj_dim, factor.object_dim, factor.all_obj_dim, factor.named_first_obj_dims = extractor._get_dims()
    factor.start_dim = -1 # uses first_obj_dim
    factor.name_idxes = extractor.name_idxes
    return factor

class Extractor():
    '''
    extracts observations from factored states, or vise versa.
    Also can be used to convert flattened collections of states to pick out particular objects
    obs_selector is used to select from factored states with ID padding
    target_selector is used without (to limit the number of targets to predict, which artificially raises the log likelihood) 
    '''
    def __init__(self, args, all_names, object_names,  environment, encoding_dim=-1):
        # get the environment specific components
        self.names = all_names
        self.object_names = object_names
        self.kept_nidx = [environment.all_names.index(n) for n in all_names]
        self.kept_noidx = [environment.object_names.index(n) for n in object_names]
         # TODO: in the future could have issues if there are other reserved keys
        object_sizes = {n : environment.object_sizes[n] for n in self.object_names}
        self.sizes = object_sizes if encoding_dim < 0 else {**{n: object_sizes[n] for n in ["Action", "Reward", "Done"]}, **{n: encoding_dim for n in object_sizes if n not in ["Action", "Reward", "Done"]}}
        self.instanced = {n: environment.object_instanced[n] for n in self.object_names}

        # store proximity values here
        self.pos_size = environment.pos_size if encoding_dim < 0 else encoding_dim
        self.sp = args.state

        # initialze the two main selectors, where most of the logic is
        self.obs_selector = PadSelector(self.sizes, self.instanced, self.object_names, args.state.append_id)
        self.target_selector = PadSelector(self.sizes, self.instanced, self.object_names, args.state.key_append_id)
        self.unappend_selector = PadSelector(self.sizes, self.instanced, self.object_names, False)

        # comput the important dimensions
        self.pad_dim = max(list(self.sizes.values()))
        self.append_dim =len(list(self.sizes.keys()))
        self.expand_dim = self.pad_dim + self.append_dim * float(self.sp.append_id)
        self.key_expand_dim = self.pad_dim + self.append_dim * float(self.sp.key_append_id)
        self.first_obj_dim, self.target_dim,self.object_dim,self.all_obj_dim,self.named_first_obj_dims = self._get_dims()
        self.name_idxes = {name: self.get_index_single(name) for name in self.object_names}
        self.num_objects = len(self.names)
    
    def get_index_single(self, name):
        if name in self.instanced and self.instanced[name] > 1:
            return [self.names.index(name + str(i)) for i in range(self.instanced[name])]
        else: return [self.names.index(name)]
        
    def get_index(self, name):
        if type(name) == list: # MUST send multiinstanced as list, to ensure return type usage
            return sum([self.get_index_single(n) for n in name], start=list())
        return self.names.index(name)

    def get_name(self, idxes):
        if type(idxes) == list or type(idxes) == np.ndarray:
            return [self.names[idx] for idx in idxes]
        return self.names[idxes]

    def _get_dims(self):
        first_obj_dim = self.key_expand_dim * len(self.names)
        target_dim = self.key_expand_dim
        all_obj_dim = self.expand_dim * len(self.names)
        named_first_obj_dims = Batch()
        for k in self.sizes.keys():
            named_first_obj_dims[k] = int(target_dim * self.instanced[k])
        return int(first_obj_dim), int(target_dim), int(self.expand_dim), int(all_obj_dim), named_first_obj_dims

    def get_selectors(self):
        return self.obs_selector, self.target_selector

    def get_obs(self, factored_state, names=[]):
        return self.obs_selector(factored_state, names)

    def get_target(self, factored_state, names=[]):
        return self.target_selector(factored_state, names)

    def get_named_target(self, flat_state, names=[]):
        if len(names) == 0: return flat_state
        flat_state = flat_state.reshape(flat_state.shape[0], -1, self.target_dim)
        return flat_state[:,self.get_index(names)]

    def get_named_obs(self, flat_state, names=[]):
        if len(names) == 0: return flat_state
        flat_state = flat_state.reshape(flat_state.shape[0], -1, self.object_dim)
        return flat_state[:,self.get_index(names)]

    def get_factored_from_obs(self, flat_state):
        return self.obs_selector.reverse(flat_state)