import json
from tqdm import tqdm
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from get_res import mixtral_8x22b, GPT4, clean_and_convert_to_dict, find_answer, llama3_70b

def get_free_res(question, golden_answer, response, model='gpt-4'):
    system_prompt = """You are an impartial judge. I will provide you with a question, a 'gold standard' answer, and a response that needs evaluation. Your task is to assess the quality of the response in comparison to the 'gold standard' answer. Please adhere to the following guidelines:

1. Start your evaluation by comparing the response to the 'gold standard' answer. Offer a brief explanation highlighting similarities and differences, focusing on relevance, accuracy, depth, and level of detail.
2. Conclude your evaluation with a score from 1 to 5, where 1 indicates the response is mostly irrelevant to the 'gold standard' answer, and 5 indicates it is very similar or equivalent.
3. Present your findings in JSON format, using 'Evaluation' for your textual analysis and 'Score' for the numerical assessment.
4. Ensure objectivity in your evaluation. Avoid biases and strive for an even distribution of scores across the spectrum of quality. Your scoring must be as rigorous as possible and adhere to the following rules:
- Overall, the higher the quality of the model's response, the higher the score, with factual accuracy and meeting user needs being the most critical dimensions. These two factors largely dictate the final composite score.
- If the model's response is irrelevant to the question, contains fundamental factual errors, or generates harmful content, the total score must be 1.
- If the model's response has no severe errors and is essentially harmless, but of low quality and does not meet user needs, the total score should be 2.
- If the model's response generally meets user requirements but performs poorly in certain aspects with medium quality, the total score should be 3.
- If the model's response is close in quality to the reference answer and performs well in all dimensions, the total score should be between 4.
- Only when the model's response significantly surpasses the reference answer, fully addresses the user's problem and all needs, and nearly achieves a perfect score in all dimensions, can it receive a score between 5.
- As an example, the golden answer could receive an 4-5.
"""

    prompt = f"""
Here is the response for you to judge:
Question: {question}
Golden Answer: {golden_answer}
Response: {response}

Now, directly output your response in json format.
"""
    if model == 'gpt-4':
        result = GPT4(system_prompt, prompt, True)
    elif model == 'mixtral':
        result = mixtral_8x22b(system_prompt, prompt, True)
    # elif model == 'chatgpt':
    #     result = ChatGPT(system_prompt, prompt, True)
    # elif model == 'llama3-8b':
    #     result = llama3_8b(system_prompt, prompt, True)
    elif model == 'llama3':
        result = llama3_70b(system_prompt, prompt, True)
    return result
    

def get_mcqa_res(question, answer, response, model):
    system_prompt = """You are a helpful assistant tasked with judging a Multiple Choice Question Answering exercise. 
    I will provide a correct answer with only one option, and a response that requires evaluation. 
    If the response matches the correct answer, simply output "Yes"; If it does not, output "No" 
    Please avoid including any irrelevant information."""

    # Example 3: 
    # Question: If the user wants to resume the group video call after checking messages, what action should they take? A. Turn their head to the right. B. Close the messaging app interface. C. Say a voice command to switch applications. D. Turn their head to the left.
    # Answer: A
    # Response: B
    # Output: No
    
    # Example 4:
    # Question: What action does the user take to start playing music in the video? A. Closed the music player application B. Moved the music player to a new position C. Clicked the play button D. Adjusted the system volume
    # Answer: [[B]]
    # Response: C
    # Output: No
    prompt = f"""
    Here are some examples:
    Example 1:
    Question: What action is attempted after the error message 'Unfortunately, djay FREE has stopped' appears? A. The device's battery is low. B. The 'djay FREE' app is malfunctioning or has a bug. C. The motivational wallpaper is causing the app to crash. D. There's an incoming phone call that's interrupting the app.
    Answer: [[B]] The message is dismissed with the 'OK' button.
    Response: **Troubleshooting Steps for 'djay FREE has stopped' Error:**
    Output: No
    
    Example 2:
    Question: Based on the GUI video, why might the 'Loading' animation continue without reaching the next stage? A. The user has not yet entered their login credentials. B. There is a system update being installed. C. The server is taking time to authenticate the login credentials. D. The 'Log In' button is malfunctioning.
    Answer: C
    Response: C. The server is taking time to authenticate the login credentials.
    Output: Yes

    Here is the question, answer, and response for you to judge:
    Question: {question}
    Answer: {answer}
    Response: {response}
    
    Now, directly output "Yes" or "No".
    """
    
    if model == 'gpt-4':
        result = GPT4(system_prompt, prompt, False)
    elif model == 'mixtral':
        result = mixtral_8x22b(system_prompt, prompt, False)
    # elif model == 'chatgpt':
    #     result = ChatGPT(system_prompt, prompt)
    # elif model == 'llama3-8b':
    #     result = llama3_8b(system_prompt, prompt)
    elif model == 'llama3':
        result = llama3_70b(system_prompt, prompt)
    return result

