# -*- coding: utf-8 -*-
from logging import root
import os
from termcolor import cprint
import shutil
import json
from typing import Dict
import time
import sys
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(root_path)
print(root_path)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pipeline.utils import print_colored, flat_directory, get_pdf_context, get_csv_context, is_non_text_file, get_user_input_multiline, generate_hints, gen_uid
from pipeline.utils import find_id_in_label, get_openai_api_key, print_colored, div_by_zero
from pipeline.evaluator import Llama3Evaluator, OpenAIEvaluator
from typing import Literal

class BenchmarkManager:
    def __init__(self, base_path = "benchmark", verbose=False):
        """
        Initialize the BenchmarkManager with the specified base path and pre-fetch the available scenarios.
        :param base_path: The base directory containing the benchmark data.
        :param verbose: If set to True, detailed output will be printed.
        """
        self.base_path = os.path.join(os.path.dirname(__file__))
        self.data_path = os.path.join(self.base_path, "platforms")
        self.verbose = verbose
        self.available_scenarios = self._fetch_scenarios()
        self.task_meta = self.get_task_meta()

    def _fetch_scenarios(self):
        return [item for item in os.listdir(self.data_path) if os.path.isdir(os.path.join(self.data_path, item)) and not item.startswith('.') and not item.startswith('_')]

    def get_scenarios(self, verbose = False):
        """
        List all available scenarios in the benchmark.
        """
        if verbose:
            print("\n")
            cprint('====================', 'green')
            cprint('Available scenarios:', 'green')
            for scenario in self.available_scenarios:
                cprint(f"- {scenario}", 'blue')
        
        return self.available_scenarios

    def get_documents(self, scenario, verbose = False):
        """
        Access the document files within the 'documents' folder of a specified scenario.
        """
        if scenario not in self.available_scenarios:
            print(f"Scenario '{scenario}' not found. Available scenarios are: {', '.join(self.available_scenarios)}")
            return []
        documents_path = os.path.join(self.data_path, scenario, 'documents')
        if os.path.exists(documents_path):
            files = [file for file in os.listdir(documents_path) if not file.startswith('.')]
            if verbose:
                cprint('====================', 'green')
                cprint(f'Available documents of {scenario}:', 'green')
                for document in files:
                    cprint(f"- {document}", 'blue')
            return [os.path.join(documents_path, file) for file in files]
        else:
            print("Documents folder doesn't exist.")
            return []

    def list_file_ids(self, scenario):
        if scenario not in self.available_scenarios:
            print(f"Scenario '{scenario}' not found. Available scenarios are: {', '.join(self.available_scenarios)}")
            return []
        
        files_path = os.path.join(self.data_path, scenario, "files")
        if os.path.exists(files_path):
            ids = [file for file in os.listdir(files_path) if not file.startswith('.')]
            return ids
        else:
            print("Files folder doesn't exist.")
            return []
        
    
    def get_files(self, id, verbose = False, flat = False):
        """
        get files by scenario/id

        Args:
            id (_type_): _description_
            verbose (bool, optional): _description_. Defaults to False.
            flat: if True, return a flat list of path of all the files in the directory wo directories. Defaults to False.
        """
        scenario = self.get_issue_scenario(id)
        if scenario not in self.available_scenarios:
            return f"Scenario '{scenario}' not found. Available scenarios are: {', '.join(self.available_scenarios)}"
        
        files_path = os.path.join(self.data_path, scenario, "files", id)
        if os.path.exists(files_path):
            if not flat:
                # files = [file for file in os.listdir(files_path) if (not file.startswith('.') and file != "unpacked")] 
                files = [file for file in os.listdir(files_path)] 
            else:
                # files = [file for file in flat_directory(files_path) if (not file.startswith('.') and file != "unpacked")]
                files = [file for file in flat_directory(files_path)]
                
            if verbose:
                cprint('====================', 'green')
                cprint(f'Available files of {scenario}/{id}:', 'green')
                for file in files:
                    cprint(f"- {file}", 'blue')
            if not flat:
                return [os.path.join(files_path, file) for file in files]
            else:
                return files
        else:
            print("Files folder doesn't exist.")
            return []
    
    def copy_files_to_dir(self, id, dest_dir, clear_dir=False, relevant_path = True):
        """
        Copy all files and directories of a specific dataset to a target directory.
        """
        files = self.get_files(id)
        if not files:
            print("No files or directories to copy.")
            return False

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        elif clear_dir:
            shutil.rmtree(dest_dir)
            os.makedirs(dest_dir)
        else:
            print("Destination directory already exists. Use clear_dir=True to remove all files and directories in the directory.")

        for item in files:
            src_path = item  # Assuming get_files returns the full path
            dst_path = os.path.join(dest_dir, os.path.basename(src_path))

            if os.path.isdir(src_path):
                # Copy entire directory tree
                shutil.copytree(src_path, dst_path)
            else:
                # Copy a single file
                shutil.copy(src_path, dest_dir)

        # Return a list of relative paths of all files/directories in dest_dir
        if not relevant_path:
            return ", ".join(["./" + os.path.join(dest_dir, item).split("/")[-1] for item in os.listdir(dest_dir)])
        else:
            return ", ".join([os.path.join(dest_dir, item) for item in os.listdir(dest_dir)])

    def get_issue_context(self, id):
        """
        Get the issue context for a specific dataset scenario.
        """
        for item in self.task_meta:
            if item["id"] == id:
                title = item['title'].strip()
                content = item['content'].strip()
                involved = item['involved']

                # Check if 'involved' needs to be JSON formatted
                if isinstance(involved, (list, dict)):
                    involved_str = json.dumps(involved, indent=4)
                else:
                    involved_str = involved.strip()

                str = f"title: {title}\ncontent: {content}\ninvolved: {involved_str}"
                return str.replace(r'\n', '\n').replace('\\',"")
    
    def get_issue_scenario(self, id):
        """
        get the scenarios of a specific id

        Args:
            id (_type_): _description_

        Returns:
            _type_: _description_
        """
        for item in self.task_meta:
            if item["id"] == id:
                return item["platform"]
        
        
    def write_input(self, scenario, id, hint_index, content: str):
        """
        Write the input to a file under `benchmark/{Scenario}/input`.
        The input file is the formatted prompt that replace <place holders> with the actual content, which will bring convenience for inputting to the model.
        """
        hints_path = os.path.join(self.data_path, scenario, "input", id)
        
        os.makedirs(hints_path, exist_ok=True)
        
        with open(os.path.join(hints_path,"input_with_hint_{}.txt".format(hint_index)), "w") as file:
            file.write(content)
            
    def get_written_input_ids(self, scenario):
        """
        get the ids whose input has been written, so there is no need to write again.
        """
        
        input_path = os.path.join(self.data_path, scenario, "input")
        
        if os.path.exists(input_path):
            return [file for file in os.listdir(input_path) if not file.startswith('.')]
        else:
            print_colored("Input folder doesn't exist.", 'red')
            return None
    
    def get_written_input_of_id(self, scenario, id):
        """
        return the path of the written input of a specific id.
        """
        
        input_path = os.path.join(self.data_path, scenario, "input", id)
        
        if os.path.exists(input_path):
            return [os.path.join(input_path,file) for file in os.listdir(input_path) if not file.startswith('.')]
        else:
            print_colored("Input folder doesn't exist.", 'red')
            return None
        
    def get_scenario_task_meta(self, scenario):
        """
        Get task_meta.json data of a specific scenario.
        """
            
        filtered = []
        for item in self.task_meta:
            if item["platform"] == scenario:
                filtered.append(item)
        return filtered

    def get_task_meta(self):
        """
        Get the task meta for all scenarios.
        """
        task_meta = {}
        task_meta_path = os.path.join(self.base_path, "task_meta.json")
        with open(task_meta_path, encoding='utf-8') as f:
            data = json.load(f)
            task_meta = data
        return task_meta["tasks"]
    
    def get_instance_meta(self, id):
        """
        Get the instance meta for a specific id.
        """
        for item in self.task_meta:
            if item["id"] == id:
                return item
        return None
    
    def get_hints(self, id):
        """
        Get the three level hints for id in a specific dataset scenario.
        """
        
        for item in self.task_meta:
            if item["id"] == id:
                hints = item['hints']
                hints = ["None"] + hints
                return hints
                
    def get_hint(self, id, hint_level):
        """
        get a specific hint of a specific id, hint_level
        """
        if hint_level == -1:
            return "None"
        elif hint_level in [0,1,2]:
            for item in self.task_meta:
                if item["id"] == id:
                    hints = item['hints']
                    return hints[hint_level]
        
    def get_complete_info(self, id):
        """
        Get the evaluation information for evaluator
        """
        issue = self.get_issue_context(id)
        hints = self.get_hints(id)
        
        return issue, hints
    
    def get_test_case_info(self, id, dest_dir, use_doc = False, relevant_path = True):
        """
        Get the test information for the test model.
        doc_path is the path of the document file, need further processing.
        relevant_path: if True, return the relative path of the files, otherwise return the absolute path.
        """
        dataset_files = self.copy_files_to_dir(id, dest_dir=dest_dir, clear_dir=True, relevant_path=relevant_path)
        hints = self.get_hints(id)
        doc_paths = []
        
        if use_doc:
            if not scenario:
                scenario = self.get_issue_scenario(id)
            doc_paths = self.get_documents(scenario)
        
        return hints, dataset_files, doc_paths
        
    
    def get_ref_materials(self):
        """
        return the path of the reference materials
        """
        material_path = os.path.join(self.base_path, "materials")
        
        materials = [file for file in os.listdir(material_path) if not file.startswith('.')]
        
        return [os.path.join(material_path, file) for file in materials]
    

    def _task_manage_helper(self, operation: str = 'add', file_id: str = None, content: Dict = None, scenario="BigBench"):
        """
        Manage tasks in a JSON-based storage by performing various operations.

        Parameters:
        - operation (str): Specifies the operation to perform. Accepted values are "add" or "update".
        - file_id (str, optional): The unique identifier for the task. Required for "update" operation.
        - content (dict, optional): A dictionary containing the task details. Required for "add" and "update" operations. 
                                    Should include keys "title", "content", "involved". If None, regenerate hint only.

        Operations:
        - "add": Generates a new task ID and creates a new task entry with title, content, and an gpt-generated hint.
        - "update": Updates the title and/or content of an existing task identified by task_id. The hint is automatically updated.
        """
        
        meta_path = os.path.join("benchmark", "task_meta.json")
        try:
            with open(meta_path, "r") as f:
                meta_data = json.load(f)
        except FileNotFoundError:
            raise FileNotFoundError(f"{meta_path} not found")


        if operation == 'add':
            if not content:
                raise ValueError("Add Error: content is empty")
            file_id = gen_uid()
            print(f"Adding Task data {file_id}")
            print(f"Generating Hints")
            hints_lst = generate_hints(content)
            print(f"Generated Hints: \n\n {hints_lst} \n\n")
            new_task = {
                "title": content['title'],
                "content": content['content'],
                "involved": content['involved'],
                "hints": hints_lst,
                "id": file_id
            }   
            meta_data["tasks"].append(new_task)
        elif operation == 'update':
            if not file_id:
                raise ValueError("Update Error: file_id is empty")
            existed_index = None
            tasks_lst = meta_data["tasks"]
            for i in range(len(tasks_lst)): 
                if tasks_lst[i]["id"] == file_id:
                    existed_index = i
                    break
            if existed_index is None:
                raise ValueError(f"Update Error: file_id {file_id} not found")

            if content:
                print(f"Updating Task data {file_id}")
                for k, v in content.items():
                    meta_data["tasks"][existed_index][k] = v
            else:
                print(f"No provided content, regenerating hints then")
                content = {
                    "title": meta_data["tasks"][existed_index]['title'],
                    "content": meta_data["tasks"][existed_index]['content'],
                    "involved": meta_data["tasks"][existed_index]['involved'],
                }
                
            print(f"Regenerating Hints")
            new_hints_lst = generate_hints(content)
            meta_data["tasks"][existed_index]["hints"] = new_hints_lst
        else:
            raise ValueError("InValid Operation")

        print(f"Saving to task_meta.json\n {content}")
        with open(meta_path, "w") as file:
            json.dump(meta_data, file, indent=4)
            
        file_path = os.path.join('benchmark', 'platforms', scenario, 'files', file_id)
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        print_colored(f"Please put the issues files in the corresponding folder: {file_path}", 'red')

    def manage_task_meta(self):
        """
        Manage tasks in a JSON-based storage by performing various operations.
        """
        available_scenarios = self.get_scenarios()
        scenario  = input(f"Please select a scenario from {available_scenarios}: ")
        while scenario not in available_scenarios:
            scenario = input(f"Invalid scenario. Please select a scenario from {available_scenarios}: ")
        
        while True:
            option = input("Please select an operation: 'add', 'update' or 'break': ")
            
            if option == 'break':
                break
            
            while option not in ['add', 'update']:
                option = input("Invalid operation. Please select an operation: 'add' or 'update': ")
                
            if option == 'update':
                avaliable_file_ids = self.list_file_ids(scenario)
                file_id = input("Please enter the file ID of the task you want to update: ")
                
                while file_id not in avaliable_file_ids:
                    file_id = input("This file ID doesn't exist. Please enter the file ID of the task you want to update: ")
                
                
            title = get_user_input_multiline("Please enter the title of the task: ")
            
            content = get_user_input_multiline("Please enter the content of the task: ")
            
            involved = []
            
            file_number = (input("Please enter the number of files involved in this issue (NOTE: The env could have multiple files, here we only need to point out the files that involved in this issue): "))
            # check the validity of the file number, not a number or less than 1
            while not file_number.isdigit() or int(file_number) < 1:
                file_number = input("Invalid input. Please enter a number greater than 0: ")
            
            file_number = int(file_number)
                
            for i in range(int(file_number)):
                file_name = input(f"Please enter the name of file {i+1}: ")
                context = get_user_input_multiline(f"Please enter the context of file {i+1} that involving this issue: ")
                involved.append({"name": file_name, "context": context})
            
            tmp = {
                "title": title,
                "content": content,
                "involved": involved,
            }
            
            
            if option == 'update':
                self._task_manage_helper(operation='update', file_id=file_id, content=tmp, scenario=scenario)
            elif option == 'add':
                print(tmp)
                self._task_manage_helper(operation='add', content=tmp, scenario=scenario)
            else:
                raise ValueError("Invalid operation")
         
    def report_statistics(self, verbose = False):
        """
        Report the statistics of the benchmark.
        """
        
        import zipfile
        from pipeline.utils import calculate_token_len
        scenarios = self.get_scenarios()
        total_issues = 0
        files_dict = {}
        including_non_text_files = {}
        type_dict = {
            "single-file_single-issue": 0,
            "single-file_multi-issue": 0,
            "multi-file_single-issue": 0,
            "multi-file_multi-issue": 0
        }
        tag_dict = {
  
        }
        print("Task Meta Length: ", len(self.task_meta))
        for item in self.task_meta:
            if "single-file" in item["type"] and "single-issue" in item["type"]:
                type_dict["single-file_single-issue"] += 1
            elif "single-file" in item["type"] and "multi-issue" in item["type"]:
                type_dict["single-file_multi-issue"] += 1
            elif "multi-file" in item["type"] and "single-issue" in item["type"]:
                type_dict["multi-file_single-issue"] += 1
            elif "multi-file" in item["type"] and "multi-issue" in item["type"]:
                type_dict["multi-file_multi-issue"] += 1
            else:
                print(f"Unknown type: {item['type']} {item['id']} {item['platform']}")
                raise ValueError("Unknown type")
            tags = item["tags"]
            for tag in tags:
                split_tags = tag.split("/")
                for split_tag in split_tags:
                    if split_tag in tag_dict:
                        tag_dict[split_tag] += 1
                    else:
                        tag_dict[split_tag] = 1
                
        print_colored("The distribution of the types of issues in the benchmark:", 'green')
        for key, value in type_dict.items():
            print_colored(f"{key}: {value}", 'blue')
        print_colored('====================', 'green')
        
        print_colored("The distribution of the tags of issues in the benchmark:", 'green')
        print("# data-problem: ", tag_dict["data-problem"])
        print("# document-problem: ", tag_dict["document-problem"])
        print("# infrastructure-problem: ", tag_dict["infrastructure-problem"])
        print("# ethical-risk-problem: ", tag_dict["ethical-legal-risk"])
        with open("tag_distribution.json", "w") as f:
            json.dump(tag_dict, f)
            
        for scenario in scenarios:
            print_colored('====================', 'green')
            print_colored(f'Scenario: {scenario}', 'green')
            ids = [id for id in self.list_file_ids(scenario)]
            print_colored(f"Number of issue examples: {len(ids)}", 'blue')
            files_dict[scenario] = ids
        
        if verbose:
            # get the average number of files in issue
            uploaded_files_lens = []
            uploaded_file_content_lens = []
            involved_files_lengs = []
            # excluded_id = ["7fd18d6e-adcf-43c4-bea2-8c1392e9eab7", "0a9c2e90-fcca-48c1-8afd-e1393d149388"]
            excluded_id = []
            stats_id_list = []
            from tqdm import tqdm
            for scenario in scenarios:
                print_colored('====================', 'green')
                print_colored(f'Scenario: {scenario}', 'green')
                documents = self.get_documents(scenario)
                print_colored(f"Number of documents: {len(documents)}", 'blue')
                ids = tqdm([id for id in self.list_file_ids(scenario) if id not in excluded_id], desc="Processing files")
                for id in ids:
                    files = self.get_files(id, flat=True)
                    file_number = 0
                    # if len(files) >= 50:
                    #     exit()
                    file_content_length = 0
                    for file in files:
                        ids.set_description(f"Processing files: {file} at {id} / {scenario}")
                        # if is zip, unzip and count the number of files; but don't unzip the files in place!
                        if zipfile.is_zipfile(file):
                            file_number += 1
                            with zipfile.ZipFile(file, 'r') as zip_ref:
                                # num_files_in_zip = len(zip_ref.infolist())
                                # file_number += num_files_in_zip
                                for zip_info in zip_ref.infolist():
                                    if not is_non_text_file(zip_info.filename):
                                        with zip_ref.open(zip_info) as f:
                                            content = f.read()
                                            file_content_length += (calculate_token_len(str(content)))
                                    else:
                                        including_non_text_files_list = including_non_text_files.get(id, [])
                                        including_non_text_files_list.append(zip_info.filename)
                                        including_non_text_files[id] = including_non_text_files_list
                        elif file.endswith('.pdf'):
                            content = get_pdf_context(file)
                            file_content_length += (calculate_token_len(content))
                            file_number += 1
                        elif file.endswith('.csv'):
                            content = get_csv_context(file)
                            file_content_length += (calculate_token_len(content))
                            file_number += 1
                        elif not is_non_text_file(file):
                            encodings_to_try = ['utf-8', 'iso-8859-1', 'windows-1252', 'latin-1']
                            for encoding in encodings_to_try:
                                try:
                                    with open(file, 'r', encoding=encoding) as f:
                                        content = f.read()
                                        file_content_length += (calculate_token_len(content))
                                    file_number += 1
                                    break  # If successful, break the loop
                                except UnicodeDecodeError:
                                    continue  # If unsuccessful, try the next encoding
                            else:
                                print(f"Warning: Could not decode file {file} with any of the attempted encodings.")
                        else:
                            # For non-text file types (e.g., images, audio), just count the file without adding to file_content_length
                            including_non_text_files_list = including_non_text_files.get(id, [])
                            including_non_text_files_list.append(file)
                            including_non_text_files[id] = including_non_text_files_list
                            file_number += 1
                    # if file_number >= 100:
                    #     exit()
                    stats = {
                        "file_number": file_number,
                        "file_content_length": file_content_length
                    }
                    stats_id_list.append((id, stats))
                    uploaded_files_lens.append(file_number)
                    uploaded_file_content_lens.append(file_content_length)
                print_colored(f"Number of issue examples: {len(ids)}", 'blue')
                total_issues += len(ids)
            print_colored('====================', 'green')
            print_colored(f'Total number of issues in DQ-Bench: {total_issues}', 'green')
            print_colored('====================', 'green')
            print(uploaded_file_content_lens[:5])
            print(sorted(uploaded_files_lens, reverse=True)[:5])
            print(len(uploaded_files_lens))
            print_colored(f'Average number of files in an issue example: {sum(uploaded_files_lens) / len(uploaded_files_lens)}', 'green')
            print_colored(f'Average number of tokens in an issue example: {sum(uploaded_file_content_lens) / len(uploaded_file_content_lens)}', 'green')
            print_colored(f'Number of issues with non-text files: {len(including_non_text_files)}', 'green')
            # max_non_text_files = np.max([len(files) for files in including_non_text_files.values()])
            # print_colored(f'Maximum number of non-text files in an issue example: {max_non_text_files}', 'green')
            # min_non_text_files = np.min([len(files) for files in including_non_text_files.values()])
            # print_colored(f'Minimum number of non-text files in an issue example: {min_non_text_files}', 'green')
            # print_colored(f'Average number of non-text files in an issue example: {np.mean([len(files) for files in including_non_text_files.values()])}', 'green')
            print_colored(f'Non-text files in an issue example: {including_non_text_files.keys()}', 'green')
            
            task_meta_path = os.path.join(self.base_path, "task_meta.json")
            with open(task_meta_path, "r") as f:
                task_meta = json.load(f)["tasks"]
            
            for id, stats in stats_id_list:
                for item in task_meta:
                    if item["id"] == id:
                        item["file_number"] = stats["file_number"]
                        item["file_content_length"] = stats["file_content_length"]

            with open(task_meta_path, "w") as f:
                json.dump({"tasks": task_meta}, f, indent=4)
            
            # Log transformation
            log_uploaded_files_lens = np.log1p(uploaded_files_lens)
            log_uploaded_file_content_lens = np.log1p(uploaded_file_content_lens)

            # # Distribution of the number of files in an issue example
            plt.hist(log_uploaded_files_lens, bins=5, edgecolor='black')
            plt.xlabel('Log(Number of files in an issue example)')
            plt.ylabel('Number of examples')
            plt.title('Distribution of the number of files in an issue example (Log Scale)')
            plt.show()

            # # Distribution of the number of characters in an uploaded file
            plt.hist(log_uploaded_file_content_lens, bins=50, edgecolor='black')
            plt.xlabel('Log(Number of characters in an uploaded file)')
            plt.ylabel('Number of examples')
            plt.title('Distribution of the number of characters in an uploaded file (Log Scale)')
            plt.show()
            
            # # Kernel Density Estimate (KDE) plot for the number of files in an issue example
            plt.figure(figsize=(10, 5))
            sns.kdeplot(uploaded_files_lens, shade=True)
            plt.xlabel('Number of files in an issue example')
            plt.ylabel('Density')
            plt.title('KDE Plot of the number of files in an issue example')
            plt.show()

            # # KDE plot for the number of characters in an uploaded file
            plt.figure(figsize=(10, 5))
            sns.kdeplot(uploaded_file_content_lens, shade=True)
            plt.xlabel('Number of characters in an uploaded file')
            plt.ylabel('Density')
            plt.title('KDE Plot of the number of characters in an uploaded file')
            plt.show()
        
    
        return files_dict
    
    def get_output_format(self):
        print_colored('==================== An example of Curator output format', 'green')
        
        format = """
 <example>
{
    “issue”: “Wrong name in README”, # Identified Problem in Specific File
    “evidence”: ” # My cool task This is a description of my cool task... in README”, # The specific content found in the file that supports the identified issue. This should be a direct quote of the context where you find the issue.
    “description”: “The name of the task in the README is incorrect. It should be ‘My awesome task’ instead of ‘My cool task’.“, # A detailed explanation of the issue discovered, referencing specific content found in the file. Highlight how it deviates from expected standards or instructions provided in the <hint>

},
...
</example>
        """
        print_colored(format, 'blue')
        print_colored('====================', 'green')

        return format
    
    def format_output(self, issue, evidence, description):
        """
        Format the output in the specified format.
        """
        output = {
            "issue": issue,
            "evidence": evidence,
            "description": description
        }
        output = json.dumps(output, indent=4)
        print_colored('====================', 'green')
        print_colored('Formatted output:', 'green')
        print(output)
        print_colored('====================', 'green')
        return output
    
