import json
from datasets import load_dataset
import random
import os
from collections import defaultdict
from utils import (
    merge_input_files,
    safe_fetch_gpt4,
    fetch_r1,
    extract_code,
    run_code,
    run_code_with_time,
    is_valid_output,
    find_equivalent_groups,
    outputs_match,
)
import pdb
import cyaron as cy
import sys
import ast
import concurrent.futures
import math
import logging

OPT_BASE = "/home/user/code/code_gen/new_prob_json/opt"
BRUTE_BASE = "/home/user/code/code_gen/new_prob_json/brute"
OUT_BASE = "/home/user/code/code_gen/new_prob_json/merged_results"

os.makedirs(OUT_BASE, exist_ok=True)

MAX_METHOD_WORKERS = min(3, (os.cpu_count() or 4))
MAX_PER_INPUT_WORKERS = 20


def generate_direct_inputs(question_text, rounds=2, n_per_round=5):
    all_inputs = []
    for round_idx in range(rounds):
        prompt = (
            f"Task: Generate {n_per_round} challenging test inputs for the algorithm problem:\n"
            f"{question_text}\n\n"
            f"Instructions:\n"
            f"- Focus on edge cases or scenarios that maximize the failure probability in faulty solutions.\n"
            f"- Generate exactly {n_per_round} test inputs.\n"
            f"- Each test input must be small-scale, complete, and valid.\n"
            f"- Output the test inputs in the following format (only the block between the delimiters):\n"
            f"'''plaintext\n"
            f"Test Input 1:\n<content>\n"
            f"Test Input 2:\n<content>\n"
            f"...\n"
            f"Test Input {n_per_round}:\n<content>\n"
            f"'''\n"
            f"Think step by step about potential failure scenarios and provide concise inputs only."
        )

        try:
            response = safe_fetch_gpt4([{"role": "user", "content": prompt}])
            start_marker = "'''plaintext"
            end_marker = "'''"
            input_block = None
            start_idx = response.find(start_marker)
            if start_idx != -1:
                start_idx += len(start_marker)
                end_idx = response.find(end_marker, start_idx)
                if end_idx != -1:
                    input_block = response[start_idx:end_idx].strip()
                else:
                    input_block = response[start_idx:].strip()
            else:
                start_marker2 = "'''"
                start_idx2 = response.find(start_marker2)
                if start_idx2 != -1:
                    start_idx2 += len(start_marker2)
                    end_idx2 = response.find(start_marker2, start_idx2)
                    if end_idx2 != -1:
                        input_block = response[start_idx2:end_idx2].strip()
                    else:
                        input_block = response[start_idx2:].strip()
                else:
                    input_block = response.strip()

            inputs = []
            current = []
            in_block = False
            for line in (input_block or "").splitlines():
                line_strip = line.strip()
                if not line_strip:
                    continue
                if line_strip.lower().startswith("test input"):
                    if current:
                        inputs.append("\n".join(current).strip())
                        current = []
                    in_block = True
                    continue
                if in_block:
                    current.append(line.rstrip())
            if current:
                inputs.append("\n".join(current).strip())

            if inputs:
                all_inputs.extend(inputs)
            print(f"generate_direct_inputs: generated {len(inputs)} inputs (round {round_idx+1})")
        except Exception as e:
            print(f"generate_direct_inputs: error on round {round_idx+1}: {e}")
            continue

    seen = set()
    deduped = []
    for it in all_inputs:
        if it and it not in seen:
            seen.add(it)
            deduped.append(it)
    return deduped


def load_jsonl_map(path):
    result = {}
    if not os.path.exists(path):
        return result
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                data = json.loads(line)
            except Exception:
                continue
            q = data.get('question')
            if q is None:
                continue
            result[q] = data
    return result


