# -*- coding: utf-8 -*-

import numpy as np
import random



class memory(object):
    def __init__(self, memory_size, state_dim = 2, action_dim = 1, CNN_FLAG = False,
                 ins_size = 24, mltpro_flag = False, mlt_num = None,):

        self.mlt_flag = mltpro_flag
        self.ins_size = ins_size
        if self.mlt_flag:
            self.mlt_num = mlt_num
            self.CNN_FLAG = CNN_FLAG
            self.memory_size = memory_size  # max size of memory
            self.initialized = False
            if self.CNN_FLAG:
                self.state_shape = (self.mlt_num, self.memory_size, state_dim[0], state_dim[1], state_dim[2])
            else:
                self.state_shape = (self.mlt_num, self.memory_size, state_dim)

            self.act_shape = (self.mlt_num, self.memory_size, 1)
            self.s = np.zeros(self.state_shape, dtype=np.float32)
            self.a = np.zeros(self.act_shape, dtype=np.float32)
            self.r = np.zeros((self.mlt_num, self.memory_size, 1), dtype=np.float32)
            self.done = np.zeros((self.mlt_num, self.memory_size), dtype=np.float32)
            self.s_next = np.zeros(self.state_shape, dtype=np.float32)

            self.pointer = np.zeros(self.mlt_num, dtype=np.int16) - 1
            self.size_now = np.zeros(self.mlt_num, dtype=np.int16)
            self.flag  = np.zeros(self.mlt_num, dtype=np.int16)


    def reset(self, mtl_mark = None):
        self.initialized = True
        if self.mlt_flag:
            assert mtl_mark != None
            if mtl_mark == 'all':
                self.pointer = np.zeros(self.mlt_num, dtype=np.int16) - 1
                self.size_now = np.zeros(self.mlt_num, dtype = np.int16)
            else:
                self.pointer [mtl_mark] = -1
                self.size_now [mtl_mark] = 0
        else:

            self.pointer = -1
            self.size_now = 0


    def get_batch(self, mark):
        assert self.initialized
        if self.size_now[mark] <= 1:
            flag, s, a, r, s_1, done = \
                self.flag[mark], self.s[mark, 0],\
                self.a[mark, 0], self.r[mark, 0],\
                self.s_next[mark, 0], self.done[mark, 0]
            s = s[np.newaxis,:]
            a = a[np.newaxis,:]
            r = r[np.newaxis,:]
            s_1 = s_1[np.newaxis,:]

            return flag, s, a, r, s_1

        else:

            return self.flag[mark], self.s[mark, 0:self.pointer[mark]+1], \
                   self.a[mark, 0:self.pointer[mark]+1], \
                   self.r[mark, 0:self.pointer[mark]+1], \
                   self.s_next[mark, 0:self.pointer[mark]+1],


    def save_exp(self,s, a, r, s_next, done, mark = None, effect = True):
        if self.mlt_flag:
            assert  mark != None
            self.pointer[mark] = (self.pointer[mark] + 1) % self.memory_size
            if self.CNN_FLAG:
                self.s[mark, self.pointer[mark]] = s
            else:
                self.s[mark, self.pointer[mark]] = s
            self.a[mark, self.pointer[mark]] = a
            self.r[mark, self.pointer[mark]] = r
            self.s_next[mark, self.pointer[mark]] = s_next
            self.done[mark, self.pointer[mark]] = done
            self.flag[mark] = effect
            self.size_now[mark] = min(self.memory_size, self.size_now[mark] + 1)









