import os
import json
import argparse
import signal
import pandas as pd
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

EXEC_TIMEOUT = 30 


def _make_json_safe(obj):
    """Recursively convert arbitrary Python objects to JSON-serializable form.
    Handles: tuple/complex/set/frozenset/bytes values AND non-string dict keys."""
    if obj is None or isinstance(obj, (bool, int, float, str)):
        return obj
    if isinstance(obj, dict):
        return {str(k): _make_json_safe(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_make_json_safe(x) for x in obj]
    if isinstance(obj, (set, frozenset)):
        return [_make_json_safe(x) for x in obj]
    if isinstance(obj, complex):
        return str(obj)
    if isinstance(obj, bytes):
        return obj.decode('utf-8', errors='replace')
    return str(obj)


class _ExecutionTimeout(Exception):
    pass


def _alarm_handler(signum, frame):
    raise _ExecutionTimeout()

def prepare_messages(user_prompt, prompt_format=None):
    """Prepare messages for chat template"""
    return [{"role": "user", "content": user_prompt}]

def generate_local(gen_model, gen_tokenizer, device, sys_prompt, user_prompt,
                   max_new_tokens=1024, temperature=0.7, top_p=0.9, prompt_format=None):
    """Unified generation function"""
    messages = prepare_messages(user_prompt, prompt_format=prompt_format)

    if sys_prompt and sys_prompt.strip():
        messages.insert(0, {"role": "system", "content": sys_prompt})

    inputs = gen_tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt", 
        padding=True, 
        truncation=True,
        return_dict=True,
        add_generation_prompt=True
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        out = gen_model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens, 
            do_sample=temperature>0,
            temperature=temperature, 
            top_p=top_p, 
            pad_token_id=gen_tokenizer.eos_token_id
        )

    seq = out[0]
    inp_len = inputs["input_ids"].shape[1]
    text = gen_tokenizer.decode(seq[inp_len:], skip_special_tokens=True).strip()
    if text.lower().startswith("assistant"):
        text = text[len("assistant"):].lstrip()
    return text

def extract_python_code(text):
    """Extract Python code from generated text"""
    code_block_pattern = r'```python\n(.*?)\n```'
    matches = re.findall(code_block_pattern, text, re.DOTALL)
    if matches:
        return matches[0].strip()

    lines = text.split('\n')
    code_lines = []
    in_function = False

    for line in lines:
        if line.strip().startswith('def '):
            in_function = True
            code_lines.append(line)
        elif in_function:
            if line.strip() == '' or line.startswith(' ') or line.startswith('\t'):
                code_lines.append(line)
            else:
                if not line.strip().startswith('def '):
                    break
                code_lines.append(line)

    if code_lines:
        return '\n'.join(code_lines)

    return text.strip()

def parse_assert_test(assert_statement):
    """Parse assert statement to extract inputs and expected output"""
    try:
        expression = assert_statement.replace('assert ', '').strip()

        if '==' in expression:
            func_call, expected = expression.split('==', 1)
            func_call = func_call.strip()
            expected = expected.strip()

            func_match = re.match(r'(\w+)\((.*)\)', func_call)
            if func_match:
                func_name = func_match.group(1)
                args_str = func_match.group(2)

                try:
                    args = eval(f"[{args_str}]")
                    expected_result = eval(expected)

                    return {
                        'function_name': func_name,
                        'inputs': args,
                        'expected': expected_result
                    }
                except:
                    return None
    except:
        return None

    return None

def convert_mbpp_tests(test_list):
    """Convert MBPP assert-style tests to our format, preserving function_name."""
    converted_tests = []

    for test_assert in test_list:
        parsed = parse_assert_test(test_assert)
        if parsed:
            converted_tests.append({
                'function_name': parsed['function_name'],
                'inputs': parsed['inputs'],
                'expected': parsed['expected']
            })

    return converted_tests


def infer_function_name(test_cases, test_list=None):
    """Extract function name from test cases or raw test_list (e.g. for math.isclose tests)."""
    for tc in test_cases:
        if tc.get('function_name'):
            return tc['function_name']
    skip = {'set', 'math', 'True', 'False', 'all', 'any', 'sorted', 'list', 'tuple',
            'isclose', 'isinstance', 'len', 'int', 'float', 'str', 'abs', 'round',
            'max', 'min', 'sum', 'map', 'filter', 'zip', 'range', 'enumerate',
            'type', 'print', 'dict', 'frozenset', 'ord', 'chr', 'hex', 'bin', 'oct'}
    if test_list:
        for stmt in test_list:
            for m in re.finditer(r'(\w+)\s*\(', stmt):
                name = m.group(1)
                if name not in skip and not name.startswith('assert'):
                    return name
    return None

