import nltk
from nltk.corpus import wordnet as wn
import json
import random

nltk.download('wordnet')


class SynonymGroup:
    def __init__(self, words):
        self.words = words
        self.word_to_group = self.compute_synonym_groups(words)

    def compute_synonym_groups(self, words):
        word_synsets = {word: set(wn.synsets(word)) for word in words}
        synonym_groups = []
        word_to_group = {}

        for word in words:
            if word in word_to_group:
                continue

            current_group = {word}
            word_to_group[word] = current_group

            for other_word in words:
                if other_word == word or other_word in word_to_group:
                    continue

                common_synsets = self.get_common_synsets(
                    word_synsets[word], word_synsets[other_word])

                if common_synsets:
                    current_group.add(other_word)
                    word_to_group[other_word] = current_group

            synonym_groups.append(current_group)

        self.synonym_groups = synonym_groups
        return word_to_group

    def get_common_synsets(self, synsets1, synsets2):
        common_synsets = set()
        for syn1 in synsets1:
            for syn2 in synsets2:
                if syn1 == syn2 and syn1.pos() == syn2.pos():
                    common_synsets.add(syn1)
        return common_synsets

    def groups_to_json(self):
        group_list = []
        processed_groups = set()

        for group in self.synonym_groups:
            group_tuple = tuple(sorted(group))
            if group_tuple not in processed_groups:
                group_list.append(list(group))
                processed_groups.add(group_tuple)

        return group_list

    @classmethod
    def json_to_groups(cls, group_list):
        instance = cls([])
        word_to_group = {}
        synonym_groups = []

        for group_words in group_list:
            current_group = set(group_words)
            synonym_groups.append(current_group)
            for word in group_words:
                word_to_group[word] = current_group

        instance.word_to_group = word_to_group
        instance.synonym_groups = synonym_groups
        instance.words = list(word_to_group.keys())

        return instance

    def are_synonyms(self, word1, word2):
        group1 = self.word_to_group.get(word1)
        group2 = self.word_to_group.get(word2)
        return group1 is not None and group1 == group2


class PoolFactory:
    def __init__(self, pool_file):
        with open(pool_file, 'r') as f:
            pools = json.load(f)
        self.entity_pool = pools['entity_pool']
        self.attribute_pool = pools['attribute_pool']
        self.relation_pool = pools['relation_pool']
        self.attribute_synonym = SynonymGroup(self.attribute_pool)
        self.relation_synonym = SynonymGroup(self.relation_pool)

    def get_entity_pool(self, entity_num):
        return random.sample(self.entity_pool, entity_num)

    @staticmethod
    def get_pool_no_synonyms(pool, synonym_group, n):
        if n <= 0:
            return []

        unique_elements = list(pool)
        random.shuffle(unique_elements)
        selected_elements = []

        for element in unique_elements:
            if not any(synonym_group.are_synonyms(element, selected) for selected in selected_elements):
                selected_elements.append(element)
                if len(selected_elements) == n:
                    break

        return selected_elements

    def get_attribute_pool(self, attribute_num):
        return self.get_pool_no_synonyms(self.attribute_pool, self.attribute_synonym, attribute_num)

    def get_relation_pool(self, relation_num):
        return self.get_pool_no_synonyms(self.relation_pool, self.relation_synonym, relation_num)
