import argparse
import json
import os
import random
import time
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
import signal
import re

from tqdm import tqdm
from openai import OpenAI

from chart_notes import get_chart_note

os.chdir(os.path.dirname(os.path.realpath(__file__)))
##############################################
def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

# Read prompt content from external text files
Self_Instruct_Prompt_Data = read_prompt_from_file("./prompt/Self_Instruct_Prompt_Data.txt")
Self_Instruct_Prompt_Code = read_prompt_from_file("./prompt/Self_Instruct_Prompt_Code.txt")
Evol_Instruct_Prompt_Thought = read_prompt_from_file("./prompt/Evol_Instruct_Prompt_Thought.txt")
Evol_Instruct_Prompt_Code = read_prompt_from_file("./prompt/Evol_Instruct_Prompt_Code.txt")
Code_Fix_Prompt = read_prompt_from_file("./prompt/Code_Fix_Prompt.txt")

##############################################

# From ChartBench & ChartX, 10 major, 32 minor
Type_of_Chart = {
    "Line Charts": ["line chart"],
    #"Pie Charts": ["pie chart", "donut pie chart", "sector pie chart", "ring chart"],
    "Bar Charts": ["bar chart","horizontal bar chart"],
    #"3D Bar Charts": ["3D bar chart", "stacked 3D bar chart", "percentage 3D bar chart"],
    #"Node Charts": ["directed node chart"],##, "undirected node chart"],
    "Radar Charts": ["radar chart", "radar chart with area filling"],
    "Area Charts": ["area chart", "stacked area chart"],
    "Box Charts": ["vertical box chart", "horizontal box chart"],
    "Scatter Charts": ["scatter chart", "scatter chart with smooth fitting"],#,"3D scatter chart (bubble chart)"],
    #"Specific Charts": ["heat map", "rose chart", "funnel chart", "waterfall chart", "histogram", "tree map"],
}

# From OneChart & ChartX, 36 topics
Topic_of_Chart = [
    "Business and Finance",
    "Healthcare and Health",
    "Science and Engineering",
    "Social Media and the Web",
    "Government and Public Policy",
    "Education and Academics",
    "Environment and Sustainability",
    "Retail and E-commerce",
    "Human Resources and Employee Management",
    "Agriculture and Food Production",
    "Energy and Utilities",
    "Transportation and Logistics",
    "Real Estate and Housing Market",
    "Manufacturing and Production",
    "Sports and Entertainment",
    "Social Sciences and Humanities",
    "Law and Legal Affairs",
    "Food and Beverage Industry",
    "History and Culture",
    "Society and Community",
    "Art and Design",
    "Travel and Exploration",
    "Religion and Spirituality",
    "Language and Communication",
    "Fashion and Style",
    "Music and Performance",
    "Film and Cinema",
    "Literature and Writing",
    "Architecture and Building",
    "Mathematics and Statistics",
    "Physics and Chemistry",
    "Biology and Life Sciences",
    "Astronomy and Space",
    "Computer Science and Information Technology",
    "Marketing and Advertising",
    "Futurism and Innovation",
    "Books and Publishing",
    "Artificial Intelligence and Robotics",
]

Evol_Direction = [
    "Increase the size of the input data or the number of data groups as appropriate so that it requires a higher level of mathematical understanding. Note if there is a sum requirement.",
    "Try changing or adding some visual elements to make visual effect better. The elements you add must make sense and not be redundant.",
    "Incorporate an overlay plot of a different type on the original chart. Use related but not identical data for the added plot.",
    #"Extend an additional subplot of a different type beside the original chart (2 in total). Use related but not identical data for the added plot.",
]

session = requests.Session()

##############################################

def post_process_model_response(response):
    if response is None:
        print("Drop out: empty response")
        return None

    code_blocks = extract_python_code_block(response)
    if len(code_blocks) != 1:
        print("Drop out: multiple code blocks found in the response")
        return None
    elif len(code_blocks[0].split("\n")) < 10 or len(code_blocks[0].split("\n")) > 150:
        print("Drop out: code block length out of range")
        return None
    else:
        return code_blocks[0]