def run_tests_detailed(code, test_cases, test_list_raw=None, test_imports=None, timeout=EXEC_TIMEOUT):
    use_raw = test_list_raw is not None and len(test_list_raw) > 0
    if use_raw:
        total = len(test_list_raw)
    elif test_cases:
        total = len(test_cases)
    else:
        return 0.0, 0, 0, [], []

    passed = 0
    errors = []
    test_results = []

    if use_raw:
        # Raw assert execution (sanitized MBPP format)
        for i, assert_stmt in enumerate(test_list_raw):
            prev_handler = signal.signal(signal.SIGALRM, _alarm_handler)
            signal.alarm(timeout)
            try:
                namespace = {}
                for imp in (test_imports or []):
                    exec(imp, namespace)
                exec(code, namespace)
                exec(assert_stmt, namespace)
                passed += 1
                test_results.append({
                    'test_id': i + 1,
                    'inputs': None,
                    'expected': None,
                    'actual': None,
                    'passed': True,
                    'error': None
                })
            except AssertionError:
                errors.append(f"Test {i+1}: Assertion failed")
                test_results.append({
                    'test_id': i + 1,
                    'inputs': None,
                    'expected': None,
                    'actual': None,
                    'passed': False,
                    'error': 'Assertion failed'
                })
            except _ExecutionTimeout:
                error_msg = f"Test {i+1}: Timed out ({timeout}s)"
                errors.append(error_msg)
                test_results.append({
                    'test_id': i + 1, 'inputs': None, 'expected': None, 'actual': None,
                    'passed': False, 'error': error_msg
                })
            except Exception as e:
                error_msg = f"Test {i+1}: Runtime error - {str(e)}"
                errors.append(error_msg)
                test_results.append({
                    'test_id': i + 1, 'inputs': None, 'expected': None, 'actual': None,
                    'passed': False, 'error': str(e)
                })
            finally:
                signal.alarm(0)
                signal.signal(signal.SIGALRM, prev_handler)
    else:
        # Parsed test_cases execution
        expected_func = None
        for tc in test_cases:
            if tc.get('function_name'):
                expected_func = tc['function_name']
                break

        for i, test_case in enumerate(test_cases):
            prev_handler = signal.signal(signal.SIGALRM, _alarm_handler)
            signal.alarm(timeout)
            try:
                namespace = {}
                for imp in (test_imports or []):
                    exec(imp, namespace)
                exec(code, namespace)

                func_name = expected_func if (expected_func and expected_func in namespace) else None
                if func_name is None:
                    for name, obj in namespace.items():
                        if callable(obj) and not name.startswith('_'):
                            func_name = name
                            break

                if func_name is None:
                    error_msg = f"Test {i+1}: No callable function found"
                    errors.append(error_msg)
                    test_results.append({
                        'test_id': i+1,
                        'inputs': test_case.get('inputs', []),
                        'expected': test_case.get('expected'),
                        'actual': None,
                        'passed': False,
                        'error': error_msg
                    })
                    continue

                inputs = test_case.get('inputs', [])
                expected = test_case.get('expected')

                if isinstance(inputs, list):
                    result = namespace[func_name](*inputs)
                else:
                    result = namespace[func_name](inputs)

                test_passed = result == expected
                if test_passed:
                    passed += 1

                test_results.append({
                    'test_id': i+1,
                    'inputs': inputs,
                    'expected': expected,
                    'actual': result,
                    'passed': test_passed,
                    'error': None
                })

                if not test_passed:
                    errors.append(f"Test {i+1}: Expected {expected}, got {result}")

            except _ExecutionTimeout:
                error_msg = f"Test {i+1}: Timed out ({timeout}s)"
                errors.append(error_msg)
                test_results.append({
                    'test_id': i+1,
                    'inputs': test_case.get('inputs', []),
                    'expected': test_case.get('expected'),
                    'actual': None,
                    'passed': False,
                    'error': error_msg
                })
            except Exception as e:
                error_msg = f"Test {i+1}: Runtime error - {str(e)}"
                errors.append(error_msg)
                test_results.append({
                    'test_id': i+1,
                    'inputs': test_case.get('inputs', []),
                    'expected': test_case.get('expected'),
                    'actual': None,
                    'passed': False,
                    'error': str(e)
                })
            finally:
                signal.alarm(0)
                signal.signal(signal.SIGALRM, prev_handler)

    pass_rate = passed / total if total > 0 else 0.0
    return pass_rate, passed, total, errors, test_results

