
import torch
import torch.nn as nn
from queue import Queue
from torch.autograd import Variable
from einops import rearrange


class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class MyQueue:
    def __init__(self, maxsize):
        super(MyQueue, self).__init__()
        self.pool = Queue(maxsize)

    def put_item(self, item):
        if self.pool.full():
            self.pool.get()
        self.pool.put(item)

    def get_item_by_idx(self, idx):
        assert idx >=0 and idx < self.pool.qsize()
        return self.pool.queue[idx]

    def get_all_items(self):
        return list(self.pool.queue)

    def len(self):
        return self.pool.qsize()

    def init(self):
        while not self.pool.empty():
            self.pool.get()

