import openai
import concurrent.futures
import time
import pandas as pd
import json
import random
import requests
from typing import List, Dict, Any
import os
import tqdm
import re
import ast
import fcntl
import argparse
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

def parse_args():
    parser = argparse.ArgumentParser(description='运行推理和验证任务')
    parser.add_argument('--hosts', type=str, nargs='+', required=True,
                       help='服务器host列表，例如: --hosts 0.0.0.0 1.2.3.4 5.6.7.8')
    parser.add_argument('--port', type=int, default=30000,
                       help='服务端口号，默认为30000')
    parser.add_argument('--num-workers', type=int, default=64,
                       help='并发worker数量，默认为64')
    parser.add_argument('--sample-num', type=int, default=100,
                       help='每个样本的采样次数，默认为100')
    parser.add_argument('--input-files', type=str, nargs='+', required=True,
                       help='输入的parquet文件路径列表，例如: --input-files file1.parquet file2.parquet')
    parser.add_argument('--timeout', type=int, default=30,
                       help='请求超时时间（秒），默认为30秒')
    parser.add_argument('--max-retries', type=int, default=3,
                       help='请求失败最大重试次数，默认为3次')
    return parser.parse_args()

# 全局配置
args = parse_args()
port = args.port
hosts = args.hosts

# 配置重试策略
retry_strategy = Retry(
    total=args.max_retries,
    backoff_factor=1,
    status_forcelist=[429, 500, 502, 503, 504],
)

# 创建session并配置重试策略
def create_session():
    session = requests.Session()
    adapter = HTTPAdapter(max_retries=retry_strategy)
    session.mount("http://", adapter)
    session.mount("https://", adapter)
    return session

def get_random_client():
    host = random.choice(hosts)
    return openai.Client(
        base_url=f"http://{host}:{port}/v1",
        api_key="None",
        timeout=args.timeout
    )

class ValidationManager:
    def __init__(self, run_all_cases=True):
        self.run_all_cases = run_all_cases
    
    def validate_response(self, code_str: str, sample: dict) -> bool:
        """验证生成的代码是否通过测试用例"""
        submissions = self.build_oj_submissions(code_str, sample)
        success_list = self.submit_batch(submissions)
        return all(success_list)
    
    def build_oj_submissions(self, code_str: str, sample: dict) -> list:
        # 这里实现与error_loc.py中相同的验证逻辑
        sample_size = 10
        submissions = []
        
        def build_submission(code_str, output_str=None, input_str=None) -> dict:
            submission = {
                "type": "python",
                "solution": code_str,
            }
            if output_str is not None:
                submission["expected_output"] = output_str
            if input_str is not None:
                submission["input"] = input_str
            return submission

        try:
            # 确保 input_output 不为 None
            if sample["input_output"] is None:
                print("Warning: input_output is None")
                return [build_submission(code_str)]
            
            # 解析 input_output，可能需要多次解析
            input_output = sample["input_output"]
            parse_attempts = 0
            max_attempts = 3  # 最多尝试解析3次
            
            while isinstance(input_output, str) and parse_attempts < max_attempts:
                try:
                    input_output = json.loads(input_output)
                    parse_attempts += 1
                except json.JSONDecodeError as e:
                    print(f"Warning: Failed to parse input_output as JSON (attempt {parse_attempts + 1}): {e}")
                    break

            # 确保 input_output 是字典类型
            if not isinstance(input_output, dict):
                print(f"Warning: input_output must be a dict after {parse_attempts} parse attempts, got {type(input_output)}")
                # 最后一次尝试：检查是否是字符串化的字典表示
                if isinstance(input_output, str) and input_output.strip().startswith('{') and input_output.strip().endswith('}'):
                    try:
                        # 使用ast.literal_eval更安全地解析字符串
                        input_output = ast.literal_eval(input_output)
                        if not isinstance(input_output, dict):
                            return [build_submission(code_str)]
                    except:
                        return [build_submission(code_str)]
                else:
                    return [build_submission(code_str)]

            # 处理test_cases的情况
            if "test_cases" in input_output:
                test_cases = input_output["test_cases"]
                if not isinstance(test_cases, list):
                    print(f"Warning: test_cases must be a list, got {type(test_cases)}")
                    # 尝试将test_cases转换为列表
                    if isinstance(test_cases, str):
                        try:
                            test_cases = json.loads(test_cases)
                            if not isinstance(test_cases, list):
                                print("Warning: Parsed test_cases is not a list")
                                return [build_submission(code_str)]
                        except json.JSONDecodeError:
                            print("Warning: Failed to parse test_cases as JSON")
                            return [build_submission(code_str)]
                    else:
                        return [build_submission(code_str)]

                if not test_cases:
                    print("Warning: test_cases is empty")
                    return [build_submission(code_str)]

                if len(test_cases) > sample_size:
                    test_cases = random.sample(test_cases, sample_size)
                
                for test_case in test_cases:
                    if isinstance(test_case, dict):
                        # 如果test_case是字典，尝试提取input和output
                        input_str = str(test_case.get('input', ''))
                        output_str = str(test_case.get('output', ''))
                        submissions.append(build_submission(code_str, output_str, input_str))
                    else:
                        # 如果test_case不是字典，按原样处理
                        submissions.append(build_submission(code_str + "\n" + str(test_case)))

            # 处理inputs/outputs的情况
            else:
                inputs = input_output.get("inputs", [])
                outputs = input_output.get("outputs", [])
                
                # 确保inputs和outputs是列表
                if not isinstance(inputs, list) or not isinstance(outputs, list):
                    print("Warning: inputs or outputs is not a list")
                    return [build_submission(code_str)]
                
                # 确保inputs和outputs长度相同且不为空
                if not inputs or not outputs or len(inputs) != len(outputs):
                    print("Warning: inputs and outputs must be non-empty and have same length")
                    return [build_submission(code_str)]

                if len(inputs) > sample_size:
                    indices = random.sample(range(len(inputs)), sample_size)
                    inputs = [inputs[i] for i in indices]
                    outputs = [outputs[i] for i in indices]
                
                for input_data, output_data in zip(inputs, outputs):
                    submissions.append(build_submission(code_str, str(output_data), str(input_data)))

        except Exception as e:
            import traceback
            print(f"Error building submissions: {e}")
            print(traceback.format_exc())
            return [build_submission(code_str)]

        if not submissions:
            print("Warning: No submissions were created")
            return [build_submission(code_str)]

        return submissions

    def submit_batch(self, submissions: list) -> list:
        data = {
            "type": "batch",
            "submissions": submissions
        }

        def write_data_to_json(file_path, data):
            try:
                with open(file_path, 'a') as f:
                    # 获取文件锁
                    fcntl.flock(f.fileno(), fcntl.LOCK_EX)
                    try:
                        json.dump(data, f, indent=4)
                        f.write('\n')  # 添加换行符以分隔不同的记录
                    finally:
                        # 释放文件锁
                        fcntl.flock(f.fileno(), fcntl.LOCK_UN)
            except IOError as e:
                print(f"Failed to write to file: {e}")

        try:
            # 随机选择一个host
            selected_host = random.choice(hosts)
            session = create_session()
            response = session.post(
                f"http://{selected_host}:8005/judge/long-batch",
                json=data,
                timeout=args.timeout
            )
            response.raise_for_status()

            results = response.json()['results']
            success_list = [res['success'] for res in results]
            assert len(success_list) == len(submissions)
            return success_list
        except requests.exceptions.RequestException as e:
            print(f"Request failed: {e}")
            return [False] * len(submissions)
        except (ValueError, KeyError, AssertionError) as e:
            print(f"Failed to process response: {e}")
            return [False] * len(submissions)
        finally:
            if 'session' in locals():
                session.close()

