import os
import json
import spacy
import nltk
from nltk.stem import WordNetLemmatizer
from tqdm import tqdm
nlp = spacy.load("en_core_web_lg")

class AmberBenchmarkDataset:
    def __init__(
        self,
        image_path: str,
        query_path: str,
        annotation_path: str,
        metric_path: str,
        relation_path: str,
        safe_words_path: str,
        similarity_score: 0.8,
    ):
        self.image_path = image_path
        self.query_path = query_path
        self.annotation_path = annotation_path
        self.metric_path = metric_path
        self.similarity_score = similarity_score 
        self.image_id_key = "id"
        self.caption_key = "response"
        self.image = self._load_amber_data()
        self.relation_path = relation_path
        self.safe_words_path = safe_words_path

    def _load_amber_data(
        self
    ):
        with open(self.query_path, "r") as file:
            queries = json.load(file)
        
        images = []
        for query in queries:
            images.append({
                "image_id": query["id"],
                "image_path": os.path.join(self.image_path, query["image"])
            })
        return images
        
    def get_test_dataset(
        self
    ):
        return self.image

    def check_synonyms_word(
        self,
        word1, 
        word2,
        ):
        token1 = nlp(word1)
        token2 = nlp(word2)
        similarity = token1.similarity(token2)
        return similarity > self.similarity_score

    def extract_nouns(
        self,
        text
        ):
        lemmatizer = WordNetLemmatizer()
        tokens = nltk.word_tokenize(text)
        tagged = nltk.pos_tag(tokens)
        nouns = [lemmatizer.lemmatize(word) for word, pos in tagged if pos.startswith('NN')]
        return nouns

    def init(
        self
    ):
        metrics = {}
        with open(self.metric_path, "r") as file:
            lines = file.readlines()

        for line in lines:
            parts = line.strip().split('=')
            if len(parts) == 2:
                variable_name = parts[0].strip()
                variable_value = eval(parts[1].strip())
                metrics[variable_name] = variable_value
                
        return metrics

    def dump_generations(
        self,
        results,
        results_path
    ):
        for result in results:
            with open(results_path, "a") as f:
                json.dump(result, f)
                f.write('\n')

    def evaluate(
        self,
        generation_path,
        dump_results=False
    ):
        metrics = self.init()
        with open(generation_path, "r") as file:
            generations = [json.loads(line) for line in file]
        annotations = json.load(open(self.annotation_path, 'r', encoding='utf-8'))

        association = json.load(open(self.relation_path, 'r', encoding='utf-8'))
        hallucination_words = []
        for word1 in association.keys():
            hallucination_words.append(word1)
            for word2 in association[word1]:
                hallucination_words.append(word2)

        global_safe_words = []
        with open(self.safe_words_path, 'r', encoding='utf-8') as safe_file:
            for line in safe_file:
                line = line.split('\n')[0]
                global_safe_words.append(line)

        for i in tqdm(range(len(generations))):
            
            id = int(generations[i]['id'])
            
            if annotations[id-1]['type'] == 'generative':
                nouns = self.extract_nouns(generations[i]['response'])
                after_process_nouns = []
                for noun in nouns:
                    if noun in hallucination_words:
                        after_process_nouns.append(noun)
                
                safe_words = []
                safe_list = []
                for idx, word in enumerate(annotations[id-1]['truth']):
                    safe_words += association[word]
                    safe_list += [idx] * len(association[word])
                    
                ha_words = []
                ha_list = []
                for idx, word in enumerate(annotations[id-1]['hallu']):
                    ha_words += association[word]
                    ha_list += [idx] * len(association[word])
                
                safe_words += annotations[id-1]['truth']
                safe_len = len(annotations[id-1]['truth'])
                safe_list += [0] * safe_len
                safe_flag_list = [0] * len(after_process_nouns)
                
                ha_words += annotations[id-1]['hallu']
                ha_len = len(annotations[id-1]['hallu'])
                ha_list += [0] * ha_len
                
                for idx, noun in enumerate(after_process_nouns):
                    if noun in global_safe_words:
                        continue
                    
                    if noun in safe_words:
                        for j in range(len(safe_words)):
                            if noun == safe_words[j]:
                                if j < (len(safe_list) - safe_len):
                                    safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                                else:
                                    safe_list[j] = 1
                                break
                        continue
                    
                    if noun in ha_words:
                        for j in range(len(ha_words)):
                            if noun == ha_words[j]:
                                if j < (len(ha_list) - ha_len):
                                    ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                                else:
                                    ha_list[j] = 1
                                break
                    
                    for j, check_word in enumerate(ha_words):
                        if self.check_synonyms_word(noun, check_word):
                            if j < (len(ha_list) - ha_len):
                                    ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                            else:
                                ha_list[j] = 1
                            break
                    
                    flag = False
                    for j, check_word in enumerate(safe_words):
                        if self.check_synonyms_word(noun, check_word):
                            flag = True
                            if j < (len(safe_list) - safe_len):
                                    safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                            else:
                                safe_list[j] = 1
                            break
                    if flag == True:
                        continue
                
                    safe_flag_list[idx] = 1

                metrics['chair_score'] += sum(safe_flag_list)
                metrics['chair_num'] += len(safe_flag_list)
                metrics['safe_cover_score'] += sum(safe_list[-safe_len:])
                metrics['safe_cover_num'] += len(safe_list[-safe_len:])
                metrics['hallu_cover_score'] += sum(ha_list[-ha_len:])
                metrics['hallu_cover_num'] += len(ha_list[-ha_len:])
                if sum(safe_flag_list) == 0:
                    metrics['non_hallu_score'] += 1
                metrics['non_hallu_num'] += 1

        CHAIR = round(metrics['chair_score'] / metrics['chair_num'] * 100, 1)
        Cover = round(metrics['safe_cover_score'] / metrics['safe_cover_num'] * 100, 1)
        Ha = round(metrics['hallu_cover_score'] / metrics['hallu_cover_num'] * 100, 1)
        Ha_p = round(100 - metrics['non_hallu_score'] / metrics['non_hallu_num'] * 100, 1)
        print("Generative Task:")
        print("CHAIR:\t\t", CHAIR)
        print("Cover:\t\t", Cover)
        print("Hal:\t\t", Ha_p)
        print("Cog:\t\t", Ha, "\n")
        results = {
            "CHAIR": CHAIR,
            "Cover": Cover,
            "Hal": Ha_p,
            "Cog": Ha
        }

        if dump_results:
            with open(os.path.join(os.path.dirname(generation_path), 'chair_results.jsonl'), 'w') as f:
                json.dump(results, f)
        
        return results
            
        