def extract_python_code_block(s):

    pattern = r"(?i)```python(.*?)```"
    code_blocks = re.findall(pattern, s, re.DOTALL)

    if code_blocks == []:
        pattern = r"```(.*?)```"
        code_blocks = re.findall(pattern, s, re.DOTALL)

    code_blocks = [code_block.strip() for code_block in code_blocks]
    return code_blocks

# Define timeout handling function
def handler(signum, frame):
    raise TimeoutError("Execution timed out")

# Improved simulate_code function, does not get execution results
def simulate_code(code_to_run):
    try:
        # Set signal handler and timeout
        original_handler = signal.signal(signal.SIGALRM, handler)
        signal.alarm(2)

        # Execute code in independent namespace to avoid polluting global namespace
        exec_globals = {}
        exec_locals = {}
        exec(code_to_run, exec_globals, exec_locals)

        signal.alarm(0)
        signal.signal(signal.SIGALRM, original_handler)

        return True, "Code executed successfully"

    except TimeoutError as te:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, original_handler)
        return False, str(te)
    except Exception as e:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, original_handler)
        return False, str(e)
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, original_handler)


##############################################

def read_files_with_pattern(data_path, pattern=r"\d{3}\.py$"):
    import os
    import re

    data = []
    files = os.listdir(data_path)

    for file in sorted(files):
        if re.match(pattern, file):
            file_path = os.path.join(data_path, file)
            with open(file_path, "r", encoding="utf-8") as f:
                data.append(f.read())

    return data


def save_files_with_pattern(data_path, data, meta, pattern=r"\d{5}\.py$"):
    import os
    import re

    os.makedirs(data_path, exist_ok=True)
    meta_file = os.path.join(os.path.dirname(data_path), "all_info.jsonl")

    files = [f for f in os.listdir(data_path) if re.match(pattern, f)]

    base_number = 0
    if files:
        numbers = [int(re.search(r"\d+", f).group()) for f in files]
        base_number = max(numbers)

    for idx, file_data in enumerate(data):
        save_number = str(base_number + idx + 1).zfill(5)
        file_name = f"{save_number}.py"
        file_path = os.path.join(data_path, file_name)
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(file_data)

        meta_info = {
            "id": f"reachqa-train-plot-{save_number}",
            "code": f"all_code/{save_number}.py",
            "image": None,
            "level": meta[idx]["level"],
            "plot_model": meta[idx]["plot_model"],
            "major_chart_type": meta[idx]["major_chart_type"],
            "minor_chart_type": meta[idx]["minor_chart_type"],
        }

        with open(meta_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(meta_info, ensure_ascii=False) + "\n")

        print(f"Saved to {file_name}")


def create_chat_response_by_messages(
        model,
        client,
        messages,
        max_tokens,
        temperature,
        top_p, ):
    t1 = time.time()

    message = client.chat.completions.create(
        model=model,
        messages=messages
    )

    t2 = time.time()
    print('########################### result, time:', t2 - t1)
    return message.choices[0].message.content