def load_test_inputs_map(test_inputs_path):
    mapping = {}
    if not os.path.exists(test_inputs_path):
        return mapping
    with open(test_inputs_path, 'r', encoding='utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                d = json.loads(line)
            except Exception:
                continue
            q = d.get('question')
            if q is None:
                continue
            inputs = d.get('input_string') or d.get('input_strings') or d.get('test_inputs') or d.get('inputs') or []
            if not isinstance(inputs, list):
                inputs = [inputs] if inputs is not None else []
            entry = {
                'inputs': inputs,
                'skills': d.get('skills') or d.get('skill') or None,
                'raw': d
            }
            mapping[q] = entry
    return mapping


def validate_and_classify_on_inputs(question, code_entries, inputs_list,
                                    call_r1_on_case_c=True,
                                    max_workers_per_input=10):
    code_strings = []
    for ce in code_entries:
        try:
            code_strings.append(extract_code(ce))
        except Exception:
            code_strings.append(None)

    input_results = []
    code_runtimes = defaultdict(list)

    for input_str in inputs_list:
        per_input_results = run_all_codes_on_input_concurrent(code_entries, input_str, max_workers=max_workers_per_input)
        output_families = defaultdict(list)
        for code_idx, res in per_input_results.items():
            out = res.get("output")
            t = res.get("time_used")
            if t is not None:
                code_runtimes[code_idx].append(t)
            if is_valid_output(out):
                key = str(out)
                output_families[key].append({
                    "code_idx": code_idx,
                    "code": code_strings[code_idx]
                })
        family_sizes = sorted([len(v) for v in output_families.values()], reverse=True)

        input_result = {
            "input": input_str,
            "output_families": [],
            "classification": None,
            "valid_output": None,
            "valid_code_indices": [],
            "valid_codes": []
        }

        for output_key, code_infos in output_families.items():
            input_result["output_families"].append({
                "output": output_key,
                "count": len(code_infos),
                "code_indices": [ci["code_idx"] for ci in code_infos],
                "code_infos": code_infos
            })

        if family_sizes and family_sizes[0] >= (len(code_entries) / 2):
            input_result["classification"] = "case_a"
            max_family_output, max_family_list = max(output_families.items(), key=lambda x: len(x[1]))
            input_result["valid_output"] = max_family_output
            input_result["valid_code_indices"] = [ci["code_idx"] for ci in max_family_list]
            input_result["valid_codes"] = [ci["code"] for ci in max_family_list]

        elif (len(family_sizes) >= 2 and family_sizes[0] >= 2 and (family_sizes[0] - family_sizes[1]) <= 2):
            input_result["classification"] = "case_b"
            second_size = family_sizes[1]
            top_families = []
            for output_key, code_infos in output_families.items():
                if len(code_infos) >= second_size:
                    top_families.append({
                        "output": output_key,
                        "count": len(code_infos),
                        "code_indices": [ci["code_idx"] for ci in code_infos],
                        "codes": [ci["code"] for ci in code_infos]
                    })
            if call_r1_on_case_c and top_families:
                selected_code_idx = let_r1_judge_solution(question, top_families)
                if selected_code_idx is not None:
                    for fam in input_result["output_families"]:
                        if selected_code_idx in fam["code_indices"]:
                            input_result["valid_output"] = fam["output"]
                            input_result["valid_code_indices"] = fam["code_indices"]
                            input_result["valid_codes"] = [ci["code"] for ci in fam.get("code_infos", [])]
                            break

        else:
            input_result["classification"] = "case_c"
            print('throw away')
            if call_r1_on_case_c:
                pass

        input_results.append(input_result)

    return input_results, code_runtimes


def let_r1_judge_solution(question, top_families):
    prompt = f"Question: {question}\n\n"
    prompt += "Several solutions were generated for this problem. Please evaluate which solution is most correct.\n\n"

    for i, family in enumerate(top_families):
        prompt += f"\nOption {i+1} (supported by {family['count']} variants):\n"
        prompt += "Code:\n```python\n"
        prompt += family['codes'][0] + "\n```\n"

    prompt += "\nPlease select the most correct solution by responding with just the option number (1, 2, etc.). If none of them are correct, respond with 'None' (exactly this word)."

    messages = [{"role": "user", "content": prompt}]
    response = safe_fetch_gpt4(messages)

    if response.strip().lower() == 'none':
        return None

    try:
        selected = int(response.strip())
        if 1 <= selected <= len(top_families):
            return top_families[selected-1]['code_indices'][0]
    except:
        pass

    return None


def run_all_codes_on_input_concurrent(code_list, input_str, max_workers=MAX_PER_INPUT_WORKERS):
    results = {}
    futures = {}
    with concurrent.futures.ThreadPoolExecutor(max_workers=min(max_workers, max(1, len(code_list)))) as ex:
        for idx, code_entry in enumerate(code_list):
            try:
                src = extract_code(code_entry)
            except Exception:
                results[idx] = {"output": None, "time_used": None}
                continue
            futures[ex.submit(run_code_with_time, src, input_str)] = idx

        for fut in concurrent.futures.as_completed(futures):
            idx = futures[fut]
            try:
                res = fut.result()
                results[idx] = {"output": res.get("output"), "time_used": res.get("time_used")}
            except Exception:
                results[idx] = {"output": None, "time_used": None}
    return results


def process_single_question(question, opt_entry, brute_entry,
                            opt_inputs_map, brute_valid_cases_map):
    if opt_entry is None and brute_entry is None:
        return None

    gen_entry = opt_entry or brute_entry
    problem_id = gen_entry.get("problem_id", "")
    tags = gen_entry.get("tags", []) or []
    difficulty = gen_entry.get("difficulty", "") or ""

    problem_result = {
        "problem_id": problem_id,
        "question": question,
        "tags": tags,
        "difficulty": difficulty,
    }

    opt_inputs_entry = opt_inputs_map.get(question, {}) if isinstance(opt_inputs_map, dict) else {}
    if isinstance(opt_inputs_entry, dict):
        opt_inputs = opt_inputs_entry.get('inputs', []) or []
        opt_skills = opt_inputs_entry.get('skills') or None
    else:
        opt_inputs = opt_inputs_entry or []
        opt_skills = None

    if opt_entry and brute_entry:
        code_list = opt_entry.get("output", []) or []
        brute_test_cases = brute_valid_cases_map.get(question, []) or []

        if not code_list:
            problem_result["rand_test_cases"] = brute_test_cases
            if opt_skills is not None:
                problem_result["skills"] = opt_skills
            return problem_result

        problem_result["rand_test_cases"] = brute_test_cases
        filtered_codes = []
        if brute_test_cases:
            for code in code_list:
                try:
                    src = extract_code(code)
                except Exception:
                    continue
                all_pass = True
                for case in brute_test_cases:
                    try:
                        out = run_code(src, case["input"])
                    except Exception:
                        all_pass = False
                        break
                    if not outputs_match(out, case["output"]):
                        all_pass = False
                        break
                if all_pass:
                    filtered_codes.append(code)
        else:
            filtered_codes = list(code_list)

        if not filtered_codes:
            problem_result["rand_test_cases"] = brute_test_cases
            if opt_skills is not None:
                problem_result["skills"] = opt_skills
            return problem_result

        input_results, code_runtimes = validate_and_classify_on_inputs(
            question, filtered_codes, opt_inputs,
            call_r1_on_case_c=True,
            max_workers_per_input=MAX_PER_INPUT_WORKERS
        )

        valid_test_cases = [
            {"input": ir["input"], "output": ir["valid_output"]}
            for ir in input_results if ir.get("valid_output") is not None
        ]

        all_valid_code_indices = set()
        for ir in input_results:
            for idx in ir.get("valid_code_indices", []):
                all_valid_code_indices.add(idx)

        code_time_limits = {}
        for code_idx in all_valid_code_indices:
            runtimes = code_runtimes.get(code_idx, [])
            if not runtimes:
                continue
            max_runtime = max(runtimes)
            time_limit = math.ceil(max_runtime * 3)
            try:
                code_str = extract_code(filtered_codes[code_idx])
            except Exception:
                code_str = f"code_idx_{code_idx}"
            code_time_limits[code_str] = time_limit

        max_time_limit = min(
            max(code_time_limits.values()) if code_time_limits else 5,
            5
        )

        valid_codes_list = []
        seen = set()
        for idx, code in enumerate(filtered_codes):
            if idx in all_valid_code_indices:
                try:
                    s = extract_code(code)
                except Exception:
                    continue
                if s not in seen:
                    seen.add(s)
                    valid_codes_list.append(s)
                if len(valid_codes_list) >= 1:
                    break

        problem_result["adv_test_cases"] = valid_test_cases
        problem_result["valid_codes"] = valid_codes_list
        problem_result["time_limits"] = max_time_limit

        if opt_skills is not None:
            problem_result["skills"] = opt_skills

        print(f"adv_test_cases length{len(problem_result.get('adv_test_cases', []))} "
              f"time_limits={problem_result['time_limits']}"
              )
        print(f"adv_test_cases length{len(problem_result.get('adv_test_cases', []))} "
              f"time_limits={problem_result['time_limits']}"
              )

        try:
            direct_inputs = generate_direct_inputs(question)
            if direct_inputs:
                corner_input_results, corner_code_runtimes = validate_and_classify_on_inputs(
                    question, filtered_codes, direct_inputs,
                    call_r1_on_case_c=True,
                    max_workers_per_input=MAX_PER_INPUT_WORKERS
                )
                corner_test_cases = [
                    {"input": cir["input"], "output": cir["valid_output"]}
                    for cir in corner_input_results if cir.get("valid_output") is not None
                ]
                problem_result["corner_test_cases"] = corner_test_cases if corner_test_cases else []
            else:
                problem_result["corner_test_cases"] = []
        except Exception as e:
            print(f"Error when generating/validating corner tests: {e}")
            problem_result["corner_test_cases"] = []

        return problem_result

    if brute_entry and not opt_entry:
        vtc = (brute_entry.get("valid_test_cases") or brute_entry.get("validated_cases") or brute_entry.get("test_cases") or [])
        problem_result["rand_test_cases"] = vtc
        print(f"only brute which length {len(problem_result['rand_test_cases'])} ")
        return problem_result

    if opt_entry and not brute_entry:
        code_list = opt_entry.get("output", []) or []

        if not code_list:
            problem_result["adv_test_cases"] = []
            problem_result["valid_codes"] = []
            problem_result["time_limits"] = None
            if opt_skills is not None:
                problem_result["skills"] = opt_skills
            print(f"adv_test_cases length{len(problem_result.get('adv_test_cases', []))} "
                  f"time_limits={problem_result['time_limits']}"
                  )
            return problem_result

        input_results, code_runtimes = validate_and_classify_on_inputs(
            question, code_list, opt_inputs,
            call_r1_on_case_c=True,
            max_workers_per_input=MAX_PER_INPUT_WORKERS
        )

        adv_test_cases = [{"input": ir["input"], "output": ir["valid_output"]} for ir in input_results if ir.get("valid_output") is not None]

        all_valid_code_indices = set()
        for ir in input_results:
            for idx in ir.get("valid_code_indices", []):
                all_valid_code_indices.add(idx)

        code_time_limits = {}
        for code_idx in all_valid_code_indices:
            runtimes = code_runtimes.get(code_idx, [])
            if not runtimes:
                continue
            max_runtime = max(runtimes)
            try:
                code_str = extract_code(code_list[code_idx])
            except Exception:
                code_str = f"code_idx_{code_idx}"
            code_time_limits[code_str] = math.ceil(max_runtime * 3)

        max_time_limit = min(
            max(code_time_limits.values()) if code_time_limits else 5,
            5
        )

        valid_codes_list = []
        seen = set()
        for idx, code in enumerate(code_list):
            if idx in all_valid_code_indices:
                try:
                    s = extract_code(code)
                except Exception:
                    continue
                if s not in seen:
                    seen.add(s)
                    valid_codes_list.append(s)
                if len(valid_codes_list) >= 3:
                    break

        problem_result["adv_test_cases"] = adv_test_cases
        problem_result["valid_codes"] = valid_codes_list
        problem_result["time_limits"] = max_time_limit
        if opt_skills is not None:
            problem_result["skills"] = opt_skills
        print(f"adv_test_cases length{len(problem_result.get('adv_test_cases', []))} "
              f"time_limits={problem_result['time_limits']}"
              )

        try:
            direct_inputs = generate_direct_inputs(question)
            if direct_inputs:
                corner_input_results, corner_code_runtimes = validate_and_classify_on_inputs(
                    question, code_list, direct_inputs,
                    call_r1_on_case_c=True,
                    max_workers_per_input=MAX_PER_INPUT_WORKERS
                )
                corner_test_cases = [
                    {"input": cir["input"], "output": cir["valid_output"]}
                    for cir in corner_input_results if cir.get("valid_output") is not None
                ]
                problem_result["corner_test_cases"] = corner_test_cases if corner_test_cases else []
            else:
                problem_result["corner_test_cases"] = []
        except Exception as e:
            print(f"Error when generating/validating corner tests (only-in-opt): {e}")
            problem_result["corner_test_cases"] = []

        return problem_result

    return None


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stderr),
        logging.FileHandler('processing.log', mode='a', encoding='utf-8')
    ]
)


