import numpy as np
import os, cv2, time, copy, itertools
import torch
from collections import OrderedDict
from tianshou.data import Batch
from Network.network_utils import pytorch_model

def broadcast(arr, size, cat=True, axis=0): # always broadcasts on axis 0
    if cat: return np.concatenate([arr.copy() for i in range(size)], axis= axis)
    return np.stack([arr.copy() for i in range(size)], axis=axis)

def flatten(factored_state, names):
    flat = list()
    for name in names:
        flat.append(factored_state[name])
    return np.concatenate(flat, axis=-1)

def add_id(add_state, append_id, num_objects):
    if append_id >= 0:
        id_append_hot = np.zeros(add_state.shape[:-1] + (num_objects, ))
        id_append_hot[...,append_id] = 1
        # print(append_id, id_append_hot.shape, pad, np.zeros(states[name].shape[:-1] + (pad, )).shape)
        add_state = np.concatenate((add_state, id_append_hot), axis=-1)
    return add_state

def add_pad_flat(object_state, pad_size, append_id, num_objects):
    pad = pad_size - object_state.shape[-1]
    add_state = np.concatenate((object_state, np.zeros(object_state.shape[:-1] + (pad, ))), axis=-1) if pad > 0 else object_state
    add_state = add_id(add_state, append_id, num_objects)
    return add_state

def add_pad(states, name, pad_size, append_id, num_objects):
    pad = pad_size - states[name].shape[-1]
    add_state = np.concatenate((states[name], np.zeros(states[name].shape[:-1] + (pad, ))), axis=-1) if pad > 0 else states[name]
    add_state = add_id(add_state, append_id, num_objects)
    return add_state

class PadSelector():
    def __init__(self, sizes, instanced, names, append_id=False, apply_pad = True):
        # sizes is the size of each object
        # instanced is the number of instances of each object
        # names is the list of all object names, used for ordering all_extraction
        # append_id appends the identification 
        # if factored does not contain the full state, this extractor is just for extracting from padded states

        self.instanced = instanced
        self.names = names
        self.name_id = {n: self.names.index(n) for n in self.names}
        self.sizes = sizes 
        self.pad_size = np.max(list(sizes.values()))
        self.num_objects = len(list(sizes.values()))
        self.append_id = append_id
        self.apply_pad = apply_pad
        self.append_pad_size = int(apply_pad) * self.pad_size + int(self.append_id) * self.num_objects
    
    def get_mask(self, mask, names):
        '''Takes in a mask of shape [...,num_objects, num_objects]
            and gets the rows of the mask corresponding to the names
        '''
        return np.stack([mask[...,self.name_id[n],:] for n in names], axis=-2)

    def __call__(self, states, names=[]):
        '''
        states are dict[name] -> ndarray: [batchlen, object state + zero padding]
        returns [batchlen, flattened state], where the flattened state selects objects in names
        it does not select only the masked values 
        '''
        flattened = list()
        if len(names) == 0: names =self.names
        for name in names:
            id_append = self.name_id[name] if self.append_id else -1
            if self.instanced[name] > 1:
                for i in range(self.instanced[name]):
                    addst = add_pad(states, name + str(i), self.pad_size, id_append, self.num_objects) if self.apply_pad else states[name + str(i)]
                    flattened.append(addst)
            else:
                addst = add_pad(states, name, self.pad_size, id_append, self.num_objects) if self.apply_pad else states[name]
                flattened.append(addst)
        return np.concatenate(flattened, axis=-1)

    def add_padding(self, flat_state):
        # adds padding and id if using ids to a flat state without padding
        # TODO: would be more efficient in parallel
        # TODO: add functionality to handle adding just the ids
        pad_flat = list()
        # print(self.num_objects, self.sizes, int(flat_state.shape[-1] // self.pad_size), flat_state.shape[-1], self.pad_size)
        at = 0
        for n in self.names:
            obj_state = flat_state[...,at:at+self.sizes[n]]
            id_append = self.name_id[n] if self.append_id else -1
            pad_flat.append(add_pad_flat(obj_state, self.pad_size, id_append, self.num_objects))
        return np.concatenate(pad_flat, axis=-1)

    def get_padding(self, states):
        '''
        states are dict[name] -> ndarray: [batchlen, object state + zero padding]
        returns [batchlen, padding binaries], where the padding binary is 1 where padding is used
        does not mask out append padding, if used
        '''
        pad_vector = list()
        for name in self.names:
            pv = np.zeros(*states.shape[:-1], self.append_pad_size)
            pad = self.pad_size - states[name].shape[-1] if self.apply_pad else 0
            pv[...,states[name].shape[-1]:states[name].shape[-1] + pad] = 1
            pad_vector.append(pv)
        return np.concatenate(pad_vector, axis=-1)

    def output_size(self, names=[]):
        return sum([self.append_pad_size * self.instanced[n] for n in self.names])

    def reverse(self, flat_state, prev_factored=None, names=[]):
        '''
        unflattens a flat state [batch, output_size]
        pretty much the same logic as the full selector, but goes by pad size
        '''
        factored = dict() if prev_factored is None else prev_factored
        names = self.names if len(names) == 0 else names
        at = 0
        for name in self.names:
            if self.instanced[name]:
                for i in range(self.instanced[name]):
                    factored[name + str(i)] = flat_state[...,at +self.sizes[name]]
                    at = at + self.append_pad_size # skip any padding
            else:
                factored[name] = flat_state[...,at + self.sizes[name]]
                at = at + self.append_pad_size
        return factored

    def get_idxes(self, names):
        at = 0
        idxes = list()
        name_check = set(names)
        for name in self.names:
            if self.instanced[name]:
                for i in range(self.instanced[name]):
                    full_name = name + str(i)
                    if full_name in name_check:
                        idxes += (at + self.factored[name]).tolist()
                    at += self.append_pad_size
            else:
                if name in name_check:
                    idxes += (at + self.factored[name]).tolist()
                at += self.append_pad_size
        return np.array(idxes)

    def assign(self, state, insert_state, names = None):
        # assigns only the factored indices, names should overlap with self.names
        if names is None: names = self.names 
        if type(insert_state) == np.ndarray:
            if type(state) == np.ndarray:
                idxes = self.get_idxes(names)
                state[...,idxes] = insert_state
            else:
                at = 0
                for name in names:
                    o_name = name.strip("0123456789")
                    size = len(self.factored[o_name])#self.sizes[o_name]
                    state[name][self.factored[o_name]] = insert_state[at:at + size]
                    at += self.append_pad_size
        else: # assume that insert state is a dict
            if type(state) == np.ndarray:
                idxes = self.get_idxes(names)
                state[...,idxes] = flatten(insert_state, names)
            else:
                for name in insert_state.keys():
                    state[name] = insert_state[name]
        return state