def evaluate_code_detailed(code, test_cases, test_list_raw=None, test_imports=None):
    total_ref = len(test_list_raw) if test_list_raw else len(test_cases)
    if not code or not code.strip():
        return 0.0, 0, total_ref, [], []

    pass_rate, passed, total, errors, test_results = run_tests_detailed(
        code, test_cases, test_list_raw=test_list_raw, test_imports=test_imports
    )
    test_score = pass_rate * 10

    return test_score, passed, total, test_results, errors

class TMPCGenerator:
    def __init__(self, gen_model, gen_tokenizer, device):
        self.gen_model = gen_model
        self.gen_tokenizer = gen_tokenizer
        self.device = device

    @staticmethod
    def _buf_key(entry):
        """Sort key for buffer: (test_passed, score) descending."""
        return (entry.get('test_passed', 0), entry.get('score', 0.0))

    def generate_exploration_candidates(self, problem_description, buffer, iteration,
                                        num_candidates=3, func_name=None,
                                        stagnant_iters=0):
        """Generate candidates using buffer-based subgoal conditioning.

        Aligned with TMPC Algorithm 1: all K candidates condition on the buffer.
        When stagnation is detected (buffer hasn't improved), escalate temperature
        and vary the refinement angle to escape local optima.
        """
        candidates = []
        total = buffer[0].get('test_total', 0) if buffer else 0
        best_passed = buffer[0].get('test_passed', 0) if buffer else 0

        # Escalate temperature when stagnant
        if stagnant_iters == 0:
            temps = [0.4, 0.5, 0.6]
        elif stagnant_iters == 1:
            temps = [0.6, 0.7, 0.8]
        elif stagnant_iters == 2:
            temps = [0.8, 0.9, 1.0]
        else:
            temps = [0.9, 1.0, 1.0]

        # Refinement angles: rotate through different diagnostic hints per iteration
        angles = [
            None,
            "edge_cases",
            "algorithm_rethink",
            "simplify",
            "reinterpret",
        ]
        angle = angles[iteration % len(angles)] if stagnant_iters >= 1 else None

        if stagnant_iters >= 3:
            n_diverse_override = num_candidates
        elif stagnant_iters >= 2:
            n_diverse_override = max(1, num_candidates - 1)
        elif stagnant_iters >= 1:
            n_diverse_override = 1
        else:
            n_diverse_override = 0

        n_conditioned_actual = num_candidates - n_diverse_override

        for i in range(n_conditioned_actual):
            buf_entry = buffer[min(i, len(buffer) - 1)]
            code = buf_entry.get('code', '')
            passed = buf_entry.get('test_passed', 0)
            temp = temps[i] if i < len(temps) else temps[-1]
            if not code:
                continue

            if passed > 0:
                candidate = self._generate_from_subgoal(
                    problem_description, code, passed, total, func_name,
                    temperature=temp, angle=angle)
            else:
                candidate = self._generate_from_failed(
                    problem_description, code, total, func_name,
                    temperature=temp, angle=angle)
            if candidate:
                candidates.append(candidate)

        # Fill remaining slots with diverse exploration
        for i in range(n_diverse_override):
            diverse_idx = iteration * num_candidates + i + stagnant_iters * 7
            temp = temps[n_conditioned_actual + i] if (n_conditioned_actual + i) < len(temps) else temps[-1]
            candidate = self._generate_diverse(
                problem_description, diverse_idx, func_name=func_name,
                temperature=temp)
            if candidate:
                candidates.append(candidate)

        return candidates[:num_candidates]

    def _generate_from_subgoal(self, problem_description, base_code, passed, total,
                               func_name=None, temperature=0.4, angle=None):
        func_hint = f"\nThe solution must define a function named `{func_name}`." if func_name else ""

        angle_hint = ""
        if angle == "edge_cases":
            angle_hint = ("\nThe remaining failing tests likely involve edge cases: "
                          "zero, negative numbers, empty inputs, single elements, very large values, "
                          "or boundary conditions. Pay special attention to these.")
        elif angle == "algorithm_rethink":
            angle_hint = ("\nThe current algorithm may be fundamentally flawed for some inputs. "
                          "Consider whether a completely different approach or mathematical property would work better.")
        elif angle == "simplify":
            angle_hint = ("\nThe current solution may be over-complicated. "
                          "Consider if there is a much simpler formula, pattern, or one-liner that solves this correctly.")
        elif angle == "reinterpret":
            angle_hint = ("\nRe-read the problem statement very carefully. "
                          "The failing tests may require a different interpretation of what the problem is asking.")

        sys_prompt = f"""You are a Python expert.
You have a program that passes {passed} out of {total} test cases.
Analyze what could cause some tests to fail and produce an improved version.{angle_hint}
{MBPP_STYLE}{func_hint}
You must only output a single, complete Python code block.
Do not include any explanations or surrounding text."""

        user_prompt = f"""Problem: {problem_description}

CURRENT BEST ({passed}/{total} tests passing):
```python
{base_code}
```

Write an improved version that passes all {total} test cases:"""

        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                      sys_prompt, user_prompt, max_new_tokens=512, temperature=temperature)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating from subgoal: {e}")
            return None

    def _generate_from_failed(self, problem_description, base_code, total,
                              func_name=None, temperature=0.5, angle=None):
        func_hint = f"\nThe solution must define a function named `{func_name}`." if func_name else ""

        angle_hint = ""
        if angle == "edge_cases":
            angle_hint = ("\nThe code likely has a subtle bug: wrong variable name, off-by-one, "
                          "wrong operator, or incorrect return type. Trace through a simple example mentally.")
        elif angle == "algorithm_rethink":
            angle_hint = ("\nThe entire approach may be wrong. "
                          "Consider a completely different algorithm or formula.")
        elif angle == "simplify":
            angle_hint = ("\nTry the simplest possible implementation. "
                          "A direct, naive solution often works better than a clever one.")
        elif angle == "reinterpret":
            angle_hint = ("\nRe-read the problem carefully. "
                          "You may be misunderstanding what the function should compute or return.")

        sys_prompt = f"""You are a Python expert.
A previous attempt passed 0 out of {total} test cases.
The code is likely wrong in its core logic — do not just tweak it, rethink the solution.{angle_hint}
{MBPP_STYLE}{func_hint}
You must only output a single, complete Python code block.
Do not include any explanations or surrounding text."""

        user_prompt = f"""Problem: {problem_description}

FAILED ATTEMPT (0/{total} tests passed):
```python
{base_code}
```

Write an improved version that passes all {total} test cases:"""

        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                      sys_prompt, user_prompt, max_new_tokens=512, temperature=temperature)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating from failed attempt: {e}")
            return None

    def _generate_diverse(self, problem_description, iteration, func_name=None,
                          temperature=0.7):
        approaches = [
            "Break the problem into smaller sub-problems. Solve each independently, then combine.",
            "Identify the core mathematical or logical pattern and implement the simplest correct solution.",
            "Choose the most suitable data structure (list, dict, set, etc.) and build around it.",
            "Start with a straightforward brute-force approach that prioritizes correctness over efficiency.",
            "Handle boundary conditions and special cases first, then generalize.",
            "Read the problem statement very literally. Make sure every word is accounted for in your solution.",
            "Think about what the expected output type and range should be, then work backwards from that.",
            "Consider if there is a well-known algorithm, formula, or sequence definition that exactly matches this problem.",
            "Think step by step: what does each input map to? Trace through 2-3 small examples mentally before coding.",
            "Look up the exact mathematical definition implied by the problem name. Implement it directly with memoization or dynamic programming if recursive.",
            "If the problem involves comparing two sequences, consider building a bijective mapping between their elements.",
            "If the problem mentions a named mathematical sequence, write out the recurrence relation explicitly, then implement it with a lookup table.",
            "Use a dictionary to track relationships or mappings between elements. Check consistency at each step.",
            "If your first approach seems natural but might be wrong, try the opposite interpretation of the problem.",
            "Implement using dynamic programming with an array. Fill values iteratively from base cases.",
        ]
        approach = approaches[iteration % len(approaches)]
        func_hint = f"\nThe solution must define a function named `{func_name}`." if func_name else ""

        sys_prompt = f"""You are a Python expert.
{approach}
{MBPP_STYLE}{func_hint}
You must only output a single, complete Python code block.
Do not include any explanations or surrounding text."""

        user_prompt = f"""Problem: {problem_description}

Write a complete, correct Python solution:"""

        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                      sys_prompt, user_prompt, max_new_tokens=512, temperature=temperature)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating diverse solution: {e}")
            return None

