import json
import re
import random 
import sys
# from load_prompt import read_prefix
from ast import literal_eval
import traceback
import os
import signal
import time
import importlib
import inspect
from typing import List, Dict, Any, Optional
import subprocess
import uuid
import tempfile

class DataProcessor:
    def __init__(self, file_path, save_path, seed=None):
        if seed is not None:
            random.seed(seed)
        self.file_path = file_path
        self.save_path = save_path
        self.filter_raw = []
        self.raw_dataset = []
        self.processed_data = []
        self.save_temp_code_dir = "./python_code/"
        # self.prefix = read_prefix(prefix_path)
    
    def save_data(self):
        def default_skip(o):
            return f"<Non-serializable: {type(o).__name__}>" 
    
        with open(self.save_path, "w", encoding="utf-8") as f:
            json.dump(
                self.processed_data,
                f,
                ensure_ascii=False,
                indent=2,
                default=default_skip,  # 遇到无法序列化的对象时调用此函数
            )
        with open(f"/Users/zhoujt/Desktop/Encyclobench/datasets/infer_data/filter_raw.jsonl", "a", encoding="utf-8") as f:
            json.dump(self.filter_raw, f, ensure_ascii= False, indent=2, default=default_skip)
    
    def load_data(self):
        with open(self.file_path, "r") as f:
            for line in f:
                try:
                    self.raw_dataset.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    print(f"Error parsing line: {line}")
                    continue
    
    def replace_variables(self, text, variables):
        return re.sub(
            r'\$(.*?)\$',
            lambda match: str(variables.get(match.group(1), match.group(0))), 
            text
        )
    def process_init_nums(self, init_nums_str):
        init_nums_str = init_nums_str.strip(" []")
        expressions = [expr.strip() for expr in init_nums_str.split(",")]
        nums = []
        
        for expr in expressions:
            num = eval(expr)
            nums.append(num)
        return nums
    
    def get_question_answer(self, result_python_code: str, initialized_variables: Dict[str, Any]):
            script_uuid = uuid.uuid4().hex
            script_filename = f"script_{script_uuid}.py"
            script_path = os.path.join(self.save_temp_code_dir, script_filename)

            with open(script_path, 'w', encoding='utf-8') as f:
                f.write(result_python_code)
            try:
                
                sys.path.insert(0, self.save_temp_code_dir)
                module = importlib.import_module(f"script_{script_uuid}")
                func = self._find_single_function(module)
                args = self._prepare_arguments(func, initialized_variables)

                def handler(signum, frame):
                    raise Exception("Time Limit Error when processing the function")
                def run_function(func, args, time_limit):
                    signal.signal(signal.SIGALRM, handler=handler)
                    signal.alarm(time_limit)
                    try:
                        result = func(**args)
                        if isinstance(result, (list, tuple)):
                            processed = list(result)
                        else:
                            processed = [result]
                        round_result = []
                        for item in processed:
                            if isinstance(item,(int, float)):
                                if item > 1e-4:
                                    round_result.append(round(item,4))
                            else:
                                raise Exception("init type error")
                                time.sleep(10)
                        signal.alarm(0)
                        return round_result
                    except Exception as e:
                        signal.alarm(0)
                        raise e
                return run_function(func, args, 10)

            except Exception as e:
                print(f"An error occur, the script is saved to: {script_path}")
                raise
            finally:
                self._cleanup_resources(module, script_path)

    def _find_single_function(self, module) -> callable:
        functions = [
            obj for name, obj in inspect.getmembers(module)
            if inspect.isfunction(obj) 
            and obj.__module__ == module.__name__  
        ]
        if len(functions) != 1:
            raise ValueError("Only one function at most can be included in question")
        return functions[0]
    
    def _prepare_arguments(self, func: callable, variables: Dict[str, Any]) -> Dict[str, Any]:
        sig = inspect.signature(func)
        args = {}
        
        for param in sig.parameters.values():
            if param.name not in variables:
                raise ValueError(f"Missing required parameters: {param.name}")
            args[param.name] = variables[param.name]
        return args

    def _cleanup_resources(self, module, script_path):
        if module:
            del sys.modules[module.__name__]
        sys.path.remove(self.save_temp_code_dir)
        os.remove(script_path)
    
    def remove_code_block_markers(self, text: str):
        lines = text.split("\n")
        return "\n".join(line for line in lines if not line.strip().startswith("```"))
    
    
    
    def generate_questions(self):
        for question in self.raw_dataset:
        
            base_prompt = question["instruction"]
            difficulty = question["Difficulty"]
            try:
                arguments = self.parse_arguments(question["Arguments"].split('\n'))
                python_code = question["function"]
                init_nums = self.process_init_nums(question["input"])
                
                # Handle initial input
                init_vars = self.generate_variables(arguments, init_nums)
                init_prompt = self.replace_variables(base_prompt, init_vars)
                random_vars = self.generate_variables(arguments)
                random_prompt = self.replace_variables(base_prompt, random_vars)
                
                init_answer = self.get_question_answer(self.remove_code_block_markers(python_code), init_vars)
                random_answer = self.get_question_answer(self.remove_code_block_markers(python_code), random_vars)
                if init_answer == None or random_answer == None:
                    pass
                self.processed_data.append({
                    "metadata": {
                        "python_code": python_code,
                        "arguments": arguments,
                        "difficulty": difficulty,
                    },
                    "prompts": {
                        "random_vars_prompt": random_prompt,
                        "init_vars_prompt": init_prompt,
                    },
                    "variables": {
                        "random_vars": random_vars,
                        "init_vars": init_vars
                    },
                    "answers":{
                        "init_answer": init_answer,
                        "random_answer": random_answer
                    }
                })
                self.filter_raw.append(question)
            except:
                print(f"Error processing question: {base_prompt}")
                print("Error details:")
                print(traceback.format_exc())
                print("##########################################")

        self.save_data()
    def generate_variables(self, arguments: list, init_input = None) -> dict:
        
        variables = {}
        
        if init_input is not None:
            for i, arg in enumerate(arguments):
                if i < len(init_input):
                    variables[arg["name"]] = init_input[i]
            return variables
        
        else:
            for arg in arguments:
                var_info = self.parse_variable_info(arg)
                variables[arg["name"]] = self.generate_variable_value(var_info)
            
            return variables
            
    def parse_variable_info(self, arg):
        
        var_type = arg.get("type", "float").lower()
        var_range = arg.get("range", "1,1000").replace(' ', '')
        
        var_range = var_range.replace("inf", "1000")
        

        var_range = var_range.strip('()')
        lower, upper = map(float, var_range.split(','))

        return {
            'type': var_type, 
            'lower': lower,
            'upper': upper
        }
    def generate_variable_value(self, var_info):
        if var_info["type"] == 'int':
            return random.randint(int(var_info['lower']), int(var_info["upper"]))
        return round(random.uniform(var_info['lower'], var_info['upper']), 4)
        
        
    def parse_arguments(self, arguments_per_question):
        processed_arguments_per_question = []
        for args in arguments_per_question:
            argument_info = args.strip(' ').split(';')
            argument_info_dict = {}
            for item in argument_info:
                if ":" in item:
                    key, value = item.split(":", 1)
                    argument_info_dict[key.strip()] = value.strip()
            processed_arguments_per_question.append(argument_info_dict)
        return processed_arguments_per_question
            
                
                
if __name__ == "__main__":
    file_path = "./datasets/SciDA_v1.jsonl"
    save_path = "./datasets/SciDA_v1.jsonl"
    dataProcessor = DataProcessor(file_path, save_path)
    dataProcessor.load_data()
    dataProcessor.generate_questions()

    
    