def process_method(method_name, max_question_workers=MAX_METHOD_WORKERS):
    opt_dir = os.path.join(OPT_BASE, method_name)
    brute_dir = os.path.join(BRUTE_BASE, method_name)
    out_dir = os.path.join(OUT_BASE, method_name)
    os.makedirs(out_dir, exist_ok=True)

    opt_solutions_path = os.path.join(opt_dir, "optimized_solutions.json")
    opt_test_inputs_path = os.path.join(opt_dir, "test_inputs.json")
    brute_validated_path = os.path.join(brute_dir, "validated_cases.json")

    opt_map = load_jsonl_map(opt_solutions_path)
    brute_map = load_jsonl_map(brute_validated_path)
    opt_inputs_map = load_test_inputs_map(opt_test_inputs_path)

    brute_valid_cases_map = {}
    for q, entry in brute_map.items():
        vtc = entry.get("valid_test_cases") or entry.get("validated_cases") or entry.get("test_cases") or []
        brute_valid_cases_map[q] = vtc

    questions = sorted(set(opt_map.keys()) | set(brute_map.keys()))
    print(f"[{method_name}] found {len(questions)} questions")

    out_path = os.path.join(out_dir, "problem_with_cases.jsonl")

    written_questions = set()
    if os.path.exists(out_path):
        with open(out_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    if "question" in obj:
                        written_questions.add(obj["question"])
                except Exception:
                    continue
        print(f"[{method_name}] already has {len(written_questions)} questions in {out_path}")

    with open(out_path, 'a', encoding='utf-8') as of:
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=min(max_question_workers, max(1, len(questions)))
        ) as ex:
            future_to_q = {}
            for q in questions:
                if q in written_questions:
                    continue
                opt_entry = opt_map.get(q)
                brute_entry = brute_map.get(q)
                future = ex.submit(
                    process_single_question,
                    q, opt_entry, brute_entry,
                    opt_inputs_map, brute_valid_cases_map
                )
                future_to_q[future] = q

            for fut in concurrent.futures.as_completed(future_to_q):
                q = future_to_q[fut]
                try:
                    res = fut.result()
                    if res:
                        to_save = {
                            "problem_id": res.get("problem_id", ""),
                            "question": res.get("question", ""),
                            "tags": res.get("tags", []),
                            "skills": res.get("skills", None),
                            "difficulty": res.get("difficulty", ""),
                            "adv_test_cases": res.get("valid_test_cases") or res.get("adv_test_cases") or [],
                            "rand_test_cases": res.get("rand_test_cases") or [],
                            "corner_test_cases": res.get("corner_test_cases") or [],
                            "valid_codes": res.get("valid_codes", []),
                            "time_limits": res.get("time_limits", 5),
                        }
                        print('******************corner************************')
                        print(to_save['corner_test_cases'])
                        print('*************************corner*****************')
                        print('len of test cases - - - - ', len(to_save['adv_test_cases'])+len(to_save['rand_test_cases'])+len(to_save['corner_test_cases']))
                        of.write(json.dumps(to_save, ensure_ascii=False) + "\n")
                        of.flush()
                        print(f"[{method_name}] question processed & saved: {res.get('problem_id','')}")
                    else:
                        print(f"[{method_name}] question skipped (no result): {q[:30]}...")
                except Exception as e:
                    logging.error(f"Error processing {q}: {str(e)}", exc_info=True)

    print(f"[{method_name}] completed processing {len(questions)} questions -> appended to {out_path}")


def main():
    opt_dirs = []
    brute_dirs = []
    if os.path.exists(OPT_BASE):
        opt_dirs = [d for d in os.listdir(OPT_BASE) if os.path.isdir(os.path.join(OPT_BASE, d))]
    if os.path.exists(BRUTE_BASE):
        brute_dirs = [d for d in os.listdir(BRUTE_BASE) if os.path.isdir(os.path.join(BRUTE_BASE, d))]

    methods = sorted(set(opt_dirs) | set(brute_dirs))
    print(f"Found {len(methods)} method dirs to process.")

    with concurrent.futures.ThreadPoolExecutor(max_workers=min(MAX_METHOD_WORKERS, max(1, len(methods)))) as ex:
        futures = {ex.submit(process_method, m): m for m in methods}
        for fut in concurrent.futures.as_completed(futures):
            m = futures[fut]
            try:
                fut.result()
            except Exception as e:
                print(f"[error] method {m} failed: {e}")


if __name__ == "__main__":
    main()