EVAL_MODEL_ID = Literal["Meta-Llama-3-70B-Instruct", "gpt-4-0125-preview", "gpt-3.5-turbo"]

class Evaluator:
    def __init__(self, test_model_id: Literal["gpt-4-0125-preview"], eval_model_id: EVAL_MODEL_ID):
        
        self.manager = BenchmarkManager()
        self.eval_model_id = eval_model_id
        if "Llama" in eval_model_id:
            eval_with_metrics_path = "prompts/evaluate_with_metrics_llama3.txt"
            self.evaluator =  Llama3Evaluator(model_id = self.eval_model_id, cache_dir = '/home/shared/Llama3/')
        elif "gpt" in eval_model_id:
            self.model_id = eval_model_id
            eval_with_metrics_path = "prompts/evaluate_with_metrics.txt"
            self.evaluator =  OpenAIEvaluator(model_id = self.eval_model_id, api_key=get_openai_api_key())
            
        with open (eval_with_metrics_path, "r") as f:
            self.input_msg = f.read()
            
        self.test_model_id = test_model_id

    def run_evaluation(self, stamp = time.strftime("%Y-%m-%d-%H-%M", time.localtime()), test_stamp = None,  label_path = False):
        """
        run the evaluation for the Curator using Evaluator. 
        The results from Curator is under pipeline/output/{model}/{scenario}/{id}/hint_level_{hint_level}/output.txt
        We use this function to get the results shown in Fig.4 in the paper.
        
        NOTE: If you want to use this function, you need to put your results at the correct path, and change the test_model_id in the __init__ function.

        Args:
            stamp: the stamp to distinguish your run. Defaults to time.strftime("%Y-%m-%d-%H-%M", time.localtime()).
            label_path: if set to true, then evaluate the performance of the evaluator with labels annotated by human. Defaults to False.
        """

        if label_path and label_path != "":
            comp_w_label = True
            with open (label_path, "r") as f:
                labels = json.load(f)
            print_colored("====================================", "green")
            print_colored("loaded labels", "green")
        else:
            print_colored("====================================", "green")
            print_colored("no labels provided", "green")
            comp_w_label = False
            
            
        scenarios = self.manager.get_scenarios()


        res = {}

        cnt = 0
        total_cost = 0
        total_time  = 0
        
        
        correct = 0
        soft_correct = 0
        dir = f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}"
        os.makedirs(dir, exist_ok=True)
        # used for incremental evaluation
        json_path = f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}/res.json"
        stats_path = f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}/stats.json"
        wrong_path = f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}/wrong.json"
        res = {}
        stats = {"correct": 0, "soft_correct": 0, "total": 0}
        wrong = {"strict": {}, "soft": {}}
        completed = set()
        if os.path.exists(json_path):
            print_colored("\n====================================", "green")
            print_colored("loaded previous results \n", "green")
            print_colored("====================================", "green")
            with open(json_path, 'r') as f:
                res = json.load(f)
            
            for scene in res:
                if scene == "stats": # skip the stats
                    continue
                for id in res[scene]:
                    for hint in res[scene][id]:
                        completed.add((scene, id, hint))
        print_colored("====================================", "green")
        print("already evaluated ", len(completed))   
        print_colored("====================================", "green")   
        if comp_w_label:
            if os.path.exists(stats_path):
                with open(stats_path, 'r') as f:
                    stats = json.load(f)
                correct = stats["correct"]
                soft_correct= stats["soft_correct"]
                cnt = stats["total"]
            
            if os.path.exists(wrong_path):
                with open(wrong_path, 'r') as f:
                    wrong = json.load(f)

        if test_stamp:
            test_output_dir = f"pipeline/output/{self.test_model_id}/{test_stamp}"
        else:
            test_output_dir = f"pipeline/output/{self.test_model_id}"
        
        for scenario in scenarios:
            test_output_scenario_dir = f"{test_output_dir}/{scenario}"
            # the output from the test model
            if not os.path.exists(test_output_scenario_dir):
                continue
            if scenario not in res:    
                res[scenario] = {}
            ids = [id for id in os.listdir(test_output_scenario_dir) if not id.startswith(".")]
            
            from tqdm import tqdm
            id_bar = tqdm(ids, total=len(ids))
                
            for id in id_bar:
                issue, hints = self.manager.get_complete_info(id)
                for hint_level in range(-1,3):
                    hint = hints[hint_level + 1]

                    if comp_w_label: # if not in the label, then skip
                        if not find_id_in_label(labels, scenario, id, str(hint_level)): # if the hint level is not annotated
                            continue
                        
                    print_colored(f"evaluating {scenario} {hint_level} {id}", "green")
                    if id not in res[scenario]:
                        res[scenario][id] = {}

                    if comp_w_label:
                        cnt += 1
                    
                    if (scenario, id, str(hint_level)) in completed: # if the hint level is already evaluated
                        # print(f"skip gen evaluation for {scenario} {hint_level} {id}")
                        pass
                    else:
                        test_model_output_path = f"{test_output_scenario_dir}/{id}/hint_level_{hint_level}/output.txt"

                        if os.path.exists(test_model_output_path):
                            with open(test_model_output_path, "r") as f:
                                agent_answer = f.read()
                        else:
                            print("not found", test_model_output_path)
                            continue
                        
                        save_path = f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}/{scenario}/{id}/hint_level_{hint_level}"
                        if not os.path.exists(save_path):
                            os.makedirs(save_path, exist_ok=True)
                        
                        prompt = self.input_msg.replace("<ISSUE>", issue).replace("<HINT>", hint).replace("<ANSWER>", agent_answer)
                        
                        with open(os.path.join(save_path, "prompt.txt"), "w") as f:
                            f.write(prompt)
                            
                        prompt = prompt.split("--------------------------------")
                        sys_msg = prompt[0]
                        usr_msg = "--------------------------------".join(prompt[1:])
                        
                        start_time = time.time()
                        result, cost = self.evaluator.evaluate(usr_msg = usr_msg, sys_msg = sys_msg, save_path = save_path, first_round_voters=2)
                        end_time = time.time() - start_time
                        
                        total_cost += cost
                        total_time += end_time
                        
                        res[scenario][id][str(hint_level)] = result
                        
                        print({"result": result, "total cost": total_cost, "total time": total_time, "average_cost": div_by_zero(total_cost,cnt), "average_time": div_by_zero(total_time,cnt)})
                    
                    if comp_w_label:
                        result = res[scenario][id][str(hint_level)]
                        human_label = find_id_in_label(labels, scenario, id, str(hint_level), bool=False)
                        
                        if human_label == result:
                            correct += 1
                        elif (result == "success" and human_label == "partially") or (result == "partially" and human_label == "success"):
                            soft_correct += 1
                            wrong["soft"][f"{scenario}-{id}-{hint_level}"] = {"result": result, "human_label": human_label}
                        else:
                            wrong["strict"][f"{scenario}-{id}-{hint_level}"] = {"result": result, "human_label": human_label}

                        stats = {
                            "correct": correct,
                            "soft_correct": soft_correct,
                            "total": cnt
                        }
                    
                        id_bar.set_postfix({"result": result, "correct": correct, "soft_correct": soft_correct, "total": cnt, "total cost": total_cost, "total time": total_time, "average_cost": div_by_zero(total_cost,cnt), "average_time": div_by_zero(total_time,cnt)})

                    
                
                # update the json file for each id
                with open(json_path, "w") as f:
                    json.dump(res, f)
                
                with open(stats_path, "w") as f:
                    json.dump(stats, f)
                    
                if comp_w_label:
                    with open(f"pipeline/eval_log/eval={self.eval_model_id}/test={self.test_model_id}/{stamp}/wrong.json", "w") as f:
                        json.dump(wrong, f)
                        
        if comp_w_label:
            print("correct:", correct)
            print("soft correct:", soft_correct)
            print("total", cnt)
            if cnt > 0:
                print("Accuracy:", correct/cnt)
                print("Soft Accuracy:", (correct + soft_correct)/cnt)

    def run_single_evaluation(self, id, hint_level, input_path, save_dir):
        """
        this function is used to evaluate a single test case, for the convenience of benchmark users

        Args:
            id (str): the id of the test case
            hint_level (int): the hint level used for the test case, from 0 to 3
            input_path: (str): the path of the input file from Curator
            save_dir (str): the directory to save the results
        """
        issue, hints = self.manager.get_complete_info(id)
        if hint_level not in range(0,4):
            raise ValueError("hint_level should be 0, 1, 2, or 3")
        hint = hints[hint_level + 1]
        
        if not os.path.exists(input_path):
            raise FileNotFoundError(f"input_path {input_path} not found")
        
        with open(input_path, "r") as f:
            agent_answer = f.read()
        
        prompt = self.input_msg.replace("<ISSUE>", issue).replace("<HINT>", hint).replace("<ANSWER>", agent_answer)
        
        os.makedirs(save_dir, exist_ok=True)
        with open(os.path.join(save_dir, "prompt.txt"), "w") as f:
            f.write(prompt)
        
        prompt = prompt.split("--------------------------------")
        sys_msg = prompt[0]
        usr_msg = "--------------------------------".join(prompt[1:])
        

        result, cost = self.evaluator.evaluate(usr_msg = usr_msg, sys_msg = sys_msg, save_path = save_dir, first_round_voters=2)
        # result belong to "failed", "partially", "success"
        
        return result, cost

        
        
     
    
if __name__ == '__main__':

    manager = BenchmarkManager()
    
    manager.get_output_format()

    # List all scenarios (will be fetched during initialization)
    # set verbose to True to print the list in a more readable format
    
    # x = manager.get_scenarios(verbose=True)
    
    # print(x)

    # # Access documents for a specific scenario
    # x = manager.get_documents('GLI', True)
    # print(x)

    # # Access datasets in a scenario and list files in a specific dataset
    # x = manager.get_files("BigBench","9020050d-0834-4eea-9e73-fe3bd7606fca", verbose=True)
    # print(x)
    


    

