import argparse
import json
import os
import regex as re

import openai
from concurrent.futures import ThreadPoolExecutor, as_completed
from rouge_score import rouge_scorer
from tqdm import tqdm
from utils.openai_utils import create_chat_response_by_messages, create_json_mode_chat_n_response


##############################################
System_Message_Reasoning = """You are both an expert Matplotlib plotter and a professional maths teacher. Now, you are asked to generate a mathematical reasoning question about a given chart. This chart and question will be used as a question on this year's college admissions examination. As a question writer, you need to ensure that the question is challenging yet fair, testing the students' ability to analyze data, interpret trends, and apply mathematical concepts."""

Gen_Reasoning_Thought = """First, please read the following plotting script in Python, try to visualize the figure in your mind and to understand the meaning of the chart. After you've analyzed this chart, we'll start generating the associated question and answer.

Here are some tips for you:
1. The plotting script (including the code itself, data mapping and labels) is absolutely correct and you can trust it completely. 
2. The question and answer need to be based on the chart type, chart topic, and the given data. It can relate to the chart as a whole or to localized details, so you need to look closely.
3. The question should be challenging, requiring visual observation skills and mathematical reasoning skills. So you need to have an deep understanding of the chart.
4. If there is no data annotation in the figure, try not to generate questions that require too many numerical recognition to reduce inconsistent answers due to visual errors.
5. If some numerical recognition is needed, choose distinguishable colors, lines, heights, and other features that make it easy to estimate without data annotation.

Here is the plotting script:
```python
{code}
```

In this step, you don't need to generate specific questions and answers, just analyze the figure and list some ideas that can be used as questions."""

Gen_Reasoning_QA = """You're doing great! Now, please generate a Q&A pair about the figure you've just analyzed. This question will be used as a math reasoning question on this year's college admissions examination.

Here are some tips for you to generate the Q&A:
1. First and foremost, the question and answer need to be based on the chart and answerable by reasoning or calculation.
2. You don't need to describe the content of the figure in the question text. This can be left for students to think about.
3. This question needs to explicitly involve a final answer, the type of answer can be a certain number, a noun or Yes/No etc.
4. The answer should contain multiple reasoning or calculation steps and be presented in a understandable and educational paragraph.
5. NEVER include any information relating to the Python script in the question or answer, as students will ONLY have access to the plotted figure.

Now, you can start to generate a question and answer. Your output needs to follow the following JSON format:
{{"question": "<the question you generate>", "detail_answer": "<detail analysis and step-by-step solution of the question in a string>", "concise_answer": "<concise answer with key steps in a string>"}}"""


System_Message_Descriptive = """You are both an expert Matplotlib plotter and a professional maths teacher. Now, you are asked to generate a descriptive question about a given chart. This chart and question will be used as a question on this year's elementary math examination to test students' ability to read charts."""


Gen_Descriptive_QA = """First, please read the following plotting script in Python, try to visualize the figure in your mind and to understand the meaning of the chart. Then, you are asked to generate a descriptive question about a given chart. 

Here are some tips for you:
1. The plotting script (including the code itself, data mapping and labels) is absolutely correct and you can trust it completely. 
2. Descriptive questions are questions that can be answered based on basic chart information, such as titles, labels, tick marks, colors, etc.
3. The generated Q&A need to be based on the chart type and data. It should be answerable through visual observation.
4. If there is no data annotation in the figure, try not to generate questions that require too many numerical recognition to reduce inconsistent answers due to visual errors.
5. If some numerical recognition is needed, choose distinguishable colors, lines, heights, and other features that make it easy to estimate without data annotation.
6. You don't need to describe the content of the figure in the question text. This can be left for students to think about.
7. This question needs to explicitly involve a final answer, the type of answer can be a certain number, a noun or Yes/No etc.
8. NEVER include any information relating to the Python script in the question or answer, as students will ONLY have access to the plotted figure.

Here are some examples of descriptive questions:
- How many colors are used in the chart? How many city categories are in the chart?
- What's the lefmost value of bar in China? And what is the value of the bar next to it?
- For the subplot at row 2 and column 1, what is the minimum value of the solid line?
- Which name does the second largest sector represent? What is its value?
- Does the blue triangle in the chart represent a higher value than the red circle?

Here is the plotting script:
```python
{code}
```

Now, you can start to generate a question and answer. Use your imagination and creativity to generate more interesting questions. Your output needs to follow the following JSON format:
{{"question": "<the question you generate>", "detail_answer": "<understandable answer in one or more sentence>", "concise_answer": "<short answer of some words>"}}"""

##############################################

def extract_and_validate_json(input_str):
    # Use a regular expression to extract the JSON substring
    json_pattern = r'\{(?:[^{}]|(?R))*\}'
    json_match = re.search(json_pattern, input_str, re.DOTALL)
    
    if json_match:
        json_str = json_match.group(0)
        json_str = json_str.replace('\\', '\\\\')
        json_str = json_str.replace('\n', ' ').replace('\r', ' ')
        try:
            # Convert the JSON string to a dictionary
            temp_dict = json.loads(json_str)

            # Validate if the dictionary contains the required keys
            if 'question' in temp_dict and 'detail_answer' in temp_dict and 'concise_answer' in temp_dict:
                
                # Handle the format of detail_answer
                if isinstance(temp_dict['detail_answer'], list) and all(isinstance(item, str) for item in temp_dict['detail_answer']):
                    temp_dict['detail_answer'] = ' '.join(temp_dict['detail_answer'])
                elif not isinstance(temp_dict['detail_answer'], str):
                    print("Invalid response format. The 'detail_answer' must be a string or a list of strings.")
                    return None
                
                return temp_dict
            else:
                print("Invalid response format. The response does not contain all 3 required keys.")
                return None
            
        except json.JSONDecodeError:
            print("Failed to decode JSON.")
            return None
    else:
        print("No JSON found in the input string.")
        return None


