import os
import shutil
from argument import args 
from dotenv import load_dotenv
from gpt_utils import get_gpt_response, get_llama_response

### Requires the file to be in the following format: "Image-file ... ; Answer: {label}"
def post_process():
    answer_list = {}
    
    # read line by line
    with open(args.labeling_result_path, 'r') as answers:
        answers = answers.readlines()
        for answer in answers:
            if "Answer:" in answer:
                print(answer)
                if "Image file-" in answer:
                    answer = answer.split(";")[1]
                label = answer.split(" ")[1:]
                real_label = ""
                for lab in label:
                    real_label += lab + " "
                real_label = real_label[:-1]
                real_label = real_label.lower().strip().strip(".")
                
                if real_label not in answer_list:
                    answer_list[real_label] = 1
                else:
                    answer_list[real_label] += 1
    # print(answer_list)
    print("Number of distinct labels: ", len(answer_list))

    # sanity check
    print(answer_list)
    print(sum(answer_list.values()))
    
    return answer_list


if __name__ == "__main__":
    load_dotenv()
    gpt_url = os.getenv("URL")
    api_key = os.getenv("API_KEY")
    user = os.getenv("USER")
    model = os.getenv("MODEL")

    # post process gpt_labels.txt
    answer_list = post_process()
    
    # read system prompt
    with open(args.clustering_system_prompt_path, 'r') as file:
        system_prompt = file.read()
        system_prompt = system_prompt.replace("[__NUM_CLASSES_CLUSTER__]", str(args.num_classes))
        system_prompt = system_prompt.replace("[__LEN__]", str(len(answer_list)))
        # print(system_prompt)

        # feed into gpt.
        user_prompt = f"list of labels: {answer_list}\n"
        user_prompt += f"num_classes: {args.num_classes}"

        if args.llama:
            response = get_llama_response(system_prompt, user_prompt, url)
        else:
            response = get_gpt_response(system_prompt, user_prompt, gpt_url, api_key, user, model)
        
        if response == "ERROR_CONTEXT_LENGTH":
            # try with gpt-4-32k
            api_key = os.getenv("API_KEY_32K")
            user = os.getenv("USER_32K")
            model = os.getenv("MODEL_32K")
            if args.llama:
                response = get_llama_response(system_prompt, user_prompt, url)
            else:
                response = get_gpt_response(system_prompt, user_prompt, gpt_url, api_key, user, model)

        # save results
        with open(args.clustering_result_path, 'w') as file:
            file.write(response)