def generate_code_data(
        model,
        client,
        seed_data_path="./seed_tasks.jsonl",
        output_dir="./",
        num_data_to_generate=1000,
        num_demo_data=3,
        save_easy_data=True,
        request_batch_size=2,
        num_workers=1,
        chart_type="Bar Charts",
):
    # load the seed data
    seed_data = read_files_with_pattern(data_path=seed_data_path, pattern=r"\d+\.py$")
    print(f"Loaded {len(seed_data)} collected seed code")

    os.makedirs(output_dir, exist_ok=True)
    random.seed(42)

    # load existing generated data (if any, to avoid duplication)
    if os.path.exists(os.path.join(output_dir, "00001.py")):
        llm_generated_data = read_files_with_pattern(data_path=output_dir, pattern=r"\d+\.py$")
        print(f"Loaded {len(llm_generated_data)} llm-generated data")
    else:
        llm_generated_data = []

    # now let's generate new data!
    all_code = seed_data + llm_generated_data
    progress_bar = tqdm(total=num_data_to_generate)
    if llm_generated_data:
        progress_bar.update(len(llm_generated_data))

    def process_code_generation(meta_info):
        ### Step 1: Self-Instruct Generation
        print("\nCalling LLM for Self-Instruct-Data...")
        self_instruct_data_output = create_chat_response_by_messages(
            model=model,
            client=client,
            messages=meta_info["self_message"],
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
        )
        meta_info["self_message"].extend([
            {"role": "assistant", "content": self_instruct_data_output},
            {"role": "user",
             "content": Self_Instruct_Prompt_Code.format(note=get_chart_note(meta_info["minor_chart_type"]))}
        ])
        print("\nCalling LLM for Self-Instruct-Code...")
        self_instruct_code_output = create_chat_response_by_messages(
            model=model,
            client=client,
            messages=meta_info["self_message"],
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
        )
        self_code = post_process_model_response(self_instruct_code_output)
        meta_info["self_code"] = self_code
        evol_code_choice=False

        if evol_code_choice:
        ## Step 2: Evaluate the generated content
            self_code = post_process_model_response(self_instruct_code_output)
            if self_code is not None:
                meta_info["self_code"] = self_code
                meta_info["evol_message"] = [
                    {"role": "system", "content": "You are a skilled MatplotLib expert."},
                    {"role": "user", "content": Evol_Instruct_Prompt_Thought.format(code=self_code, direction=random.choice(
                        Evol_Direction))},
                ]
            else:
                return None

            print("\nCalling LLM for Evol-Instruct-Thought...")
            evol_instruct_thought_output = create_chat_response_by_messages(
                model=model,
                client=client,
                messages=meta_info["evol_message"],
                max_tokens=8192,
                temperature=1.0,
                top_p=0.95,
            )
            # print('####################evol_instruct_thought_output', evol_instruct_thought_output)
            meta_info["evol_message"].extend([
                {"role": "assistant", "content": evol_instruct_thought_output},
                {"role": "user", "content": Evol_Instruct_Prompt_Code}
            ])
            print("\nCalling LLM for Evol-Instruct-Code...")
            evol_instruct_code_output = create_chat_response_by_messages(
                model=model,
                client=client,
                messages=meta_info["evol_message"],
                max_tokens=8192,
                temperature=1.0,
                top_p=0.95,
            )
            # print('####################evol_code', evol_instruct_code_output)
            evol_code = post_process_model_response(evol_instruct_code_output)
            if evol_code is not None:
                meta_info["evol_code"] = evol_code

        return meta_info

    # multithreaded code fix
    def fix_code(meta_info):
        ### Step 3: Filter with Execution
        print("\nCalling LLM for Code Fix...")
        fix_output = create_chat_response_by_messages(
            model=model,
            client=client,
            messages=[
                {"role": "system", "content": "You are a skilled Python and MatplotLib expert."},
                {"role": "user",
                 "content": Code_Fix_Prompt.format(code=meta_info["final_code"], error=meta_info["error"])},
            ],
            max_tokens=8192,
            temperature=1.0,
            top_p=0.95,
        )

        fix_code = post_process_model_response(fix_output)
        if fix_code is not None:
            meta_info["final_code"] = fix_code

        return meta_info

    # Main loop for generating data
    
    for major, minor_list in Type_of_Chart.items():    
        major_chart_type = major            
        llm_generated_data=[]

        num = num_data_to_generate
        if major_chart_type == 'Bar Charts' or major_chart_type == 'Line Charts':
            num = int(num_data_to_generate * 2)

        while len(llm_generated_data) < num:

            minor_chart_type = random.choice(Type_of_Chart[major_chart_type])
        #for minor_chart_type in minor_list:

            meta_info_list = []
            # construct the meta infomation for each data
            for _ in range(request_batch_size):
                demo_codes = random.sample(all_code, num_demo_data)  # sampling from the seed data + generated data
                
                select_topics = random.sample(Topic_of_Chart, 2)

                messages = [
                    {"role": "system", "content": "You are a skilled MatplotLib expert."},
                    {
                        "role": "user",
                        "content": Self_Instruct_Prompt_Data.format(
                            type=minor_chart_type,
                            topic1=select_topics[0],
                            topic2=select_topics[1],
                            demo1=demo_codes[0],
                            demo2=demo_codes[1],
                            demo3=demo_codes[2],
                        ),
                    },
                ]
                meta_info_list.append(
                    {"major_chart_type": major_chart_type, "minor_chart_type": minor_chart_type, "self_message": messages,
                    "plot_model": model}
                )

            # Start the parallel processing
            meta_to_check_list = []
            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                futures = {executor.submit(process_code_generation, meta_info): meta_info for meta_info in meta_info_list}
                for future in tqdm(as_completed(futures), total=len(futures)):
                    result = future.result()
                    if result is not None:
                        if save_easy_data and "self_code" in result:
                            easy_item = result.copy()
                            easy_item["level"] = "Easy"
                            easy_item["final_code"] = easy_item["self_code"]
                            meta_to_check_list.append(easy_item)

                        if "evol_code" in result:
                            hard_item = result.copy()
                            hard_item["level"] = "Hard"
                            hard_item["final_code"] = hard_item["evol_code"]
                            meta_to_check_list.append(hard_item)

            # Check the code with execution
            for meta in meta_to_check_list:
                success, error = simulate_code(meta["final_code"])
                print('########################### simulate_code', success, error)
                if success:
                    meta["success"] = True
                else:
                    meta["success"] = False
                    meta["error"] = error

            meta_to_save_list = [meta for meta in meta_to_check_list if meta["success"]]
            meta_to_check_list = [meta for meta in meta_to_check_list if not meta["success"]]

            # Start the parallel processing
            max_attempts = 2
            for _ in range(max_attempts):
                if len(meta_to_check_list) == 0:
                    break

                meta_after_fix = []
                with ThreadPoolExecutor(max_workers=num_workers) as executor:
                    futures = {executor.submit(fix_code, meta_info): meta_info for meta_info in meta_to_check_list}
                    for future in tqdm(as_completed(futures), total=len(futures)):
                        result = future.result()
                        meta_after_fix.append(result)
                # for meta_info in meta_to_check_list:
                #     print('##############################################fix_code')
                #     result = fix_code(meta_info)
                #     meta_after_fix.append(result)

                for meta in meta_after_fix:
                    success, error = simulate_code(meta["final_code"])
                    if success:
                        meta["success"] = True
                    else:
                        meta["success"] = False
                        meta["error"] = error

                meta_to_save_list.extend([meta for meta in meta_after_fix if meta["success"]])
                meta_to_check_list = [meta for meta in meta_after_fix if not meta["success"]]

            code_to_save = [meta["final_code"] for meta in meta_to_save_list]
            llm_generated_data.extend(code_to_save)
            all_code.extend(code_to_save)
            progress_bar.update(len(code_to_save))

            print(f"For the batch_size of {request_batch_size}, kept {len(code_to_save)} code!")
            save_files_with_pattern(data_path=output_dir, data=code_to_save, meta=meta_to_save_list, pattern=r"\d+\.py$")


def arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str)

    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--seed_data_path", type=str)
    parser.add_argument("--num_data_to_generate", type=int, default=1000)
    parser.add_argument("--num_demo_data", type=int, default=3)
    parser.add_argument("--save_easy_data", type=bool, default=True)

    parser.add_argument("--request_batch_size", type=int, default=50)
    parser.add_argument("--num_workers", type=int, default=30)
    parser.add_argument("--chart_type", type=str, default="Bar Charts")

    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parser()
    print(args)

    openai_api_key = "EMPTY"

    openai_client = OpenAI(
        api_key=openai_api_key,
    )

    generate_code_data(
        model=args.model_name,
        client=openai_client,
        seed_data_path=args.seed_data_path,
        output_dir=args.output_dir,
        num_data_to_generate=args.num_data_to_generate,
        num_demo_data=args.num_demo_data,
        save_easy_data=args.save_easy_data,
        request_batch_size=args.request_batch_size,
        num_workers=args.num_workers,
        chart_type=args.chart_type
    )
