from collections import deque
import random

def buffer2batch(buffer):
    queries, responses, query_tensors, response_tensors, classification_texts, labels = buffer.get_all()
    return {
        'query': queries,
        'response': responses,
        'query_tensors': query_tensors,
        'response_tensors': response_tensors,
        'dec_in_texts': classification_texts,
        'label': labels
    }

class DecBuffer():
    def __init__(self, max_size=1024):
        self.max_size = max_size
        self.buffer = {
            'texts': deque(maxlen=max_size),
            'labels': deque(maxlen=max_size),
            'is_sus': deque(maxlen=max_size),
            'overseer_rewards': deque(maxlen=max_size)
        }

    def add_data(self, batch):
        # assert all keys are present
        assert all(key in batch for key in self.buffer.keys())
        for key in self.buffer.keys():
            self.buffer[key].extend(batch[key])

    def get_all(self, shuffle=True):
        if shuffle:
            # Create a list of indices and shuffle it
            indices = list(range(self.get_size()))
            random.shuffle(indices)
            batch = {}
            for key in self.buffer.keys():
                batch[key] = [self.buffer[key][i] for i in indices]

        else:
            batch = {}
            for key in self.buffer.keys():
                batch[key] = list(self.buffer[key])
            
        return batch

    def get_size(self):
        assert len(self.buffer['texts']) == len(self.buffer['labels'])
        return len(self.buffer['texts'])

    def clear(self):
        for key in self.buffer.keys():
            self.buffer[key].clear()


class EncBuffer():
    def __init__(self, max_size=1024):
        self.max_size = max_size
        self.queries = deque(maxlen=max_size)
        self.responses = deque(maxlen=max_size)
        self.query_tensors = deque(maxlen=max_size)
        self.response_tensors = deque(maxlen=max_size)
        self.classification_texts = deque(maxlen=max_size)
        self.labels = deque(maxlen=max_size)

    def add_data(self, queries, responses,query_tensors, response_tensors, classification_texts, labels):
        self.queries.extend(queries)
        self.responses.extend(responses)
        self.query_tensors.extend(query_tensors)
        self.response_tensors.extend(response_tensors)
        self.classification_texts.extend(classification_texts)
        self.labels.extend(labels)

    def get_all(self, shuffle=True):
        if shuffle:
            # Create a list of indices and shuffle it
            indices = list(range(len(self.query_tensors)))
            random.shuffle(indices)
            # Use the shuffled indices to create training data
            queries = [self.queries[i] for i in indices]
            responses = [self.responses[i] for i in indices]
            query_tensors = [self.query_tensors[i] for i in indices]
            response_tensors = [self.response_tensors[i] for i in indices]
            classification_texts = [self.classification_texts[i] for i in indices]
            labels = [self.labels[i] for i in indices]
            return queries, responses, query_tensors, response_tensors, classification_texts, labels

        else:
            return self.queries, self.responses, self.query_tensors, self.response_tensors, self.classification_texts, self.labels

    def get_size(self):
        return len(self.query_tensors)

    def clear(self):
        self.queries.clear()
        self.responses.clear()
        self.query_tensors.clear()
        self.response_tensors.clear()
        self.classification_texts.clear()
        self.labels.clear()