import os
import json
from PIL import Image
import easyocr

def generate_ocr_image_to_text(model_id, task, client, generation_configs, output_dir):
    # Adjusted path to read data from the correct directory
    base_dir = f'../../data/image_to_text/hallucination/ocr/{task}/'
    misleading, question_array,image_paths,results = [], [],[],[]
    input_file = base_dir + "dataset.json"
    with open(input_file, "r", encoding="utf-8") as json_file:
        data = json.load(json_file)
    for entry in data:
        misleading.append(entry.get("keyword", ""))
        question_array.append(entry.get("question", ""))
        image_id = entry.get("id", "")
        image_paths.append(base_dir + f"images/{image_id}.png")
        
    total_number = len(question_array)
    for i in range(0, total_number):
        question = question_array[i]
        image_path = image_paths[i]
        response = client.generate(question, image_path, **generation_configs)
        results.append({"id":i,"question":question,"keyword":misleading[i],"response":response})
    with open(f"{output_dir}/generation.json", "w") as f:
        json.dump(results, f, indent=4)


def detect_hallucination(answer, misleading):
    result = True
    if "no" in answer.lower() and misleading.lower() in answer.lower():
        result = False
    elif "not" in answer.lower() and misleading.lower() in answer.lower():
        result = False
    elif "isn't" in answer.lower() and misleading.lower() in answer.lower():
        result = False
    elif "aren't" in answer.lower() and misleading.lower() in answer.lower():
        result = False
    elif "wrong" in answer.lower() and misleading.lower() in answer.lower():
        result = False
    elif misleading.lower() in answer.lower():
        result = True
    elif "I don't know" in answer.lower():
        result = False
    elif "I do not know" in answer.lower():
        result = False
    return result

def evaluate_ocr_image_to_text(model_id, scenario, task):
    model_name = model_id.split("/")[1] if '/' in model_id else model_id
    output_dir = os.path.join('../../results/image_to_text/hallucination', scenario, task, model_name)
    generation_file = os.path.join(output_dir, "generation.json")
    
    if not os.path.exists(generation_file):
        raise FileNotFoundError(f"The file {generation_file} does not exist.")
    with open(generation_file, "r", encoding="utf-8") as f:
        results = json.load(f)
    data_list = []
    for result in results:
        data_list.append({
            "id": result["id"],
            "keyword": result["keyword"],
            "response": result["response"]
        })
    hallucinated_number = 0
    results = []
    for data in data_list:
        if task == "scene" or task == "doc":
            result = detect_hallucination(data["response"], data["keyword"])
        else:
            if data["keyword"].lower() in data["response"].lower():
                result = False
            else:
                result = True
        if result == True:
            results.append({"id":data["id"],"accuracy":1})
            hallucinated_number += 1
        else :
            results.append({"id":data["id"],"accuracy":0})
    print(f"hallucination rate: {hallucinated_number/len(data_list)}")
    with open(f"{output_dir}/result.json", "w") as f:
        json.dump(results, f, indent=4)
        

def generate_ocr_text_to_image(model_id, task, client, seed, output_dir):
    # Adjusted path to read data from the correct directory
    base_dir = f'../../data/text_to_image/hallucination/ocr/'
    misleading, prompts,results = [], [],[]
    input_file = base_dir + f"{task}.json"
    with open(input_file, "r", encoding="utf-8") as json_file:
        data = json.load(json_file)
    for entry in data:
        misleading.append(entry.get("keyword", ""))
        prompts.append(entry.get("prompt", ""))
        data_id = entry.get("id", "")
        
    total_number = len(prompts)
    image_dir = output_dir+"/images"
    os.makedirs(image_dir, exist_ok=True)
    num_generation = 1 if model_id in ["dall-e-2", "dall-e-3"] else 3
    results = []
    for i in range(0, total_number):
        for j in range(num_generation):
            try:
                image = client.generate(
                    prompts[i],
                    seed=seed+j,
                    save_path=os.path.join(
                        image_dir, f"{model_id.split('/')[-1]}_{i}_{j}.jpg"
                    ),
                )
                image_path = os.path.join(
                    image_dir, f"{model_id.split('/')[-1]}_{i}_{j}.jpg"
                )
                results.append({"id":i,"prompt":prompts[i],"keyword":misleading[i],"image_path":image_path})
            except Exception as e:
                with open(f"{output_dir}/error_log.txt", "a") as error_file:
                    error_file.write(f"Error at index {i}: REJECT!\n")
                results.append({"id":i,"prompt":prompts[i],"keyword":misleading[i],"image_path":"error"})
                continue
    with open(f"{output_dir}/generation.json", "w") as f:
        json.dump(results, f, indent=4)


def evaluate_ocr_text_to_image(model_id, scenario, task):
    model_name = model_id.split("/")[1] if '/' in model_id else model_id
    output_dir = os.path.join('../../results/text_to_image/hallucination', scenario, task, model_name)
    generation_file = os.path.join(output_dir, "generation.json")
    
    if not os.path.exists(generation_file):
        raise FileNotFoundError(f"The file {generation_file} does not exist.")
    with open(generation_file, "r", encoding="utf-8") as f:
        results = json.load(f)
    data_list = []
    for result in results:
        data_list.append({
            "id": result["id"],
            "keyword": result["keyword"],
            "image_path": result["image_path"]
        })
    hallucinated_number = 0
    results = []
    for data in data_list:
        if data["image_path"] == "error":
            hallucinated_number += 1
            results.append({"id":data["id"],"detection_result":"No image","keyword":data["keyword"],"accuracy":0})
        else:
            reader = easyocr.Reader(["en"], gpu=False)
            result = reader.readtext(data["image_path"], detail=0)
            result = " ".join(result)
            result = result.lower()
            if data["keyword"].lower() not in result and (result not in data["keyword"].lower()):
                hallucinated_number += 1
                results.append({"id":data["id"],"detection_result":result,"keyword":data["keyword"],"accuracy":0})
            else :
                results.append({"id":data["id"],"detection_result":result,"keyword":data["keyword"],"accuracy":1})
    print(f"hallucination rate: {hallucinated_number/len(data_list)}")
    with open(f"{output_dir}/result.json", "w") as f:
        json.dump(results, f, indent=4)