def sanitize(text: str) -> str:
    # Remove the starting and ending ```
    pattern = r"```(?:python)?\s*([\s\S]*?)\s*```"
    match = re.search(pattern, text, re.IGNORECASE)

    if match:
        return match.group(1).strip()
    return text

def check_ce(code_str: str) -> bool:
        
    if not isinstance(code_str, str):
        return True
    try:
        ast.parse(code_str)
        return False
    except:
        return True

def save_result_to_jsonl(result: Dict, output_file: str):
    """将单条结果保存到jsonl文件"""
    with open(output_file, 'a') as f:
        json.dump(result, f, ensure_ascii=False)
        f.write('\n')

def process_single_sample(row: pd.Series, validator: ValidationManager, output_file: str, sample_num: int = 100) -> None:
    try:
        # 创建这个样本的采样进度条
        pbar = tqdm.tqdm(
            range(sample_num),
            desc=f"Sample {row['id']}",
            position=1,
            leave=False
        )
        
        # 将prompt从ndarray转换为list
        prompt = row['prompt'].tolist() if hasattr(row['prompt'], 'tolist') else row['prompt']
        
        
        # 对每个样本进行采样
        for sample_idx in pbar:
            try:
                # 每次请求使用随机选择的client
                client = get_random_client()
                response = client.chat.completions.create(
                    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    messages=prompt,
                    temperature=1.0,
                    max_tokens=4096,
                )
                if response.choices and len(response.choices) > 0:
                    generation = response.choices[0].message.content
                    generation_code = sanitize(generation)
                    
                    # 验证生成的代码
                    is_valid = False
                    if check_ce(generation_code):
                        result = "ce"
                    else:
                        is_valid = validator.validate_response(generation_code, {
                            "input_output": row['input_output'],
                            "prompter_type": row['prompter_type']
                        })
                    
                    # 构建单条结果，确保prompt是list类型
                    result = {
                        "id": row['id'],
                        "result": is_valid,
                        "prompt": prompt,  # 使用已经转换好的prompt
                        "answer": generation_code,
                        "sample_index": sample_idx
                    }
                    
                    # 立即保存这条结果
                    save_result_to_jsonl(result, output_file)
                    
                    # 更新进度条描述
                    pbar.set_description(f"Sample {row['id']} - Success: {is_valid}")
                    
            except Exception as e:
                import traceback
                print(f"Error in sample {sample_idx} for id {row['id']}: {e}")
                print(f"Debug - Full traceback:")
                print(traceback.format_exc())
                continue
                
    except Exception as e:
        import traceback
        print(f"Error processing sample with id {row['id']}: {e}")
        print(f"Debug - Full traceback:")
        print(traceback.format_exc())

