import numpy as np
import torch


class Batch:
    def __init__(self, batch_dict=None, **kwargs):
        if batch_dict is not None:
            for batch_key, obj in batch_dict.items():
                self.__dict__[batch_key] = obj
        if len(kwargs) > 0:
            self.__init__(kwargs)  

    def __getitem__(self, index):
        if isinstance(index, str):
            return self.__dict__[index]
        new_batch = Batch()
        for batch_key,obj in self.__dict__.items():
            if isinstance(obj,np.ndarray) or isinstance(obj,torch.Tensor):
                new_batch.__dict__[batch_key] = obj[index]
            else:
                new_batch.__dict__[batch_key] = obj
        return new_batch

    def __setitem__(self, index, value):
        if isinstance(index, str):
            self.__dict__[index] = value
            return
        for key, val in self.items():
            try:
                self.__dict__[key][index] = value[key]
            except KeyError:
                if isinstance(val, Batch):
                    self.__dict__[key][index] = Batch()
                elif isinstance(val, torch.Tensor) or \
                        (isinstance(val, np.ndarray) and
                         issubclass(val.dtype.type, (np.bool_, np.number))):
                    self.__dict__[key][index] = 0
                else:
                    self.__dict__[key][index] = None      

    def update(self, batch=None, **kwargs):
        if batch is None:
            self.update(kwargs)
            return
        for batch_key, obj in batch.items():
            self[batch_key] = obj
        if kwargs:
            self.update(kwargs)   

    def __len__(self):
        for obj in self.__dict__.values():
            if isinstance(obj,np.ndarray):
                return obj.shape[0]
            elif isinstance(obj,torch.Tensor):
                return obj.size()[0]
        return 0 

    def keys(self):
        return self.__dict__.keys()  

    def items(self):
        return self.__dict__.items()

    def values(self):
        return self.__dict__.values() 

    def split(self, size, shuffle=True, merge_last=False):
        length = len(self)
        assert 1 <= size  # size can be greater than length, return whole batch
        if shuffle:
            indices = np.random.permutation(length)
        else:
            indices = np.arange(length)
        merge_last = merge_last and length % size > 0
        for idx in range(0, length, size):
            if merge_last and idx + size + size >= length:
                yield self[indices[idx:]]
                break
            yield self[indices[idx:idx + size]]

class ReplayBuffer:
    def __init__(self, size):
        self.batch_dict = dict(obs={},act={},obs_next={},rew={},done={},info={},feature={})
        self.size = size
        self._data = Batch(self.batch_dict)
        self._indices = np.arange(size)
        self.ptr = 0
        self._ep_rew = 0
        self._ep_len = 0

    def reset(self):
        self._data = Batch(self.batch_dict)
        self.ptr = 0

    def add(self, data):
        self._data[self.ptr] = data    
        self._ep_rew += data.rew
        self._ep_len += 1   
        self.ptr += 1
        if self.ptr==self.size:
            self.ptr = 0
        if data.done:
            episode_reward = self._ep_rew
            episode_length = self._ep_len
            self._ep_rew = 0
            self._ep_len = 0
            return episode_reward, episode_length
        else:
            return 0, 0
       
    def sample(self):
        return self._data

    def unfinished_index(self):
        if self.ptr==0:
            ptr = self.size-1
        else:
            ptr = self.ptr-1
        return ptr

class VectorBuffer:
    def __init__(self, total_size, buffer_num):
        self.buffer_list = []
        self.buffer_num = buffer_num
        size = int(np.ceil(total_size / buffer_num))
        for _ in range(self.buffer_num):
            self.buffer_list.append(ReplayBuffer(size,))

    def reset(self):
        for i in range(self.buffer_num):
            self.buffer_list[i].reset()

    def add(self, data, buffer_id=0):
        ep_rews = []
        ep_lens = []
        # for i in range(len(buffer_id)):
        #     idx = buffer_id[i]
        #     data_slice = Batch()
        #     for k,v in data.items():
        #         data_slice[k] = v[idx]
        #     ep_rew = self.buffer_list[idx].add(data_slice)   
        #     ep_rews.append(ep_rew)
        # return np.array(ep_rews)
        for i, idx in enumerate(buffer_id):
            data_slice = Batch()
            for k,v in data.items():
                data_slice[k] = v[i]
            ep_rew, ep_len = self.buffer_list[idx].add(data_slice)   
            ep_rews.append(ep_rew)
            ep_lens.append(ep_len)
        return np.array(ep_rews), np.array(ep_lens)

    def sample(self):
        data = Batch()
        data_slice = self.buffer_list[0].sample()
        for dk,dv in data_slice.items():
            value = list()
            for v in dv.values():
                value.append(v)  
            data[dk] = value      
        for i in range(1,self.buffer_num):
            data_slice = self.buffer_list[i].sample()
            for dk,dv in data_slice.items():
                value = list()
                for v in dv.values():
                    value.append(v)
                data[dk] += value
        for k,v in data.items():
            data[k] = np.array(v)
        return data
    
    def unfinished_index(self):
        index = [self.buffer_list[0].unfinished_index()]
        for i in range(1,self.buffer_num):
            index.append(self.buffer_list[i].unfinished_index()+index[-1]+1)
        return np.array(index)
