import json
from datasets import load_dataset
import random
import os
from collections import defaultdict
import pdb
import cyaron as cy
import sys
import ast
import concurrent.futures


def get_processed_problems(output_file):
    processed = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        processed.add(data['question'].strip())
                    except:
                        continue
    return processed

def sample_problems(input_file):
    problem_data = []
    
    with open(input_file, 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            if not line.strip():
                continue
                
            try:
                entry = json.loads(line)
                new_prob_content = entry["new_prob"]
                
                start_marker = "New_problem:"
                end_marker = "## Part 3:"
                start_index = new_prob_content.find(start_marker)
                if start_index == -1:
                    start_index = new_prob_content.find("New_problem")
                
                if start_index != -1:
                    start_index += len(start_marker) if ":" in start_marker else len("New_problem")
                    end_index = new_prob_content.find(end_marker)
                    problem_description = new_prob_content[start_index:end_index].strip()
                else:
                    part2_start = new_prob_content.find("## Part 2:")
                    if part2_start != -1:
                        part2_end = new_prob_content.find("## Part 3:")
                        problem_description = new_prob_content[part2_start+10:part2_end].strip()
                    else:
                        problem_description = "Problem description not found"
                
                part4_start = new_prob_content.find("## Part 4:")
                if part4_start != -1:
                    part4_content = new_prob_content[part4_start:]
                    
                    difficulty = ""
                    tags = ""
                    skills = ""
                    
                    diff_line = [line for line in part4_content.split('\n') if "difficulty:" in line.lower()]
                    if diff_line:
                        difficulty = diff_line[0].split("difficulty:")[1].strip()
                    
                    tags_line = [line for line in part4_content.split('\n') if "tags:" in line.lower()]
                    if tags_line:
                        tags = tags_line[0].split("tags:")[1].strip().strip("[]").replace("'", "").split(", ")
                    
                    skills_line = [line for line in part4_content.split('\n') if "skills:" in line.lower()]
                    if skills_line:
                        skills = skills_line[0].split("skills:")[1].strip().strip("[]").replace("'", "").split(", ")
                else:
                    difficulty = ""
                    tags = []
                    skills = []
                
                problem_data.append({
                    "problem_id": f"new_prob_{i}",
                    "question": problem_description,
                    "tags": tags if tags else [],
                    "skills": skills if skills else [],
                    "difficulty": difficulty if difficulty else "UNKNOWN"
                })
                
            except Exception as e:
                print(f"Error parsing line {i}: {str(e)}")
                continue
    
    return problem_data



def generate_input_generators(subset, output_file):
    processed = get_processed_problems(output_file)
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    existing_entries = {}
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        existing_entries[data['question']] = data
                    except:
                        continue
    
    for i, problem in enumerate(subset):
        question = problem['question'].strip()
        
        if question in processed:
            print(f"skip processed: {question[:50]}...")
            continue
        
        print(f"processing {i+1}/{len(subset)}: {question[:50]}...")
        
        entry = existing_entries.get(question, {})
        if entry:
            print(f"  found partial result, skip")
            continue
        
        try:
            ans = gen_random_input_generator(problem['question'])
            input_generator, input_validation = parse_ans(ans)
            
            entry = {
                'problem_id': problem['problem_id'],
                'question': problem['question'],
                'tags': problem['tags'],
                'skills': problem.get('skills', ''),
                'difficulty': problem['difficulty'],
                'input_generator': input_generator,
                'input_validation': input_validation
            }
            
            with open(output_file, "a", encoding="utf-8") as file:
                json.dump(entry, file, ensure_ascii=False)
                file.write("\n")
                
        except Exception as e:
            print(f"processing failed: {str(e)}")
            continue



def generate_test_inputs(input_file, output_file):
    processed = get_processed_problems(output_file)
    
    input_data = []
    with open(input_file, 'r', encoding='utf-8') as file:
        for line in file:
            if line.strip():
                try:
                    data = json.loads(line)
                    input_data.append(data)
                except:
                    continue
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    existing_entries = {}
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        existing_entries[data['question']] = data
                    except:
                        continue
    
    for entry in input_data:
        question = entry['question'].strip()
        
        if question in processed:
            print(f"skip processed: {question[:50]}...")
            continue
        
        print(f"processing: {entry['problem_id']}...")
        
        if question in existing_entries:
            print(f"  found partial result, skip")
            continue
        
        generate_test_input_code = entry["input_generator"]
        validate_test_input_code = entry["input_validation"]

        global_vars = globals().copy()
        local_vars = {}
        
        try:
            exec(generate_test_input_code, global_vars, local_vars)
            exec(validate_test_input_code, global_vars, local_vars)
            
            generate_test_input = local_vars.get("generate_test_input")
            validate_test_input = local_vars.get("validate_test_input")
            
            if not generate_test_input or not validate_test_input:
                raise ValueError("cannot extract generator or validator")
            
            valid_data = []
            attempts = 0
            max_attempts = 30
            
            while len(valid_data) < 20 and attempts < max_attempts:
                attempts += 1
                try:
                    input_string = generate_test_input()
                    if validate_test_input(input_string):
                        valid_data.append(input_string)
                except:
                    continue
            
            entry['input_string'] = valid_data
            print(len(valid_data))
            
            if valid_data:
                with open(output_file, "a", encoding="utf-8") as out_file:
                    json.dump(entry, out_file, ensure_ascii=False)
                    out_file.write("\n")
            
        except Exception as e:
            print(f"generate input failed: {str(e)}")
            continue


def generate_brute_force_solutions(input_file, output_file, n_samples=5, max_workers=20):
    processed = get_processed_problems(output_file)
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    with open(input_file, 'r', encoding='utf-8') as file, \
         open(output_file, 'a+', encoding='utf-8') as out_f:
        
        out_f.seek(0)
        existing_entries = {}
        for line in out_f:
            if line.strip():
                try:
                    data = json.loads(line)
                    existing_entries[data['question']] = data
                except:
                    continue
        
        for line in file:
            if not line.strip():
                continue
                
            try:
                entry = json.loads(line)
                question = entry['question'].strip()
            except json.JSONDecodeError as e:
                print(f"parse input file line error: {e}")
                continue
            
            if question in processed:
                print(f"skip processed: {question[:50]}...")
                continue
                
            print(f"processing: {question[:50]}...")
            
            prompt = "You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that can get an absolutely correct result."
            prompt += "\nQUESTION:\n" + question
            prompt += "\nUse Standard Input format"
            prompt += '\n\nEnsure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT.\nEnclose your code within delimiters as follows.\n```python\n# YOUR CODE HERE\n```\n\n### Answer:\n\n'
            prompt += '\n\nNote: You do not need to consider time or space complexity. You can generate a brute-force search code to ensure correctness!!!\n\n'

            messages = [{"role": "user", "content": prompt}]
            
            results = {
                "problem_id": entry.get('problem_id', ''),
                "question": question,
                "tags": entry['tags'],
                "difficulty": entry.get('difficulty', ''),
                "input_strings": entry.get('input_string', []),
                "output": []
            }
            
            existing_data = existing_entries.get(question)
            if existing_data:
                results['output'] = existing_data['output']
                print(f"  found partial generated results: existing {len(results['output'])} samples")
            
            samples_needed = n_samples - len(results['output'])
            if samples_needed <= 0:
                print(f"  samples enough ({n_samples}), skip")
                continue
            
            print(f"  need generate {samples_needed} new samples...")
            
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_index = {}
                
                for idx in range(samples_needed):
                    future = executor.submit(safe_fetch_gpt4, messages)
                    future_to_index[future] = idx + len(results['output'])
                
                temp_results = [None] * samples_needed
                for future in concurrent.futures.as_completed(future_to_index):
                    idx = future_to_index[future] - len(results['output'])
                    try:
                        temp_results[idx] = future.result()
                    except Exception as e:
                        print(f"sample generate failed: {str(e)}")
                        temp_results[idx] = f"Error: {str(e)}"
                
                results['output'].extend(temp_results)
            
            if question in existing_entries:
                out_f.seek(0)
                lines = out_f.readlines()
                out_f.seek(0)
                out_f.truncate()
                
                for line in lines:
                    try:
                        data = json.loads(line)
                        if data['question'] == question:
                            out_f.write(json.dumps(results, ensure_ascii=False) + "\n")
                        else:
                            out_f.write(line)
                    except:
                        out_f.write(line)
            else:
                out_f.write(json.dumps(results, ensure_ascii=False) + "\n")
            
            out_f.flush()
            print(f"  complete problem processing, total {len(results['output'])} samples")

def merge_input_files(file_paths):
    merged_map = {}

    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                data = json.loads(line)
                key = data.get('question')
                if key:
                    current_inputs = set(data.get('input_string', []))
                    if key in merged_map:
                        merged_map[key].update(current_inputs)
                    else:
                        merged_map[key] = current_inputs

    for key in merged_map:
        merged_map[key] = list(merged_map[key])

    return merged_map


def validate_and_classify_solutions(solution_file, input_files, output_file):
    stats = {
        "total_problems": 0,
        "case_a": 0,
        "case_b": 0,
        "case_c": 0,
        "invalid_outputs": 0,
        "total_inputs": 0,
        "valid_inputs_case_a": 0,
        "valid_inputs_case_b": 0
    }
    
    if isinstance(input_files, list) and len(input_files) > 1:
        input_strings_map = merge_input_files(input_files)
    else:
        input_strings_map = {}
        with open(input_files, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                    except:
                        print('corrupted，skip')
                    key = data.get('question')
                    if key:
                        input_strings_map[key] = data.get('input_string', [])

    processed_questions = set()
    if os.path.exists(output_file):
        print(f"detected existing output file {output_file}，start resume processing...")
        try:
            with open(output_file, 'r', encoding='utf-8') as out_f:
                recovered_count = 0
                for line in out_f:
                    if not line.strip():
                        continue
                    try:
                        problem_result = json.loads(line)
                        q = problem_result["question"]
                        if q in processed_questions:
                            continue
                            
                        processed_questions.add(q)
                        recovered_count += 1
                        
                    except Exception as e:
                        print(f"parse existing result line error: {str(e)}")
                        continue
                        
                print(f"successfully recovered {recovered_count} problem processing progress")
        except Exception as e:
            print(f"read existing output file failed: {str(e)}，will start from beginning")
    else:
        print("no existing output file detected，will start from beginning")

    with open(solution_file, 'r', encoding='utf-8') as f, \
         open(output_file, 'a', encoding='utf-8') as out_f:
        
        for line in f:
            if not line.strip():
                continue
                
            gen_entry = json.loads(line)
            question = gen_entry["question"]
            
            if question in processed_questions:
                print(f"skip processed problem: {question[:20]}...")
                continue
                
            stats["total_problems"] += 1
            
            if question not in input_strings_map:
                print(f"problem no corresponding input，skip: {question[:50]}...")
                continue
                
            input_strings = input_strings_map[question]
            code_list = gen_entry["output"]
            stats["total_inputs"] += len(input_strings)
            
            input_results = []
            case_a_inputs = 0
            case_b_inputs = 0
            problem_invalid_outputs = 0

            problem_judgment_cache = {}
            
            for input_idx, input_str in enumerate(input_strings):
                
                with concurrent.futures.ThreadPoolExecutor(max_workers=60) as executor:
                    futures = {executor.submit(run_code, extract_code(code), input_str, timeout=30): code_idx 
                              for code_idx, code in enumerate(code_list)}
                    
                    output_families = defaultdict(list)
                    
                    for future in concurrent.futures.as_completed(futures):
                        code_idx = futures[future]
                        extracted_code = extract_code(code_list[code_idx])
                        
                        try:
                            result = future.result()
                            if is_valid_output(result):
                                output_families[result].append({
                                    "code_idx": code_idx,
                                    "code": extracted_code
                                })
                            else:
                                stats["invalid_outputs"] += 1
                                problem_invalid_outputs += 1
                        except Exception as exc:
                            stats["invalid_outputs"] += 1
                            problem_invalid_outputs += 1
                            continue
                
                family_sizes = sorted([len(v) for v in output_families.values()], reverse=True)
                input_result = {
                    "input": input_str,
                    "output_families": [],
                    "classification": None,
                    "valid_output": None,
                    "valid_code_indices": []
                }
                
                for output, code_infos in output_families.items():
                    input_result["output_families"].append({
                        "output": output,
                        "count": len(code_infos),
                        "code_indices": [ci["code_idx"] for ci in code_infos]
                    })
                
                if family_sizes and family_sizes[0] >= (len(code_list)/2):
                    case_a_inputs += 1
                    input_result["classification"] = "case_a"
                    max_family = max(output_families.items(), key=lambda x: len(x[1]))
                    input_result["valid_output"] = max_family[0]
                    input_result["valid_code_indices"] = [ci["code_idx"] for ci in max_family[1]]
                    input_result["valid_codes"] = [ci["code"] for ci in max_family[1]]
                    print('majority')

                
                elif (len(family_sizes) >= 2 and family_sizes[0] >= 2 and 
                      (family_sizes[0] - family_sizes[1]) <= 2):
                    case_b_inputs += 1
                    input_result["classification"] = "case_b"
                    
                    top_families = []
                    for output, code_infos in output_families.items():
                        if len(code_infos) >= family_sizes[1]:
                            top_families.append({
                                "output": output,
                                "count": len(code_infos),
                                "code_indices": [ci["code_idx"] for ci in code_infos],
                                "codes": [ci["code"] for ci in code_infos]
                            })
                    
                    cache_key = frozenset(tuple(f['code_indices']) for f in top_families)
                    
                    if cache_key in problem_judgment_cache:
                        selected_code_idx = problem_judgment_cache[cache_key]
                        print(f"  input {input_idx+1}: use cached R1 judgment (Case B)")
                    else:
                        print(f"  input {input_idx+1}: call R1 for judgment (Case B)")
                        selected_code_idx = let_r1_judge_solution(question, top_families)
                        problem_judgment_cache[cache_key] = selected_code_idx
                    
                    if selected_code_idx is not None:
                        for family in input_result["output_families"]:
                            if selected_code_idx in family["code_indices"]:
                                input_result["valid_output"] = family["output"]
                                input_result["valid_code_indices"] = family["code_indices"]
                                break
                
                else:
                    input_result["classification"] = "case_c"
                    print(f"  input {input_idx+1}: cannot determine valid output (Case C)")
                
                input_results.append(input_result)
            
            if case_a_inputs > 0:
                stats["case_a"] += 1
                stats["valid_inputs_case_a"] += case_a_inputs
                print(f"problem classification: Case A (majority consistent input {case_a_inputs}/{len(input_strings)})")
            elif case_b_inputs > 0:
                stats["case_b"] += 1
                stats["valid_inputs_case_b"] += case_b_inputs
                print(f"problem classification: Case B (need R1 judgment input {case_b_inputs}/{len(input_strings)})")
            else:
                stats["case_c"] += 1
                print(f"problem classification: Case C (no valid input {len(input_strings)}/{len(input_strings)})")
            
            problem_result = {
                "problem_id": gen_entry.get("problem_id", ""),
                "question": question,
                "tags": gen_entry["tags"],
                "difficulty": gen_entry.get("difficulty", ""),
                "valid_test_cases": [{
                    "input": ir["input"],
                    "output": ir["valid_output"]
                } for ir in input_results if ir["valid_output"] is not None],
            }
            
            out_f.write(json.dumps(problem_result, ensure_ascii=False) + "\n")
            out_f.flush()
            print(f"completed problem processing and saved result")

    print(json.dumps(stats, indent=2))


def let_r1_judge_solution(question, top_families):
    prompt = f"Question: {question}\n\n"
    prompt += "Several solutions were generated for this problem. Please evaluate which solution is most correct.\n\n"
    
    for i, family in enumerate(top_families):
        prompt += f"\nOption {i+1} (supported by {family['count']} variants):\n"
        prompt += "Code:\n```python\n"
        prompt += family['codes'][0] + "\n```\n"
    
    prompt += "\nPlease select the most correct solution by responding with just the option number (1, 2, etc.). If none of them are correct, respond with 'None' (exactly this word)."
    
    messages = [{"role": "user", "content": prompt}]
    response = safe_fetch_gpt4(messages)

    if response.strip().lower() == 'none':
        return None
    
    try:
        selected = int(response.strip())
        if 1 <= selected <= len(top_families):
            return top_families[selected-1]['code_indices'][0]
    except:
        pass
    
    return None

if __name__ == "__main__":
    BASE_PATH = "/home/user/code/code_gen/new_prob_json/brute"
    NEW_PROB_DIR = "../new_prob"
    
    prob_files = [f for f in os.listdir(NEW_PROB_DIR) 
                 if f.startswith("new_prob_") and f.endswith(".json")]

    import pdb
    pdb.set_trace()
    
    for prob_file in prob_files[9:10]:
        method_type = prob_file.replace("new_prob_", "").replace(".json", "")
        print(f"\n{'='*50}")
        print(f"start processing method: {method_type}")
        print(f"{'='*50}")
        
        method_dir = os.path.join(BASE_PATH, method_type)
        os.makedirs(method_dir, exist_ok=True)
        
        input_file_path = os.path.join(NEW_PROB_DIR, prob_file)
        
        print(f"step1: parse problem file {prob_file}")
        subset = sample_problems(input_file_path)
        
        print("step2.1: generate input generators")
        generate_input_generators(
            subset, 
            os.path.join(method_dir, "input_generators.json")
        )
        
        print("step2.2: generate test inputs")
        generate_test_inputs(
            os.path.join(method_dir, "input_generators.json"),
            os.path.join(method_dir, "test_inputs.json")
        )
        
        print("step3: generate brute force solutions")
        generate_brute_force_solutions(
            os.path.join(method_dir, "test_inputs.json"),
            os.path.join(method_dir, "brute_solutions.json")
        )
        
        print("step4: validate solutions")
        validate_and_classify_solutions(
            os.path.join(method_dir, "brute_solutions.json"),
            os.path.join(method_dir, "test_inputs.json"),
            os.path.join(method_dir, "validated_cases.json")
        )
        
        print(f"\n{'='*50}")
        print(f"completed processing method: {method_type}")
        print(f"{'='*50}\n")