from pyexpat import model
import re
from sys import exc_info
from .models import OpenAIModel, OpenAIModelRephraser
import time
from pprint import pprint
from benchmark.api import BenchmarkManager
import os
from tqdm import tqdm
import time
from argparse import ArgumentParser
from .get_input import get_input, load_human_labels

support_file_format =  ["c", "cpp", "css", "csv", "docx", "gif", "html", "java", "jpeg", "jpg", "js", "json", "md", "pdf", "php", "png", "pptx", "py", "rb", "tar", "tex", "ts", "txt", "xlsx", "xml", "zip", "yaml", "LICENSE", "ipynb"]

rephrase_comparision = False # NOTE: set to True if you want to compare the rephrase output with normal output, in order to check whether the evaluator will have length bias.
if rephrase_comparision:
    rephraser = OpenAIModelRephraser(model_id="gpt-4o-2024-08-06")
else:
    rephraser = None

def main():
    parser = ArgumentParser(description=__doc__)
    parser.add_argument(
        "--max_files",
        type=int,
        help="The maximum number of files that can be processed at once",
        default=20,
    )
    parser.add_argument(
        "--scenarios",
        nargs='+',
        help="The scenarios to process, usage ex. --scenarios Kaggle TensorFlow HuggingFace",
        default=[],
    )
    
    parser.add_argument(
        "--model_name",
        type=str,
        help="The model name to use",
        default="gpt-4-0125-preview", # gpt-4-0125-preview, gpt-3.5-turbo (0125), gpt-4o-mini (2024-07-18)
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        help="The output directory",
        default="pipeline/output",
    )
    
    parser.add_argument(
        "--stamp",
        type=str,
        help="The stamp to distinguish your run",
        default = None,
    )
    
    parser.add_argument(
        "--label_path",
        type=str,
        help="The path to the label file. If not provided, then we won't compare the results with the labels",
        default=None,
    )
    
    parser = parser.parse_args()
    
    rephrase_prompts = ["""You are provided with a description text, and your task is to make it more detailed and elaborate. Ensure that you do not add any new information or content that was not present in the original context. Your goal is to expand upon the existing meaning without altering it. Additionally, maintain the same format as the original answer. Don't mention this prompt in your response.
    =================== Start:""",
    """You are provided with a description text. Your task is to rephrase it concisely while ensuring all essential information is retained. Do not add new information or omit any key details from the original context. Remember to maintain and keep the same format as the original answer. When answer is already short, you don't have to condense it too much. Do not mention this prompt in your response.
    =================== Start:"""
    ]
    
    order = 1


    output_dir = parser.output_dir
    model_name = parser.model_name
    stamp = parser.stamp
    if not stamp:
        output_dir = os.path.join(output_dir, model_name)
    else:
        output_dir = os.path.join(output_dir, model_name, stamp)
    if rephrase_comparision:
        output_rephrase_dir = os.path.join(output_dir + "_succinct")
    os.makedirs(output_dir, exist_ok=True)
    too_much_file = []
    not_supported_id = []
    
    if "gpt" in model_name:
        model = OpenAIModel(model_name=model_name)
    elif model_name == "langchain_agent":
        raise NotImplementedError("langchain_agent is not implemented")
        # from .models.lanchain.langchain_agent import LangchainAgent
        # model = LangchainAgent()

    manager = BenchmarkManager()    
    
    scenarios = parser.scenarios
    if len(scenarios) == 0:
        scenarios = manager.get_scenarios()

    excluded_ids = []
    
    if parser.label_path is not None:
        """
        Only process the labeled data, in order to cut down the cost
        """
        pprint("Only process the labeled data")
        labels = load_human_labels(path = parser.label_path)
    else:
        labels = set()
    
    total_cost = 0
    total_model_time = 0
    process_sample_num = 0 
    for scenario in scenarios:
        print(f"Processing {scenario}")
        ids = manager.get_written_input_ids(scenario)
        if not ids:
            get_input()
            ids = manager.get_written_input_ids(scenario)
        if os.path.exists(os.path.join(output_dir, scenario)):
            if not rephrase_comparision:
                processed_ids = [id for id in os.listdir(os.path.join(output_dir, scenario)) if not id.startswith(".") and len(os.listdir(os.path.join(output_dir, scenario, id)))==4]
            else:
                processed_ids = [] # ignore the processed files
            pprint(f"Number of processed ids: {len(processed_ids)}")
            inputs_id = [id for id in ids if (id not in processed_ids and id not in excluded_ids)]
            pprint(f"Number of inputs_id: {len(inputs_id)}")
            print(inputs_id[:5])
        else:
            inputs_id = [id for id in ids if id not in excluded_ids]


        id_bar = tqdm(inputs_id, total=len(inputs_id))
        # id_bar = tqdm(modified_hints, total=len(modified_hints))
        for id in id_bar:
            files = manager.get_files(id, flat=True)

            for file in files:
                if file.split(".")[-1] not in support_file_format:
                    not_supported_id.append(f"{scenario}-{id}")
                    # continue
            inputs_prompt = manager.get_written_input_of_id(scenario, id)
                
            if len(files) > parser.max_files:
                too_much_file.append(f"{scenario}-{id}")
                continue
            for input_path in inputs_prompt:
                hint_level = input_path.split("_")[-1].split(".")[0]
                if len(labels) > 0 and (id, hint_level) not in labels:
                    pprint(f"skip {id}, {hint_level} not in labels")
                    continue
                if rephrase_comparision:
                    pprint(f"Processing {scenario} {id} hint_level_{hint_level} in rephrase mode")
                    if os.path.exists(os.path.join(output_dir, scenario, id, f"hint_level_{hint_level}")) and os.path.exists(os.path.join(output_rephrase_dir, scenario, id, f"hint_level_{hint_level}")):
                        pprint(f"skip {scenario} {id} hint_level_{hint_level}")
                        continue
                    elif os.path.exists(os.path.join(output_dir, scenario, id, f"hint_level_{hint_level}")):
                        with open(os.path.join(output_dir, scenario, id, f"hint_level_{hint_level}", "output.txt"), "r") as f:
                            output = f.read()
                        rephrase_prompt = rephrase_prompts[order]
                        
                        print("rephrase_prompt: ", rephrase_prompt)
                        print("output: ", output)

                        output_rephrase, _ = rephraser.run(system_msg=rephrase_prompt, user_msg=output)
                        
                        print("output_rephrase: ", output_rephrase)
                        
                        output_path_rephrase = os.path.join(output_rephrase_dir, scenario, id, f"hint_level_{hint_level}")
                
                        if not os.path.exists(output_path_rephrase):
                            os.makedirs(output_path_rephrase, exist_ok=True)
                            
                        with open(os.path.join(output_path_rephrase,"output.txt"), "w") as f:
                            f.write(output_rephrase)

                else:
                    if os.path.exists(os.path.join(output_dir, scenario, id, f"hint_level_{hint_level}")):
                        pprint(f"skip {scenario} {id} hint_level_{hint_level}")
                        continue

                id_bar.set_description(f"Processing {scenario} {id} hint_level_{hint_level}")
                with open(input_path, "r") as f:
                    input = f.read()
                start = time.time()
                # additional_lengthy_prompt = "Please provide an answer that is highly detailed, thoroughly explained, and covers the contextual evidence comprehensively with extensive information."
                # input = input.split("Respond below:")
                # input = input[0] + additional_lengthy_prompt + input[1] + "\n\nRespond below:"
                try:
                    res = model.run(input=input, file_paths=files)
                except Exception as e:
                    print(f"Error in {scenario} {id} hint_level_{hint_level}")
                    print(e)
                            
                if res is None:
                    pprint(f"skip {scenario} {id} hint_level_{hint_level}, since the model return None")
                    continue
                output, cost = res
                
                if rephrase_comparision:
                    rephrase_prompt = rephrase_prompts[order]

                    output_rephrase, _ = rephraser.run(system_msg=rephrase_prompt, user_msg=output)

                
                t = time.time() - start
                total_model_time += t
                output_path = os.path.join(output_dir, scenario, id, f"hint_level_{hint_level}")
                if not os.path.exists(output_path):
                    os.makedirs(output_path, exist_ok=True)
                with open(os.path.join(output_path,"output.txt"), "w") as f:
                    f.write(output)
                    
                if rephrase_comparision:
                    output_path_rephrase = os.path.join(output_rephrase_dir, scenario, id, f"hint_level_{hint_level}")
                
                    if not os.path.exists(output_path_rephrase):
                        os.makedirs(output_path_rephrase, exist_ok=True)
                        
                    with open(os.path.join(output_path_rephrase,"output.txt"), "w") as f:
                        f.write(output_rephrase)
                    
  
                
                total_cost += cost
                process_sample_num += 1
                id_bar.set_postfix(cost=total_cost, average_cost=total_cost / process_sample_num, time=t, average_time=total_model_time / process_sample_num)

                    
    # print("too much: ",too_much_file)
    # print("file-format not support:", not_supported_id)
    
    # save_to_log
    log_path = os.path.join(output_dir)
    os.makedirs(log_path, exist_ok=True)
    log_path = os.path.join(log_path, "log.txt")
    with open(log_path, "a") as f:
        f.write("too much file: ")
        f.write(str(too_much_file))
        f.write("\n")
        f.write("file-format not support:")
        f.write(str(not_supported_id))
            

        
if __name__ == "__main__":
    main()
    
    # python -m pipeline.run_inference --label_path="pipeline/label_more.json"