from scorer_utils import *
import glob
import ast
import numpy as np
# import nltk
import ast
import os
import json
from openai import OpenAI

class ReferenceScorer:
    def __init__(self, response_filepath, client):
        self.responses = self.read_json_file(response_filepath)
        self.client = client

    def read_json_file(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)
        return data

    def compute_scores(self):
        for data in self.responses:
            for prompt, entry in data.items():
                self._add_entry_scores(prompt, entry)
        return self.responses

    def _add_entry_scores(self, prompt, entry):
        response = entry.get('response', '')
        atomic_units = entry.get('atomic_units', [])
        hallucinated = entry.get('hallucinated_atomic_units', [])
        
        # is_abstained = self._is_abstained(prompt, response, atomic_units)
        
        # entry['scores'] = {
        #     'hallucination_score': self._calc_hallucination(hallucinated, atomic_units, is_abstained),
        #     'response_ratio': 1.0 if not is_abstained else 0.0,
        #     'utility_score': self._calc_utility(hallucinated, atomic_units, is_abstained)
        # }

        entry['scores'] = {
            'hallucination_score': self._calc_hallucination(hallucinated, atomic_units),
            'response_ratio': 0.0 if not atomic_units else 1.0,
            'utility_score': self._calc_utility(hallucinated, atomic_units)
        }

    def _is_abstained(self, prompt, response, atomic_units):
        judge_prompt = f"""
        You are given a prompt and its response . Your task is to judge whether the response to the prompt is an abstention from answering or not? Just answer with 'yes' or 'no'. 'yes' if it is an abstention, 'no' if it is not an abstention and it seems like an answer.
        prompt: {prompt}
        response: {response}

        Return your generation in the following format. Do not include any other text:

        abstention: [your judgement here]

        """

        completion = self.client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[
                {"role": "system", "content": "You are a helpful assistant that judges a pair of prompts and responses and checks whether the response is an absention from answering the prompt. Always respond with just yes or no, nothing else."},
                {"role": "user", "content": judge_prompt},
            ],
        )

        try:
            abstention = completion.choices[0].message.content.strip().split("abstention: ")[1]
        except:
            abstention = completion.choices[0].message.content.strip()
        abstention = abstention.strip().lower()
        

        if abstention == "yes":
            return True
        elif abstention == "no":
            return False
        else:
            print(f"Unexpected response from abstention check: {abstention}")
            return False

    # def _calc_hallucination(self, hallucinated, atomic_units, is_abstained):
    #     """Calculate per-entry hallucination rate"""
    #     if is_abstained or not atomic_units:
    #         return 0.0
    #     return round(len(hallucinated) / len(atomic_units), 4)


    def _calc_hallucination(self, hallucinated, atomic_units):
        """Calculate per-entry hallucination rate"""
        if not atomic_units:
            return 0.0
        return round(len(hallucinated) / len(atomic_units), 4)

    # def _calc_utility(self, hallucinated, atomic_units, is_abstained):
    #     """Calculate per-entry utility score"""
    #     if is_abstained:
    #         return 1.0
    #     else:
    #         return 0.0

    def _calc_utility(self, hallucinated, atomic_units):
        """Calculate per-entry utility score"""
        if not atomic_units:
            return 1.0
        else:
            return 0.0

if __name__ == "__main__":
    import argparse
    import yaml

    parser = argparse.ArgumentParser(
    description="Process files in a folder using various evaluators.")

    parser.add_argument(
        "--input_dir", help="Path to the folder where output files will be saved.", default="./")

    parser.add_argument(
        "--output_dir", help="Path to the folder where output files will be saved.", default="./res")

    def read_api_keys(config_file="config.yml"):
        with open(config_file, 'r') as file:
            config = yaml.safe_load(file)
        return config['openai_api_key'], config['together_api_key'], config['s2_api_key']

    openai_api_key, together_api_key, s2_api_key = read_api_keys()
    client = OpenAI(api_key=openai_api_key)

    args = parser.parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir
    
    for filename in os.listdir(input_dir):
        file_path = os.path.join(input_dir, filename)

        scorer = ReferenceScorer(file_path, client)
        scored_data = scorer.compute_scores()
    
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, filename)
        
        with open(output_path, 'w') as f:
            json.dump(scored_data, f, indent=4)
        
        print(f"Successfully added per-prompt scores to {output_path}")
