import os
import pickle
from tqdm import tqdm

class PreEmbDataset:
    def __init__(
        self,
        raw_set,
        entity_identifiers,
        save_path,
        skip_no_topic=True,
        skip_no_ans=True
    ):
        """
        Parameters
        ----------
        raw_set : list
            List of raw samples to process.
        entity_identifiers : set
            Set of entity identifiers for which we cannot get meaningful text embeddings.
        save_path : str
            Path to save the processed data.
        skip_no_topic : bool, optional
            Whether to skip samples without topic entities in the graph. Default is True.
        skip_no_ans : bool, optional
            Whether to skip samples without answer entities in the graph. Default is True.
        """
        self.processed_dict_list = self._load_or_process(
            raw_set,
            entity_identifiers,
            save_path
        )
        
        self.skip_no_topic = skip_no_topic
        self.skip_no_ans = skip_no_ans
        
        self.processed_dict_list = self._filter_samples(self.processed_dict_list)
        
        print(f'# raw samples: {len(raw_set)} | # processed samples: {len(self.processed_dict_list)}')

    def _load_or_process(self, raw_set, entity_identifiers, save_path):
        """
        Load processed data from save_path if it exists, otherwise process the raw data.
        
        Parameters
        ----------
        raw_set : list
            List of raw samples to process.
        entity_identifiers : set
            Set of entity identifiers for which we cannot get meaningful text embeddings.
        save_path : str
            Path to save the processed data.
        
        Returns
        -------
        list
            List of processed samples.
        """
        if os.path.exists(save_path):
            with open(save_path, 'rb') as f:
                return pickle.load(f)
        
        processed_dict_list = self._process_samples(raw_set, entity_identifiers)
        
        with open(save_path, 'wb') as f:
            pickle.dump(processed_dict_list, f)
        
        return processed_dict_list

    def _process_samples(self, raw_set, entity_identifiers):
        """
        Process each sample in the raw set.
        
        Parameters
        ----------
        raw_set : list
            List of raw samples to process.
        entity_identifiers : set
            Set of entity identifiers for which we cannot get meaningful text embeddings.
        
        Returns
        -------
        list
            List of processed samples.
        """
        processed_dict_list = []
        for i in tqdm(range(len(raw_set))):
            sample_i = raw_set[i]
            processed_dict_i = self._process_sample(sample_i, entity_identifiers)
            processed_dict_list.append(processed_dict_i)
        return processed_dict_list

    def _process_sample(self, sample, entity_identifiers):
        """
        Process a single sample.
        
        Parameters
        ----------
        sample : dict
            A single raw sample.
        entity_identifiers : set
            Set of entity identifiers for which we cannot get meaningful text embeddings.
        
        Returns
        -------
        dict
            Processed sample.
        """
        question = sample['question']
        triples = sample['graph']

        all_entities, all_relations = self._extract_entities_and_relations(triples)
        entity_list = sorted(all_entities)
        text_entity_list, non_text_entity_list = self._partition_entities(entity_list, entity_identifiers)
        entity2id = self._create_entity_id_mapping(entity_list, text_entity_list, non_text_entity_list)
        relation_list = sorted(all_relations)
        rel2id = self._create_relation_id_mapping(relation_list)
        h_id_list, r_id_list, t_id_list = self._convert_triples_to_ids(triples, entity2id, rel2id)
        q_entity_id_list = self._get_question_entity_ids(sample['q_entity'], entity2id)
        a_entity_id_list = self._get_answer_entity_ids(sample['a_entity'], entity2id)

        id2entity = {v: k for k, v in entity2id.items()}
        id2rel = {v: k for k, v in rel2id.items()}
        
        processed_dict = {
            'id': sample['id'],
            'question': question,
            'q_entity': sample['q_entity'],
            'q_entity_id_list': q_entity_id_list,
            'text_entity_list': text_entity_list,
            'non_text_entity_list': non_text_entity_list,
            'relation_list': relation_list,
            'h_id_list': h_id_list,
            'r_id_list': r_id_list,
            't_id_list': t_id_list,
            'a_entity': sample['a_entity'],
            'a_entity_id_list': a_entity_id_list,
            'entity2id': entity2id,
            'id2entity': id2entity,
            'rel2id': rel2id,
            'id2rel': id2rel
        }

        return processed_dict

    def _extract_entities_and_relations(self, triples):
        """
        Extract all entities and relations from triples.
        
        Parameters
        ----------
        triples : list of tuples
            List of triples in the form (head, relation, tail).
        
        Returns
        -------
        tuple
            A tuple containing sets of all entities and all relations.
        """
        all_entities = set()
        all_relations = set()
        for h, r, t in triples:
            all_entities.add(h)
            all_relations.add(r)
            all_entities.add(t)
        return all_entities, all_relations

    def _partition_entities(self, entity_list, entity_identifiers):
        """
        Partition entities into text entities and non-text entities.
        
        Parameters
        ----------
        entity_list : list
            List of all entities.
        entity_identifiers : set
            Set of entity identifiers for which we cannot get meaningful text embeddings.
        
        Returns
        -------
        tuple
            A tuple containing lists of text entities and non-text entities.
        """
        text_entity_list = []
        non_text_entity_list = []
        for entity in entity_list:
            if entity in entity_identifiers:
                non_text_entity_list.append(entity)
            else:
                text_entity_list.append(entity)
        return text_entity_list, non_text_entity_list

    def _create_entity_id_mapping(self, entity_list, text_entity_list, non_text_entity_list):
        """
        Create a mapping from entities to IDs.
        
        Parameters
        ----------
        entity_list : list
            List of all entities.
        text_entity_list : list
            List of text entities.
        non_text_entity_list : list
            List of non-text entities.
        
        Returns
        -------
        dict
            Dictionary mapping entities to IDs.
        """
        entity2id = {}
        entity_id = 0
        for entity in text_entity_list:
            entity2id[entity] = entity_id
            entity_id += 1
        for entity in non_text_entity_list:
            entity2id[entity] = entity_id
            entity_id += 1
        return entity2id

    def _create_relation_id_mapping(self, relation_list):
        """
        Create a mapping from relations to IDs.
        
        Parameters
        ----------
        relation_list : list
            List of all relations.
        
        Returns
        -------
        dict
            Dictionary mapping relations to IDs.
        """
        rel2id = {}
        rel_id = 0
        for rel in relation_list:
            rel2id[rel] = rel_id
            rel_id += 1
        return rel2id

    def _convert_triples_to_ids(self, triples, entity2id, rel2id):
        """
        Convert triples to entity and relation IDs.
        
        Parameters
        ----------
        triples : list of tuples
            List of triples in the form (head, relation, tail).
        entity2id : dict
            Dictionary mapping entities to IDs.
        rel2id : dict
            Dictionary mapping relations to IDs.
        
        Returns
        -------
        tuple
            A tuple containing lists of head IDs, relation IDs, and tail IDs.
        """
        h_id_list = []
        r_id_list = []
        t_id_list = []
        for h, r, t in triples:
            h_id_list.append(entity2id[h])
            r_id_list.append(rel2id[r])
            t_id_list.append(entity2id[t])
        return h_id_list, r_id_list, t_id_list

    def _get_question_entity_ids(self, q_entity, entity2id):
        """
        Get IDs for question entities.
        
        Parameters
        ----------
        q_entity : list
            List of question entities.
        entity2id : dict
            Dictionary mapping entities to IDs.
        
        Returns
        -------
        list
            List of IDs for question entities.
        """
        q_entity_id_list = []
        for entity in q_entity:
            if entity in entity2id:
                q_entity_id_list.append(entity2id[entity])
        return q_entity_id_list

    def _get_answer_entity_ids(self, a_entity, entity2id):
        """
        Get IDs for answer entities.
        
        Parameters
        ----------
        a_entity : list
            List of answer entities.
        entity2id : dict
            Dictionary mapping entities to IDs.
        
        Returns
        -------
        list
            List of IDs for answer entities.
        """
        a_entity_id_list = []
        for entity in a_entity:
            entity_id = entity2id.get(entity, None)
            if entity_id is not None:
                a_entity_id_list.append(entity_id)
        return a_entity_id_list

    def _filter_samples(self, processed_dict_list):
        """
        Filter samples based on the presence of topic and answer entities.
        
        Parameters
        ----------
        processed_dict_list : list
            List of processed samples.
        
        Returns
        -------
        list
            List of filtered samples.
        """
        filtered_list = []
        for processed_dict_i in processed_dict_list:
            if (len(processed_dict_i['q_entity_id_list']) == 0) and self.skip_no_topic:
                continue
            
            if (len(processed_dict_i['a_entity_id_list']) == 0) and self.skip_no_ans:
                continue
            
            filtered_list.append(processed_dict_i)
        return filtered_list

    def __len__(self):
        return len(self.processed_dict_list)
    
    def __getitem__(self, i):
        sample = self.processed_dict_list[i]
        return sample['id'], sample['question'], sample['text_entity_list'], sample['relation_list'],sample['entity2id'],sample['rel2id']