import random
import numpy as np
import pickle
import torch
from QLearning.config import Config

class ReplayBuffer(object):
    def __init__(self, buffer_size, random_seed=123):
        self.buffer_size = buffer_size
        self.count = 0
        self.buffer = []
        random.seed(random_seed)

    def add(self, s, a, r, t, s2):
        experience = (s, a, r, t, s2)
        if self.count < self.buffer_size:
            self.buffer.append(experience)
            self.count += 1
        else:
            self.buffer.pop(0)
            self.buffer.append(experience)

    def size(self):
        return self.count

    def sample_batch(self, batch_size):
        if self.count < batch_size:
            batch = random.sample(self.buffer, self.count)
        else:
            batch = random.sample(self.buffer, batch_size)
        A1=[]
        B1=[]
        ASig1=[]
        BSig1 = []
        Z1=torch.zeros((batch_size,252))
        currentState1 = []
        u01=[]
        prevS_A1=[]

        A2=[]
        B2=[]
        ASig2=[]
        BSig2 = []
        Z2=torch.zeros((batch_size,252))
        currentState2 = []
        u02=[]
        prevS_A2=[]

        for i in range(batch_size):
            A1.append(batch[i][0][0])
            B1.append(batch[i][0][1])
            ASig1.append(batch[i][0][2])
            BSig1.append(batch[i][0][3])
            Z1[i,:]=batch[i][0][4]
            currentState1.append(batch[i][0][5])
            u01.append(batch[i][0][6])
            prevS_A1.append(batch[i][0][7])


            A2.append(batch[i][4][0])
            B2.append(batch[i][4][1])
            ASig2.append(batch[i][4][2])
            BSig2.append(batch[i][4][3])
            Z2[i,:]=batch[i][4][4]
            currentState2.append(batch[i][4][5])
            u02.append(batch[i][4][6])
            prevS_A2.append(batch[i][4][7])

        #s_batch = np.array([_[0] for _ in batch])
        a_batch = np.array([_[1] for _ in batch])
        r_batch = np.array([_[2] for _ in batch])
        t_batch = np.array([_[3] for _ in batch])
        #s2_batch = np.array([_[4] for _ in batch])

        return (A1,B1,ASig1,BSig1,Z1,currentState1,u01,prevS_A1), a_batch, r_batch, t_batch, (A2,B2,ASig2,BSig2,Z2,currentState2,u02,prevS_A2)

    def clear(self):
        self.buffer = []
        self.count = 0

    def save(self):
        file = open('replay_buffer.obj', 'wb')
        pickle.dump(self.buffer, file)
        file.close()

    def load(self):
        try:
            filehandler = open('replay_buffer.obj', 'rb')
            self.buffer = pickle.load(filehandler)
            self.count = len(self.buffer)
        except:
            print('there was no file to load')