def generate_instruction_data(
    model,
    client,
    data_path,
    QA_type="Reasoning",
    num_workers=5,
    num_instruction_per_plot=3,
):    
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
    output_file_path = os.path.join(data_path, f"all_instruction_data.jsonl")
    
    # Load existing data if available to find the last processed plot
    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]

    # load meta data
    with open(os.path.join(data_path, "plot_info.jsonl"), "r") as f:
        meta_data = [json.loads(line) for line in f]
        
    print(f"Loaded {len(meta_data)} collected plot(s)")
    
    # Skip already processed plots
    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["id"] == last_processed_plot_id:
                start_index = index + 1
                break
    
    print(f"Skipped {start_index} processed plot(s)")
    
    def process_plot_to_reasoning(plot):
        code_file_path = os.path.join(data_path, plot["code"])
        with open(code_file_path, "r") as f:
            code = f.read()
        
        gen_inst_messages = [
            {"role": "system", "content": System_Message_Reasoning},
            {"role": "user", "content": Gen_Reasoning_Thought.format(code=code)},
        ]
        
        ### Call LLM to generate instruction
        print("\nCalling OpenAI for Generate Thought...")
        reasoning_thought_output = create_chat_response_by_messages(
            model=model,
            client=client,
            messages=gen_inst_messages,
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
        )
        gen_inst_messages.extend([
            {"role": "assistant", "content": reasoning_thought_output},
            {"role": "user", "content": Gen_Reasoning_QA}
        ])
        
        response_list = create_json_mode_chat_n_response(
            model=model,
            client=client,
            messages=gen_inst_messages,
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
            n=num_instruction_per_plot,
        )
        
        ### Filter with ROUGE-L
        new_instructions = []
        for response in response_list:
            new_dict = extract_and_validate_json(response) 
            if new_dict is None:
                print(f"Warning: extract_and_validate_json returned None for response: {response}")
                continue

            is_duplicate = False
            for existing in new_instructions:
                if existing is None:
                    continue
                score = scorer.score(new_dict.get("question", ""), existing.get("question", ""))
                if score["rougeL"].fmeasure > 0.7:
                    is_duplicate = True
                    print(f"Duplicate instruction found: {new_dict['question']}")
                    break
            if not is_duplicate:
                new_instructions.append(new_dict)

        return new_instructions, plot
    
    def process_plot_to_descriptive(plot):
        code_file_path = os.path.join(data_path, plot["code"])
        with open(code_file_path, "r") as f:
            code = f.read()
        
        gen_inst_messages = [
            {"role": "system", "content": System_Message_Descriptive},
            {"role": "user", "content": Gen_Descriptive_QA.format(code=code)},
        ]
        
        ### Call LLM to generate instruction
        print("\nCalling OpenAI for Generate QA...")
        response_list = create_json_mode_chat_n_response(
            model=model,
            client=client,
            messages=gen_inst_messages,
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
            n=num_instruction_per_plot,
        )
        
        ### Filter with ROUGE-L
        new_instructions = []
        for response in response_list:
            new_dict = extract_and_validate_json(response) 
            if new_dict is None:
                print(f"Warning: extract_and_validate_json returned None for response: {response}")
                continue

            is_duplicate = False
            for existing in new_instructions:
                if existing is None:
                    continue
                score = scorer.score(new_dict.get("question", ""), existing.get("question", ""))
                if score["rougeL"].fmeasure > 0.7:
                    is_duplicate = True
                    print(f"Duplicate instruction found: {new_dict['question']}")
                    break
            if not is_duplicate:
                new_instructions.append(new_dict)

        return new_instructions, plot
    
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        if QA_type == "Reasoning":
            futures = {executor.submit(process_plot_to_reasoning, plot): plot for plot in meta_data[start_index:]}
        elif QA_type == "Descriptive":
            futures = {executor.submit(process_plot_to_descriptive, plot): plot for plot in meta_data[start_index:]}
        else:
            raise ValueError("Invalid QA_type")
        
        # Open the file in append mode
        with open(output_file_path, "a") as f:
            for future in tqdm(as_completed(futures), total=len(futures)):
                new_instructions, plot = future.result()
                for instruction in new_instructions:
                    sample = {
                        "plot_id": plot["id"],
                        "image": plot["image"],
                        "code": plot["code"],
                        "plot_level": plot["level"],
                        "plot_model": plot["plot_model"],
                        "major_chart_type": plot["major_chart_type"],
                        "minor_chart_type": plot["minor_chart_type"],
                        "QA_type": QA_type,
                        "QA_model": model,
                        "question": instruction["question"],
                        "detail_answer": instruction["detail_answer"],
                        "concise_answer": instruction["concise_answer"],
                    }
                    f.write(json.dumps(sample) + "\n")
                

def arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="gpt-4-turbo-2024-04-09")
    
    parser.add_argument("--data_path", type=str, default="./data/reachqa_train")
    parser.add_argument("--num_instruction_per_plot", type=int, default=3)
    parser.add_argument("--QA_type", type=str, default="Reasoning", choices=["Reasoning", "Descriptive"])
    parser.add_argument("--num_workers", type=int, default=5)
    
    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parser()
    print(args)

    openai_key = "key"
    openai_client = openai.OpenAI(api_key=openai_key, base_url="url")

    generate_instruction_data(
        model=args.model_name,
        client=openai_client,
        data_path=args.data_path,
        QA_type=args.QA_type,
        num_workers=args.num_workers,
        num_instruction_per_plot=args.num_instruction_per_plot,
    )

