from collections import deque
import random
import os
import torch

from torch.utils.data import Sampler

from dataset.baseDataset import baseDataset


class ExperienceReplay():
    def __init__(self, max_size, data_dir, buffer_fname):
        self.max_size = max_size
        self.buffer = deque(maxlen=max_size)

    def load_buffer(self, data_dir, buffer_fname):
        # load quadruples from a text buffer file and update the self.buffer
        results = self._load_quadruples(os.path.join(data_dir, buffer_fname))
        self.buffer += results

    def update_buffer(self, train, sample_ratio=0.1):
        # update the self.buffer with quadruples sampled from the training set, sample without replacement
        self.buffer += random.sample(train, int(len(train) * sample_ratio))

    def save_buffer(self, data_dir, buffer_fname):
        # save the self.buffer to a text buffer file
        self._save_quadruples(self.buffer, os.path.join(data_dir, buffer_fname))

    def _save_quadruples(self, quadruples, buffer_path):
        with open(buffer_path, 'w') as f:
            for quadruple in quadruples:
                f.write('\t'.join([str(x) for x in quadruple]) + '\n')

    def _load_quadruples(self, buffer_path):
        with open(buffer_path, 'r') as f:
            quadrupleList = []
            for line in f:
                try:
                    line_split = line.split()
                    head = int(line_split[0])
                    rel = int(line_split[1])
                    tail = int(line_split[2])
                    time = int(line_split[3])
                    quadrupleList.append([head, rel, tail, time])
                except:
                    print(line)
        return quadrupleList


    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, idx):
        return self.buffer[idx] # (state, action, reward, next_state, done)