def process_parquet_file(file_path: str, num_workers: int = 100, sample_num: int = 100):
    # 读取parquet文件
    df = pd.read_parquet(file_path)
    validator = ValidationManager(run_all_cases=True)
    
    # 设置输出文件
    output_file = file_path.replace('.parquet', '_results_14b.jsonl')
    
    # 如果输出文件已存在，读取已处理的记录
    processed_records = {}  # 改用字典来存储每个ID的已完成次数
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            for line in f:
                try:
                    record = json.loads(line)
                    sample_id = record['id']
                    sample_index = record['sample_index']
                    if sample_id not in processed_records:
                        processed_records[sample_id] = set()
                    processed_records[sample_id].add(sample_index)
                except Exception as e:
                    print(f"Error reading record: {e}")
                    continue
    
    # 创建主进度条
    total_samples = len(df) * sample_num
    completed_samples = sum(len(indices) for indices in processed_records.values())
    
    main_pbar = tqdm.tqdm(
        total=total_samples,
        desc="Processing samples",
        position=0,
        initial=completed_samples
    )
    
    try:
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = []
            for _, row in df.iterrows():
                sample_id = row['id']
                # 获取这个样本已经完成的采样索引
                completed_indices = processed_records.get(sample_id, set())
                # 计算还需要多少次采样
                remaining_indices = set(range(sample_num)) - completed_indices
                
                if remaining_indices:  # 如果还有未完成的采样
                    # 创建一个新的Series，包含额外的信息
                    row_with_info = row.copy()
                    row_with_info['remaining_indices'] = list(remaining_indices)
                    futures.append(
                        executor.submit(process_remaining_samples, 
                                     row_with_info, 
                                     validator, 
                                     output_file,
                                     main_pbar)
                    )
            
            # 等待所有任务完成
            for future in concurrent.futures.as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    print(f"Error in future: {e}")
    
    finally:
        main_pbar.close()

def process_remaining_samples(row: pd.Series, validator: ValidationManager, output_file: str, main_pbar: tqdm.tqdm) -> None:
    try:
        # 获取剩余需要采样的索引
        remaining_indices = row['remaining_indices']
        
        # 创建这个样本的采样进度条
        pbar = tqdm.tqdm(
            remaining_indices,
            desc=f"Sample {row['id']}",
            position=1,
            leave=False
        )
        
        # 将prompt从ndarray转换为list
        prompt = row['prompt'].tolist() if hasattr(row['prompt'], 'tolist') else row['prompt']
        
        # 对剩余的样本进行采样
        for sample_idx in pbar:
            try:
                # 每次请求使用随机选择的client
                client = get_random_client()
                response = client.chat.completions.create(
                    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    messages=prompt,
                    temperature=1.0,
                    max_tokens=4096,
                )
                if response.choices and len(response.choices) > 0:
                    generation = response.choices[0].message.content
                    generation_code = sanitize(generation)
                    
                    # 验证生成的代码
                    is_valid = False
                    if check_ce(generation_code):
                        result = "ce"
                    else:
                        is_valid = validator.validate_response(generation_code, {
                            "input_output": row['input_output'],
                            "prompter_type": row['prompter_type']
                        })
                    
                    # 构建单条结果，确保prompt是list类型
                    result = {
                        "id": row['id'],
                        "result": is_valid,
                        "prompt": prompt,
                        "answer": generation_code,
                        "sample_index": sample_idx
                    }
                    
                    # 立即保存这条结果
                    save_result_to_jsonl(result, output_file)
                    
                    # 更新两个进度条
                    main_pbar.update(1)
                    pbar.set_description(f"Sample {row['id']} - Success: {is_valid}")
                    
            except Exception as e:
                import traceback
                print(f"Error in sample {sample_idx} for id {row['id']}: {e}")
                print(traceback.format_exc())
                continue
                
    except Exception as e:
        import traceback
        print(f"Error processing sample with id {row['id']}: {e}")
        print(traceback.format_exc())

def main():
    for file_path in tqdm.tqdm(args.input_files, desc="Processing files"):
        print(f"\nProcessing {file_path}")
        process_parquet_file(file_path, num_workers=args.num_workers, sample_num=args.sample_num)

if __name__ == "__main__":
    main()