MBPP_STYLE = "Output only the function. Its return value is checked with exact equality—match the type and format the problem expects."
MBPP_INITIAL_MAX_TOKENS = 512
SANITIZED_MBPP_URL = "https://raw.githubusercontent.com/google-research/google-research/master/mbpp/sanitized-mbpp.json"


def build_mbpp_zero_shot_prompt(problem_description, func_name=None):
    func_hint = ""
    if func_name:
        func_hint = f"\nThe solution must define a function named `{func_name}`."

    sys_prompt = f"""You are a Python expert.
Generate a working solution for the problem. Focus on correctness over complexity.
{MBPP_STYLE}{func_hint}
Output only clean Python code. Do not include any explanations or surrounding text."""

    user_prompt = f"""Problem: {problem_description}

Write a complete, correct Python solution:"""
    return sys_prompt, user_prompt


def load_sanitized_mbpp(file_path=None):
    """Load sanitized MBPP dataset (JSON array).
    Uses file_path if provided and exists; else ./sanitized-mbpp.json; else downloads from URL.
    Format: [{task_id, prompt, code, test_list, test_imports, source_file}, ...]
    """
    import urllib.request
    path = file_path or "sanitized-mbpp.json"
    if os.path.isfile(path):
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    try:
        with urllib.request.urlopen(SANITIZED_MBPP_URL, timeout=30) as resp:
            data = json.loads(resp.read().decode('utf-8'))
        # Cache locally for next run
        with open("sanitized-mbpp.json", 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=None)
        return data
    except Exception as e:
        raise FileNotFoundError(
            f"Could not load sanitized MBPP. Tried {path} and {SANITIZED_MBPP_URL}. Error: {e}"
        ) from e

