import logging
from colorlog import ColoredFormatter

import numpy as np

import torch


class SimplePool:
    def __init__(self, pool_size, version="pt"):
        self.pool_size = pool_size
        self.version = version
        if self.pool_size > 0:
            self.num = 0
            self.items = []
        if not (version == "pt" or version == "np"):
            print("version = %s; please choose pt or np")
            assert False  # please choose pt or np

    def __len__(self):
        return len(self.items)

    def mean(self, min_size="none"):
        if min_size == "half":
            pool_size_thresh = self.pool_size / 2
        else:
            pool_size_thresh = 1

        if self.version == "np":
            if len(self.items) >= pool_size_thresh:
                return np.sum(self.items) / float(len(self.items))
            else:
                return np.nan
        if self.version == "pt":
            if len(self.items) >= pool_size_thresh:
                return torch.sum(self.items) / float(len(self.items))
            else:
                return torch.from_numpy(np.nan)

    def sample(self):
        idx = np.random.randint(len(self.items))
        return self.items[idx]

    def fetch(self, num=None):
        if self.version == "pt":
            item_array = torch.stack(self.items)
        elif self.version == "np":
            item_array = np.stack(self.items)
        if num is not None:
            # there better be some items
            assert len(self.items) >= num

            # if there are not that many elements just return however many there are
            if len(self.items) < num:
                return item_array
            else:
                idxs = np.random.randint(len(self.items), size=num)
                return item_array[idxs]
        else:
            return item_array

    def is_full(self):
        full = self.num == self.pool_size
        return full

    def empty(self):
        self.items = []
        self.num = 0

    def update(self, items):
        for item in items:
            if self.num < self.pool_size:
                # the pool is not full, so let's add this in
                self.num = self.num + 1
            else:
                # the pool is full
                # pop from the front
                self.items.pop(0)
            # add to the back
            self.items.append(item)
        return self.items


# From https://github.com/taylorwwebb/learning_representations_that_support_extrapolation/blob/master/VAEC_dataset_and_models/util.py
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

formatter = ColoredFormatter(
    "%(log_color)s[%(asctime)s] %(message)s",
    datefmt=None,
    reset=True,
    log_colors={
        "DEBUG": "cyan",
        "INFO": "white,bold",
        "INFOV": "cyan,bold",
        "WARNING": "yellow",
        "ERROR": "red,bold",
        "CRITICAL": "red,bg_white",
    },
    secondary_log_colors={},
    style="%",
)
ch.setFormatter(formatter)

log = logging.getLogger("rn")
log.setLevel(logging.DEBUG)
log.handlers = []  # No duplicated handlers
log.propagate = False  # workaround for duplicated logs in ipython
log.addHandler(ch)