def split_path(path):
    last_slash_index = path.rfind('/')
    if last_slash_index == -1:
        return path, ''
    before_last_slash = path[:last_slash_index]
    after_last_slash = path[last_slash_index + 1:]
    return before_last_slash, after_last_slash

def process_item(item, selected_keys, model, directly, output_file):
    for key, value in item['result'].items():
        if "MLLM_result" not in value.keys():
            continue
        if directly:
            if isinstance(value['MLLM_result'], str):
                try:
                    result = clean_and_convert_to_dict(value['MLLM_result'])
                    value['MLLM_result'] = {"Answer": result["Answer"]}
                except Exception as e:
                    print(f"Error parsing JSON: {e}")
                    result = find_answer(value['MLLM_result'])
                    if result is not None:
                        value['MLLM_result'] = {"Answer": result}
                    else:
                        value['MLLM_result'] = {"Answer": value['MLLM_result']}
            print(value['MLLM_result'])
        else:
            if isinstance(value['MLLM_result'], str):
                try:
                    result = clean_and_convert_to_dict(value['MLLM_result'])
                    if result == "Error parsing JSON after attempting to fix.":
                        continue
                    value['MLLM_result'] = result
                except Exception as e:
                    print(f"Error parsing JSON: {e}")
                    continue
            if "Answer" not in value['MLLM_result'].keys():
                continue
        if key not in selected_keys:
            result = get_mcqa_res(value['q'], value['a'], value['MLLM_result']['Answer'], model)
            if result is not None:
                value['LLM_Judge'] = result
        else:
            result = get_free_res(value['q'], value['a'], value['MLLM_result']['Answer'], model)
            if result is not None:
                value['LLM_Judge'] = result
    with open(output_file, 'a') as file:
        file.write(json.dumps(item) + '\n')
    return item

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--input", type=str, default=None)
    parser.add_argument("--output", type=str, default=None)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=0)
    parser.add_argument("--model", type=str, default='gpt-4')
    parser.add_argument("--directly", action='store_true', default=False)
    args = parser.parse_args()
    args.model = args.model.lower()
    assert args.model in ['gpt-4', 'mixtral', 'llama3']
    
    with open(args.input, 'r') as file:
        data = [json.loads(line) for line in file]
    if args.end == 0:
        args.end = len(data)
    if args.output == 'auto':
        base_path, filename = split_path(args.input)
        args.output = f"{base_path}/llm_judge_{args.model}_{filename[:-5]}_{args.start}_{args.end}.{filename[-5:]}"
    
    selected_keys = ["Sequential-QA", "Prediction", "Description1", "Description2", "Caption", "static QA", "Conversation1", "Conversation2"]

    with ThreadPoolExecutor(max_workers=8) as executor:
        futures = [executor.submit(process_item, item, selected_keys, args.model, args.directly, args.output) for item in data[args.start:args.end]]
        for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing data"):
            pass
