import json
from datasets import load_dataset
import random
import os
from collections import defaultdict
from utils import parse_ans, gen_adv_input_generator, safe_fetch_gpt4, fetch_r1, extract_code, run_code, is_valid_output, find_equivalent_groups, outputs_match
import pdb
from openai import OpenAI
import cyaron as cy
import sys
import ast
import concurrent.futures


def get_processed_problems(output_file):
    processed = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        processed.add(data['question'].strip())
                    except:
                        continue
    return processed

def sample_problems(input_file):
    problem_data = []
    
    with open(input_file, 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            if not line.strip():
                continue
                
            try:
                entry = json.loads(line)
                new_prob_content = entry["new_prob"]
                
                start_marker = "New_problem:"
                end_marker = "## Part 3:"
                start_index = new_prob_content.find(start_marker)
                if start_index == -1:
                    start_index = new_prob_content.find("New_problem")
                
                if start_index != -1:
                    start_index += len(start_marker) if ":" in start_marker else len("New_problem")
                    end_index = new_prob_content.find(end_marker)
                    problem_description = new_prob_content[start_index:end_index].strip()
                else:
                    part2_start = new_prob_content.find("## Part 2:")
                    if part2_start != -1:
                        part2_end = new_prob_content.find("## Part 3:")
                        problem_description = new_prob_content[part2_start+10:part2_end].strip()
                    else:
                        problem_description = "Problem description not found"
                
                part4_start = new_prob_content.find("## Part 4:")
                if part4_start != -1:
                    part4_content = new_prob_content[part4_start:]
                    
                    difficulty = ""
                    tags = ""
                    skills = ""
                    
                    diff_line = [line for line in part4_content.split('\n') if "difficulty:" in line.lower()]
                    if diff_line:
                        difficulty = diff_line[0].split("difficulty:")[1].strip()
                    
                    tags_line = [line for line in part4_content.split('\n') if "tags:" in line.lower()]
                    if tags_line:
                        tags = tags_line[0].split("tags:")[1].strip().strip("[]").replace("'", "").split(", ")
                    
                    skills_line = [line for line in part4_content.split('\n') if "skills:" in line.lower()]
                    if skills_line:
                        skills = skills_line[0].split("skills:")[1].strip().strip("[]").replace("'", "").split(", ")
                else:
                    difficulty = ""
                    tags = []
                    skills = []
                
                problem_data.append({
                    "problem_id": f"new_prob_{i}",
                    "question": problem_description,
                    "tags": tags if tags else [],
                    "skills": skills if skills else [],
                    "difficulty": difficulty if difficulty else "UNKNOWN"
                })
                
            except Exception as e:
                print(f"Error parsing line {i}: {str(e)}")
                continue
    
    return problem_data


def generate_input_generators(subset, output_file):
    processed = get_processed_problems(output_file)
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    existing_entries = {}
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        existing_entries[data['question']] = data
                    except:
                        continue
    
    for i, problem in enumerate(subset):
        question = problem['question'].strip()
        
        if question in processed:
            print(f"Skipping processed problem: {question[:50]}...")
            continue
        
        print(f"Processing problem {i+1}/{len(subset)}: {question[:50]}...")
        
        entry = existing_entries.get(question, {})
        if entry:
            print(f"  Found partial results, skipping generation")
            continue
        
        try:
            ans = gen_adv_input_generator(problem['question'])
            input_generator, input_validation = parse_ans(ans)
            
            entry = {
                'problem_id': problem['problem_id'],
                'question': problem['question'],
                'tags': problem['tags'],
                'skills': problem.get('skills', ''),
                'difficulty': problem['difficulty'],
                'input_generator': input_generator,
                'input_validation': input_validation
            }
            
            with open(output_file, "a", encoding="utf-8") as file:
                json.dump(entry, file, ensure_ascii=False)
                file.write("\n")
                
        except Exception as e:
            print(f"Failed to process problem: {str(e)}")
            continue



def generate_test_inputs(input_file, output_file):
    processed = get_processed_problems(output_file)
    
    input_data = []
    with open(input_file, 'r', encoding='utf-8') as file:
        for line in file:
            if line.strip():
                try:
                    data = json.loads(line)
                    input_data.append(data)
                except:
                    continue
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    existing_entries = {}
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        existing_entries[data['question']] = data
                    except:
                        continue
    
    for entry in input_data:
        question = entry['question'].strip()
        
        if question in processed:
            print(f"Skipping processed problem: {question[:50]}...")
            continue
        
        print(f"Processing problem: {entry['problem_id']}...")
        
        if question in existing_entries:
            print(f"  Found partial results, skipping generation")
            continue
        
        generate_test_input_code = entry["input_generator"]
        validate_test_input_code = entry["input_validation"]

        global_vars = globals().copy()
        local_vars = {}
        
        try:
            exec(generate_test_input_code, global_vars, local_vars)
            exec(validate_test_input_code, global_vars, local_vars)
            
            generate_test_input = local_vars.get("generate_test_input")
            validate_test_input = local_vars.get("validate_test_input")
            
            if not generate_test_input or not validate_test_input:
                raise ValueError("Failed to extract generator or validator function")
            
            valid_data = []
            attempts = 0
            max_attempts = 30
            
            while len(valid_data) < 20 and attempts < max_attempts:
                attempts += 1
                try:
                    input_string = generate_test_input()
                    if validate_test_input(input_string):
                        valid_data.append(input_string)
                except:
                    continue
            
            entry['input_string'] = valid_data
            print(len(valid_data))
            
            if valid_data:
                with open(output_file, "a", encoding="utf-8") as out_file:
                    json.dump(entry, out_file, ensure_ascii=False)
                    out_file.write("\n")
            
        except Exception as e:
            print(f"Failed to generate input: {str(e)}")
            continue



def generate_optimized_solutions(input_file, output_file, n_samples=10, max_workers=20):
    processed = get_processed_problems(output_file)
    
    if not os.path.exists(output_file):
        open(output_file, 'w').close()
    
    with open(input_file, 'r', encoding='utf-8') as file, \
         open(output_file, 'a+', encoding='utf-8') as out_f:
        
        out_f.seek(0)
        existing_entries = {}
        for line in out_f:
            if line.strip():
                try:
                    data = json.loads(line)
                    existing_entries[data['question']] = data
                except:
                    continue
        
        for line in file:
            if not line.strip():
                continue
                
            entry = json.loads(line)
            question = entry['question'].strip()
            
            if question in processed:
                print(f"Skipping processed problem: {question[:50]}...")
                continue
                
            print(f"Processing problem: {question[:50]}...")
            
            prompt = "You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
            prompt += "\nQUESTION:\n" + question
            prompt += '\nEnclose your code within delimiters as follows:\n```python\n# YOUR CODE HERE\n```\n'
            
            messages = [{"role": "user", "content": prompt}]
            
            results = {
                "problem_id": entry.get('problem_id', ''),
                "question": question,
                "tags": entry['tags'],
                "difficulty": entry.get('difficulty', ''),
                "input_strings": entry.get('input_string', []),
                "output": []
            }
            
            existing_data = existing_entries.get(question)
            if existing_data:
                results['output'] = existing_data['output']
                print(f"  Found partial generation results: existing {len(results['output'])} samples")
            
            samples_needed = n_samples - len(results['output'])
            if samples_needed <= 0:
                print(f"  Samples already sufficient ({n_samples}), skipping generation")
                continue
            
            print(f"  Need to generate {samples_needed} new samples...")
            
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_index = {}
                
                for idx in range(samples_needed):
                    future = executor.submit(safe_fetch_gpt4, messages)
                    future_to_index[future] = idx + len(results['output'])
                
                temp_results = [None] * samples_needed
                for future in concurrent.futures.as_completed(future_to_index):
                    idx = future_to_index[future] - len(results['output'])
                    try:
                        temp_results[idx] = future.result()
                    except Exception as e:
                        print(f"Sample generation failed: {str(e)}")
                        temp_results[idx] = f"Error: {str(e)}"
                
                results['output'].extend(temp_results)
            
            if question in existing_entries:
                out_f.seek(0)
                lines = out_f.readlines()
                out_f.seek(0)
                out_f.truncate()
                
                for line in lines:
                    try:
                        data = json.loads(line)
                        if data['question'] == question:
                            out_f.write(json.dumps(results, ensure_ascii=False) + "\n")
                        else:
                            out_f.write(line)
                    except:
                        out_f.write(line)
            else:
                out_f.write(json.dumps(results, ensure_ascii=False) + "\n")
            
            out_f.flush()
            print(f"  Completed problem processing, total {len(results['output'])} samples")


def let_r1_judge_solution(question, top_families):
    prompt = f"Question: {question}\n\n"
    prompt += "Several solutions were generated for this problem. Please evaluate which solution is most correct.\n\n"
    
    for i, family in enumerate(top_families):
        prompt += f"\nOption {i+1} (supported by {family['count']} variants):\n"
        prompt += "Code:\n```python\n"
        prompt += family['codes'][0] + "\n```\n"
    
    prompt += "\nPlease select the most correct solution by responding with just the option number (1, 2, etc.). If none of them are correct, respond with 'None' (exactly this word)."
    
    messages = [{"role": "user", "content": prompt}]
    response = safe_fetch_gpt4(messages)

    if response.strip().lower() == 'none':
        return None
    
    try:
        selected = int(response.strip())
        if 1 <= selected <= len(top_families):
            return top_families[selected-1]['code_indices'][0]
    except:
        pass
    
    return None



if __name__ == "__main__":
    BASE_PATH = "/home/user/code/code_gen/new_prob_json/opt"
    NEW_PROB_DIR = "../new_prob"
    
    prob_files = [f for f in os.listdir(NEW_PROB_DIR) 
                 if f.startswith("new_prob_") and f.endswith(".json")]
    
    for prob_file in prob_files:
        method_type = prob_file.replace("new_prob_", "").replace(".json", "")
        print(f"\n{'='*50}")
        print(f"Starting method: {method_type}")
        print(f"{'='*50}")
        
        method_dir = os.path.join(BASE_PATH, method_type)
        os.makedirs(method_dir, exist_ok=True)
        
        input_file_path = os.path.join(NEW_PROB_DIR, prob_file)
        
        print(f"Step 1: Parsing problem file {prob_file}")
        subset = sample_problems(input_file_path)
        
        print("Step 2.1: Generating input generators")
        generate_input_generators(
            subset, 
            os.path.join(method_dir, "input_generators.json")
        )
        
        print("Step 2.2: Generating test inputs")
        generate_test_inputs(
            os.path.join(method_dir, "input_generators.json"),
            os.path.join(method_dir, "test_inputs.json")
        )
        
        print("Step 3: Generating opt solutions")
        generate_optimized_solutions(
            os.path.join(method_dir, "test_inputs.json"), 
            os.path.join(method_dir, "optimized_solutions.json")
        )
        
        print(f"\n{'='*50}")
        print(f"Completed method: {method_type}")
        print(f"{'='*50}\n")
        print(f"\n{'='*50}")
        print(f"Completed method: {method_type}")
        print(f"{'='*50}\n")