from abc import ABC, abstractmethod
from typing import List, Tuple

import numpy as np


class BackgroundGenerator(ABC):

    @abstractmethod
    def get_background(self):
        pass


class UniformBackgroundGenerator(BackgroundGenerator):

    def __init__(self, channels: int, base_color: Tuple[int] = (0, 0, 0)) -> None:
        assert channels == 1 or channels == 3, "wrong channel size"
        self.channels = channels
        self.base_color = base_color

    def get_background(self):
        if self.channels == 1:
            img = np.zeros((224, 224), dtype=np.uint8)
        else:
            base_color_vec = np.array(self.base_color, dtype=np.uint8)
            img = np.tile(base_color_vec, (224, 224, 1))
        return img


class RandomBackgroundGenerator(BackgroundGenerator):

    def __init__(self, channels: int, exclude_colors: List[Tuple[int]]) -> None:
        assert channels != 1, "single channel random background generation not supported yet!"
        self.channels = channels
        self.exclude_colors = exclude_colors

    def get_background(self):
        img = np.random.randint(256, size=(224, 224, 3), dtype=np.uint8)
        # change exclude colors to black pixels
        for color in self.exclude_colors:
            img[np.all(img == color, axis=2)] = [0, 0, 0]

        return img