def run_tmpc_mbpp(args):
    device = torch.device(f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    gen_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    print(f"Loading model: {gen_model_name}")

    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
    if gen_tokenizer.pad_token is None:
        gen_tokenizer.pad_token = gen_tokenizer.eos_token

    if getattr(args, 'load_in_4bit', False):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        gen_model = AutoModelForCausalLM.from_pretrained(
            gen_model_name,
            quantization_config=bnb_config,
            device_map={"": device},
            trust_remote_code=True,
        )
    else:
        gen_model = AutoModelForCausalLM.from_pretrained(
            gen_model_name, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True
        ).to(device)
    gen_model.eval()

    data = load_sanitized_mbpp(args.input_file)
    os.makedirs(args.output_folder, exist_ok=True)

    print(f"Loaded {len(data)} problems from sanitized MBPP")
    print(f"Processing problems index {args.start} to {args.end - 1}")

    tmpc_generator = TMPCGenerator(gen_model, gen_tokenizer, device)

    completed = 0
    sum_pass_rate = 0.0
    perfect_count = 0

    for idx, item in enumerate(data):
        if idx < args.start or idx >= args.end:
            continue

        task_id = item.get('task_id', idx)
        problem_description = item.get('prompt', item.get('text', ''))
        test_list = item.get('test_list', [])
        test_imports = item.get('test_imports', [])

        test_cases = convert_mbpp_tests(test_list)
        func_name = infer_function_name(test_cases, test_list)

        print(f"\nProblem {task_id}: {problem_description[:80]}...")
        print(f"  Converted {len(test_cases)} test cases (function: {func_name})")

        sys_prompt, user_prompt = build_mbpp_zero_shot_prompt(problem_description, func_name)

        initial_code = generate_local(
            gen_model, gen_tokenizer, device, sys_prompt, user_prompt,
            temperature=0.6, max_new_tokens=MBPP_INITIAL_MAX_TOKENS
        )
        initial_code = extract_python_code(initial_code)

        initial_score, initial_passed, initial_total, _, initial_errors = evaluate_code_detailed(
            initial_code, test_cases, test_list_raw=test_list, test_imports=test_imports
        )

        print(f"  Initial: {initial_score:.2f} (tests: {initial_passed}/{initial_total})")

        buffer_size = 3
        buffer = [{'code': initial_code, 'score': initial_score,
                    'test_passed': initial_passed, 'test_total': initial_total}]

        history = {
            0: {
                "task_id": task_id,
                "problem": problem_description,
                "original_tests": test_list,
                "code": initial_code,
                "score": initial_score,
                "test_passed": initial_passed,
                "test_total": initial_total,
            }
        }

        if initial_passed == initial_total:
            print(f"  All tests passed initially!")

        prev_best_passed = initial_passed
        stagnant_iters = 0

        for iteration in range(1, args.max_iterations + 1):
            print(f"  Iteration {iteration} (buffer: {len(buffer)} entries, "
                  f"best={buffer[0]['test_passed']}/{initial_total}, stagnant={stagnant_iters})")

            candidates = tmpc_generator.generate_exploration_candidates(
                problem_description, buffer, iteration,
                num_candidates=args.num_candidates, func_name=func_name,
                stagnant_iters=stagnant_iters
            )

            print(f"    Generated {len(candidates)} candidates")

            for j, candidate in enumerate(candidates):
                if not candidate:
                    continue
                score, passed, total, _, errors = evaluate_code_detailed(
                    candidate, test_cases, test_list_raw=test_list, test_imports=test_imports
                )
                print(f"      Candidate {j+1}: {score:.2f} (tests: {passed}/{total})")

                entry = {'code': candidate, 'score': score,
                         'test_passed': passed, 'test_total': total}
                worst = min(buffer, key=TMPCGenerator._buf_key) if buffer else None
                if len(buffer) < buffer_size or TMPCGenerator._buf_key(entry) > TMPCGenerator._buf_key(worst):
                    buffer.append(entry)
                    buffer.sort(key=TMPCGenerator._buf_key, reverse=True)
                    buffer = buffer[:buffer_size]

            best = buffer[0]
            history[iteration] = {
                "task_id": task_id,
                "problem": problem_description,
                "original_tests": test_list,
                "code": best['code'],
                "score": best['score'],
                "test_passed": best['test_passed'],
                "test_total": initial_total,
            }

            if best['test_passed'] > prev_best_passed:
                stagnant_iters = 0
                prev_best_passed = best['test_passed']
            else:
                stagnant_iters += 1

            print(f"    Best: {best['score']:.2f} (tests: {best['test_passed']}/{initial_total})")

            if initial_total > 0 and best['test_passed'] == initial_total and iteration >= 2:
                print(f"    All tests passed! Moving to next problem.")
                break

        # Save results
        output_path = os.path.join(args.output_folder, f"problem_{task_id}.json")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(_make_json_safe(history), f, ensure_ascii=False, indent=2)

        best_passed = buffer[0]['test_passed']
        final_pass_rate = best_passed / initial_total if initial_total > 0 else 0.0
        completed += 1
        sum_pass_rate += final_pass_rate
        if final_pass_rate >= 1.0 - 1e-9:
            perfect_count += 1

        avg_pass_rate = sum_pass_rate / completed
        print(f"  ── Problem {task_id} done | pass_rate={final_pass_rate:.2f} | "
              f"running avg={avg_pass_rate:.4f} | perfect={perfect_count}/{completed} "
              f"({perfect_count/completed*100:.1f}%)")

    print(f"\nTMPC MBPP completed! "
          f"Problems: {completed} | Avg pass rate: {sum_pass_rate/max(completed,1):.4f} | "
          f"Perfect: {perfect_count}/{completed} ({perfect_count/max(completed,1)*100:.1f}%)")

def evaluate_results_strict(
    folder_path: str,
    max_iteration: int,
    output_file: str,
    max_id: int = 809,
    id_start: int = 2,
    filename_prefix: str = "problem_",
    discover_ids: bool = False,
    dataset_size: int = None
):
    """
    Evaluate results for problem_{task_id}.json.

    When discover_ids=True: find all problem_*.json in folder and evaluate only those.
    Otherwise: evaluate problem_{id_start}..problem_{max_id}.json (missing => zeros).

    Rules:
      - Missing file => counted with zeros, error='missing_file' (when not discover_ids)
      - Load error => zeros, error='load_error: ...'
      - File exists but no valid iterations <= max_iteration => zeros, error='no_valid_iters'
      - Valid file: pick the iteration (<= max_iteration) with max 'score'

    Output:
      - CSV with columns:
        task_id, problem, initial_pass_rate, initial_perfect, best_iteration, score, pass_rate,
        perfect, code, error
      - Console summary: Initial vs Final (TMPC) with improvement
    """
    import glob
    expected_ids = []
    if discover_ids:
        pattern = os.path.join(folder_path, f"{filename_prefix}*.json")
        for p in glob.glob(pattern):
            base = os.path.basename(p)
            stem = base.replace(filename_prefix, "").replace(".json", "")
            try:
                expected_ids.append(int(stem))
            except ValueError:
                pass
        expected_ids = sorted(set(expected_ids))
        # Filter by eval_start..eval_range when specified (id_start, max_id)
        expected_ids = [x for x in expected_ids if id_start <= x <= max_id]
    else:
        expected_ids = list(range(id_start, max_id + 1))
    records = []

    for task_id in expected_ids:
        file_path = os.path.join(folder_path, f"{filename_prefix}{task_id}.json")

        if not os.path.exists(file_path):
            records.append({
                'task_id': task_id, 'problem': '', 'initial_pass_rate': 0.0, 'initial_perfect': 0,
                'best_iteration': None, 'score': 0.0, 'test_passed': 0, 'test_total': 0,
                'pass_rate': 0.0, 'perfect': 0, 'code': None, 'error': 'missing_file'
            })
            continue

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            records.append({
                'task_id': task_id, 'problem': '', 'initial_pass_rate': 0.0, 'initial_perfect': 0,
                'best_iteration': None, 'score': 0.0, 'test_passed': 0, 'test_total': 0,
                'pass_rate': 0.0, 'perfect': 0, 'code': None, 'error': f'load_error: {e}'
            })
            continue

        try:
            valid_iters = {}
            for k, v in data.items():
                try:
                    ki = int(k)
                except Exception:
                    continue
                if ki <= max_iteration:
                    valid_iters[ki] = v
        except Exception as e:
            valid_iters = {}

        if not valid_iters:
            records.append({
                'task_id': data.get('task_id', task_id), 'problem': (data.get('problem') or ''),
                'initial_pass_rate': 0.0, 'initial_perfect': 0,
                'best_iteration': None, 'score': 0.0, 'test_passed': 0, 'test_total': 0,
                'pass_rate': 0.0, 'perfect': 0, 'code': None, 'error': 'no_valid_iters'
            })
            continue

        def safe_score(it_dict, default=0.0):
            try:
                s = it_dict.get('score', default)
                return float(s) if s is not None else default
            except Exception:
                return default

        # Initial (iteration 0) metrics
        init_result = valid_iters.get(0)
        if init_result is not None:
            init_passed = int(init_result.get('test_passed', 0) or 0)
            init_total = int(init_result.get('test_total', 0) or 0)
            initial_pass_rate = (init_passed / init_total) if init_total > 0 else 0.0
            initial_perfect = 1 if initial_pass_rate >= 0.999 else 0
        else:
            initial_pass_rate = 0.0
            initial_perfect = 0

        # Final (best across iterations)
        best_iter = max(valid_iters, key=lambda x: safe_score(valid_iters[x], 0.0))
        best_result = valid_iters[best_iter]
        score = safe_score(best_result, 0.0)
        test_passed = int(best_result.get('test_passed', 0) or 0)
        test_total = int(best_result.get('test_total', 0) or 0)
        pass_rate = (test_passed / test_total) if test_total > 0 else 0.0
        perfect = 1 if pass_rate >= 0.999 else 0

        records.append({
            'task_id': best_result.get('task_id', task_id),
            'problem': (best_result.get('problem') or ''),
            'initial_pass_rate': initial_pass_rate,
            'initial_perfect': initial_perfect,
            'best_iteration': best_iter,
            'score': score,
            'test_passed': test_passed,
            'test_total': test_total,
            'pass_rate': pass_rate,
            'perfect': perfect,
            'code': best_result.get('code'),
            'error': None
        })

    df = pd.DataFrame(records)
    df.to_csv(output_file, index=False)

    if not df.empty:
        valid_df = df[df['error'].isna()]
        errors = df['error'].notna().sum()
        n_eval = len(valid_df)

        if n_eval > 0:
            init_avg_pass = valid_df['initial_pass_rate'].mean()
            init_perfect = valid_df['initial_perfect'].sum()
            init_perfect_ratio = init_perfect / n_eval

            final_avg_pass = valid_df['pass_rate'].mean()
            final_perfect = valid_df['perfect'].sum()
            final_perfect_ratio = final_perfect / n_eval

            delta_pass = final_avg_pass - init_avg_pass
            delta_perfect = final_perfect - init_perfect

        print(f"\nResults saved to {output_file}")
        if discover_ids:
            size_info = f" (of {dataset_size} in dataset)" if dataset_size else ""
            print(f"Problems evaluated: {n_eval}{size_info} (discovered from folder)")
        else:
            print(f"Problems considered: {len(df)} (IDs {id_start}..{max_id})")
        if dataset_size and discover_ids:
            print(f"Dataset size (sanitized MBPP): {dataset_size}")
        print(f"Errors (missing/load/no_valid_iters): {errors}")

        if n_eval > 0:
            print(f"\n--- Initial (zero-shot, iter 0) ---")
            print(f"  Average pass rate: {init_avg_pass:.3f}")
            print(f"  Perfect pass: {init_perfect}/{n_eval} ({init_perfect_ratio*100:.1f}%)")

            print(f"\n--- Final (TMPC, best across iters) ---")
            print(f"  Average pass rate: {final_avg_pass:.3f}")
            print(f"  Perfect pass: {final_perfect}/{n_eval} ({final_perfect_ratio*100:.1f}%)")

            print(f"\n--- Improvement ---")
            print(f"  Pass rate: {delta_pass:+.3f}")
            print(f"  Perfect count: {delta_perfect:+d}")
    else:
        print("No results produced (empty DataFrame).")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="TMPC for MBPP (sanitized)")

    parser.add_argument("--input_file", type=str, default=None,
                        help="Sanitized MBPP JSON path (default: sanitized-mbpp.json or download)")
    parser.add_argument("--output_folder", type=str, help="Output folder")
    parser.add_argument("--max_iterations", type=int, default=5, help="Max iterations")
    parser.add_argument("--num_candidates", type=int, default=3, help="Candidates per iteration (2 subgoal + 1 diverse)")
    parser.add_argument("--cuda_num", type=int, default=0, help="CUDA device index")
    parser.add_argument("--load_in_4bit", action="store_true",
                        help="Load model in 4-bit quantization (saves ~12GB VRAM)")
    parser.add_argument("--start", type=int, default=0, help="Start index (0-based)")
    parser.add_argument("--end", type=int, default=427, help="End index (sanitized has 427 problems)")

    parser.add_argument("--evaluate", action="store_true", help="Evaluation mode")
    parser.add_argument("--eval_folder", type=str, help="Evaluation folder")
    parser.add_argument("--eval_max_iter", type=int, default=0, help="Max iteration to evaluate")
    parser.add_argument("--eval_output", type=str, default="tmpc_results.csv", help="Evaluation output file")
    parser.add_argument("--eval_range", type=int, default=809, help="Max task_id to evaluate (sanitized)")
    parser.add_argument("--eval_start", type=int, default=2, help="Start task_id to evaluate (sanitized)")
    parser.add_argument("--eval_from_folder", action="store_true",
                        help="Discover task_ids from folder (default when --evaluate)")
    parser.add_argument("--eval_full_range", action="store_true",
                        help="Evaluate full range (id_start..max_id) instead of folder discovery")
    parser.add_argument("--eval_from_dataset", action="store_true",
                        help="Load sanitized MBPP to report dataset size (N of M in dataset)")

    args = parser.parse_args()

    if args.evaluate:
        dataset_size = None
        if getattr(args, 'eval_from_dataset', False):
            try:
                data = load_sanitized_mbpp(args.input_file)
                dataset_size = len(data)
            except Exception:
                pass
        # Default: discover from folder (only evaluate existing files). Use --eval_full_range to override.
        discover_ids = not getattr(args, 'eval_full_range', False)
        evaluate_results_strict(
            args.eval_folder, args.eval_max_iter, args.eval_output,
            max_id=args.eval_range, id_start=args.eval_start,
            discover_ids=discover_ids,
            dataset_size=dataset_size
        )
    else:
        run_tmpc_mbpp(args)