import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS
add_to_stopwords = {'next', 'front', 'rear', 'besides', 'below', 'under',
                    'near', 'back', 'side', 'near', 'background', 'foreground',
                    'behind', 'along', 'top', 'small',
                    'large', 'sitting', 'driving', 'riding', 'laying', 'standing', 'looking', 'holding', 'outside',
                    'inside', 'another', 'together', 'old', 'open', 'close',
                    'new', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'like',
                    'looks', 'owner', 'cute', 'home', 'day', 'love', 'little', 'around',
                    'time', 'world', 'happy', 'big', 'get', 'look', 'head',
                    'eating', 'likes', 'got', 'sleeping', 'go', 'today', 'think', 'put', 'really',
                    'area', 'body', 'double'}

stop_words = STOPWORDS.update(add_to_stopwords)

# Read the whole text.
text = open('/Users/asgaris/Downloads/blip_texplain_captions/bus/blip.txt').read()

# Generate a word cloud image
# wordcloud = WordCloud().generate(text)
# plt.imshow(wordcloud, interpolation='bilinear')
# plt.axis("off")
# wordcloud = WordCloud(max_font_size=100, max_words=5, collocations=False,
#                       min_font_size=20, colormap='turbo', relative_scaling=0.7, width=300, height=300,
#                       background_color='white').generate(text)
# fig = plt.figure()
# plt.imshow(wordcloud, interpolation="bilinear")
# plt.axis("off")
# fig.savefig('foo.png', dpi=fig.dpi, bbox_inches='tight')

from wordcloud import (WordCloud, get_single_color_func)


class SimpleGroupedColorFunc(object):
    """Create a color function object which assigns EXACT colors
       to certain words based on the color to words mapping

       Parameters
       ----------
       color_to_words : dict(str -> list(str))
         A dictionary that maps a color to the list of words.

       default_color : str
         Color that will be assigned to a word that's not a member
         of any value from color_to_words.
    """

    def __init__(self, color_to_words, default_color):
        self.word_to_color = {word: color
                              for (color, words) in color_to_words.items()
                              for word in words}

        self.default_color = default_color

    def __call__(self, word, **kwargs):
        return self.word_to_color.get(word, self.default_color)


class GroupedColorFunc(object):
    """Create a color function object which assigns DIFFERENT SHADES of
       specified colors to certain words based on the color to words mapping.

       Uses wordcloud.get_single_color_func

       Parameters
       ----------
       color_to_words : dict(str -> list(str))
         A dictionary that maps a color to the list of words.

       default_color : str
         Color that will be assigned to a word that's not a member
         of any value from color_to_words.
    """

    def __init__(self, color_to_words, default_color):
        self.color_func_to_words = [
            (get_single_color_func(color), set(words))
            for (color, words) in color_to_words.items()]

        self.default_color_func = get_single_color_func(default_color)

    def get_color_func(self, word):
        """Returns a single_color_func associated with the word"""
        try:
            color_func = next(
                color_func for (color_func, words) in self.color_func_to_words
                if word in words)
        except StopIteration:
            color_func = self.default_color_func

        return color_func

    def __call__(self, word, **kwargs):
        return self.get_color_func(word)(word, **kwargs)


# Since the text is small collocations are turned off and text is lower-cased
# wc = WordCloud(collocations=False, max_font_size=100, max_words=20, relative_scaling=0.7, width=400, height=300,
#                background_color='white', min_font_size=10).generate(text.lower())

wc = WordCloud(collocations=False,  background_color='white').generate(text.lower())

# color_to_words = {
#     # will be colored with a red single color function
#     '#ff0000': ['bird', 'dog', 'cat', 'zebra', 'bus', 'train', 'kitchen', 'bathroom']
# }#ADD8E6

color_to_words = {
    # will be colored with a red single color function
    '#00B3FF': ['bird', 'dog', 'cat', 'zebra', 'bus', 'train', 'kitchen', 'bathroom']}

# Words that are not in any of the color_to_words values
# will be colored with a grey single color function
default_color = 'black'

# Create a color function with single tone
# grouped_color_func = SimpleGroupedColorFunc(color_to_words, default_color)

# Create a color function with multiple tones
grouped_color_func = GroupedColorFunc(color_to_words, default_color)

# Apply our color function
wc.recolor(color_func=grouped_color_func)

# Plot
fig = plt.figure()
plt.imshow(wc, interpolation="bilinear")
plt.axis("off")
fig.savefig('wc_blip_bus_20_.png', dpi=fig.dpi, bbox_inches='tight')


