import ast
import json
import re
import torch
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass
from transformers import StoppingCriteria, StoppingCriteriaList

FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"


class SingleStatementStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer, prompt_length: int):
        self.tokenizer = tokenizer
        self.prompt_length = prompt_length

    def _is_complete_statement(self, code: str) -> bool:
        code = code.strip()
        if not code:
            return False

        paren_depth = 0
        bracket_depth = 0
        brace_depth = 0
        in_string = False
        string_char = None

        for i, char in enumerate(code):
            if char in ('"', "'") and (i == 0 or code[i-1] != '\\'):
                if not in_string:
                    in_string = True
                    string_char = char
                elif char == string_char:
                    in_string = False
                continue

            if in_string:
                continue

            if char == '(':
                paren_depth += 1
            elif char == ')':
                paren_depth -= 1
            elif char == '[':
                bracket_depth += 1
            elif char == ']':
                bracket_depth -= 1
            elif char == '{':
                brace_depth += 1
            elif char == '}':
                brace_depth -= 1

        return (paren_depth == 0 and bracket_depth == 0 and brace_depth == 0
                and code.rstrip().endswith(')'))

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids.shape[1] <= self.prompt_length:
            return False

        generated_ids = input_ids[0, self.prompt_length:]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

        if '\n' in generated_text:
            first_part = generated_text.split('\n')[0]
            if self._is_complete_statement(first_part):
                return True

        return False


@dataclass
class MatchedRegion:
    start_lineno: int
    end_lineno: int
    start_col: int
    end_col: int
    original_code: str
    guidance_item: Dict[str, Any]


def parse_guidance(guidance_path: str) -> Dict[str, Any]:
    with open(guidance_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def read_source_code(source_path: str) -> str:
    with open(source_path, 'r', encoding='utf-8') as f:
        return f.read()


def extract_run_skill_function(code: str) -> Tuple[Optional[ast.FunctionDef], Optional[str]]:
    tree = ast.parse(code)
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and node.name == 'run_skill':
            lines = code.split('\n')
            func_lines = lines[node.lineno - 1:node.end_lineno]
            func_code = '\n'.join(func_lines)
            return node, func_code
    return None, None


def match_call_node(node: ast.Call, match_spec: Dict[str, Any]) -> bool:
    func_spec = match_spec.get('func', {})
    if func_spec.get('type') == 'Name':
        expected_id = func_spec.get('id')
        if not isinstance(node.func, ast.Name) or node.func.id != expected_id:
            return False
    elif func_spec.get('type') == 'Attribute':
        expected_attr = func_spec.get('attr')
        if not isinstance(node.func, ast.Attribute) or node.func.attr != expected_attr:
            return False

    keywords_required = match_spec.get('keywords_required', [])
    if keywords_required:
        node_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
        if not all(kw in node_keywords for kw in keywords_required):
            return False

    return True


def find_matching_regions(
    code: str,
    guidance_items: List[Dict[str, Any]]
) -> List[MatchedRegion]:
    tree = ast.parse(code)
    lines = code.split('\n')
    matched_regions = []

    for guidance_item in guidance_items:
        match_spec = guidance_item.get('match', {})
        node_type = match_spec.get('node', '')

        if node_type != 'call':
            continue

        for node in ast.walk(tree):
            if isinstance(node, ast.Call):
                if match_call_node(node, match_spec):
                    stmt_node = find_containing_statement(tree, node)
                    if stmt_node:
                        start_line = stmt_node.lineno
                        end_line = stmt_node.end_lineno

                        region_lines = lines[start_line - 1:end_line]
                        original_code = '\n'.join(region_lines)

                        start_col = stmt_node.col_offset
                        end_col = stmt_node.end_col_offset if hasattr(stmt_node, 'end_col_offset') else len(lines[end_line - 1])

                        matched_regions.append(MatchedRegion(
                            start_lineno=start_line,
                            end_lineno=end_line,
                            start_col=start_col,
                            end_col=end_col,
                            original_code=original_code,
                            guidance_item=guidance_item
                        ))

    matched_regions.sort(key=lambda r: r.start_lineno, reverse=True)
    return matched_regions


def find_containing_statement(tree: ast.AST, target_node: ast.AST) -> Optional[ast.stmt]:
    class StatementFinder(ast.NodeVisitor):
        def __init__(self):
            self.parent_map = {}

        def visit(self, node):
            for child in ast.iter_child_nodes(node):
                self.parent_map[child] = node
            self.generic_visit(node)
            return node

    finder = StatementFinder()
    finder.visit(tree)

    compound_stmt_types = (
        ast.For, ast.AsyncFor, ast.While, ast.If,
        ast.With, ast.AsyncWith, ast.Try,
        ast.FunctionDef, ast.AsyncFunctionDef,
        ast.ClassDef, ast.Module,
    )

    current = target_node
    while current in finder.parent_map:
        parent = finder.parent_map[current]
        if isinstance(current, ast.stmt) and not isinstance(current, compound_stmt_types):
            if isinstance(parent, compound_stmt_types):
                return current
        current = parent

    return None


def validate_guidance_content(guidance_item: Dict[str, Any]) -> str:
    content = guidance_item.get('content')
    if content is None:
        raise ValueError("Guidance item missing 'content' field")
    if not isinstance(content, str) or not content.strip():
        raise ValueError("Guidance item 'content' field is empty")
    return content.strip()

def validate_guidance_full_content(guidance_item: Dict[str, Any]) -> str:
    full_content = guidance_item.get('full_content')
    if full_content is None:
        raise ValueError("Guidance item missing 'full_content' field")
    if not isinstance(full_content, str) or not full_content.strip():
        raise ValueError("Guidance item 'full_content' field is empty")
    return full_content.strip()


_SKILL_SIGNATURES_CACHE: Dict[str, Dict[str, List[str]]] = {}


def _load_skill_signatures(robot_name: str) -> Dict[str, List[str]]:
    global _SKILL_SIGNATURES_CACHE

    if robot_name in _SKILL_SIGNATURES_CACHE:
        return _SKILL_SIGNATURES_CACHE[robot_name]

    import os
    from pathlib import Path

    possible_paths = [
        Path(f"ontology_modules/data/{robot_name}_skills.json"),
        Path(f"{robot_name}_skills.json"),
    ]

    skills_path = None
    for p in possible_paths:
        if p.exists():
            skills_path = p
            break

    if skills_path is None:
        print(f"[WARNING] {robot_name}_skills.json not found")
        _SKILL_SIGNATURES_CACHE[robot_name] = {}
        return {}

    try:
        with open(skills_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        signatures = {}
        for skill in data.get('skills', []):
            skill_name = skill.get('name')
            if skill_name:
                params = ['env', 'task']
                for param in skill.get('parameters', []):
                    params.append(param.get('name'))
                signatures[skill_name] = params

        _SKILL_SIGNATURES_CACHE[robot_name] = signatures
        return signatures

    except Exception as e:
        print(f"[WARNING] Failed to load {robot_name}_skills.json: {e}")
        _SKILL_SIGNATURES_CACHE[robot_name] = {}
        return {}


def get_full_interface_signature(func_name: str, robot_name: str = "ur5") -> str:
    signatures = _load_skill_signatures(robot_name)
    if func_name not in signatures:
        return f"{func_name}(...)"

    params = signatures[func_name]
    if 'move_to' in func_name:
        params = [p for p in params if p not in ('trials', 'max_configs', 'timeout_s')]

    param_str = ', '.join([f"{p}=..." for p in params])
    return f"{func_name}({param_str})"


def construct_fim_prompt(
    code: str,
    region: MatchedRegion,
    include_context: bool = True,
    grasp_guidance: Optional[str] = None,
    target_robot: str = "ur5"
) -> Tuple[str, str, str]:
    # Validate and extract mandatory content instruction
    content_instruction = validate_guidance_content(region.guidance_item)
    full_content_instruction = validate_guidance_full_content(region.guidance_item)

    lines = code.split('\n')

    # Extract prefix (everything before the matched region)
    prefix_lines = lines[:region.start_lineno - 1]
    prefix_code = '\n'.join(prefix_lines)
    if prefix_code and not prefix_code.endswith('\n'):
        prefix_code += '\n'

    # Get indentation of the original code
    original_line = lines[region.start_lineno - 1]
    indent = len(original_line) - len(original_line.lstrip())
    indent_str = ' ' * indent

    # Extract suffix (everything after the matched region)
    suffix_lines = lines[region.end_lineno:]
    suffix_code = '\n'.join(suffix_lines)

    from infilling_modules.static_rules import get_rules_for_robot

    original_code_block = ""
    if region.guidance_item.get('interface'):
        interface = region.guidance_item['interface']
        interface_lines_check = [line.strip() for line in interface.split('\n') if line.strip()]
        if len(interface_lines_check) > 1:
            original_code_block = f"\n{original_line}\n"

    robot_rules = get_rules_for_robot(target_robot)
    full_content_instruction += robot_rules + original_code_block

    if grasp_guidance:
        match_spec = region.guidance_item.get('match', {})
        func_spec = match_spec.get('func', {})
        skill_name = func_spec.get('id', '').lower()

        interface = region.guidance_item.get('interface', '').lower()
        is_grasp_related = any(kw in skill_name for kw in ['pick', 'grasp']) or \
                          any(kw in interface for kw in ['grasp', 'ur5_grasp'])

        if is_grasp_related:
            full_content_instruction += f"\n\n{grasp_guidance}\n"

    instruction_lines = full_content_instruction.split('\n')
    formatted_instruction = '\n'.join(f"# {line}" if line.strip() else "#" for line in instruction_lines)
    instruction_block = f"{formatted_instruction}\n"

    interface_hint = ""
    if region.guidance_item.get('interface'):
        interface = region.guidance_item['interface']
        interface_lines_check = [line.strip() for line in interface.split('\n') if line.strip()]
        if len(interface_lines_check) == 1:
            interface_hint = f"{indent_str}{interface}\n{indent_str}"

    if region.guidance_item.get('interface'):
        interface = region.guidance_item['interface']
        interface_lines = [line.strip() for line in interface.split('\n') if line.strip()]

        if len(interface_lines) > 1:
            formatted_lines = []
            for line in interface_lines:
                func_name_match = re.match(r'(\w+)\s*\(', line)
                if func_name_match:
                    func_name = func_name_match.group(1)
                    full_sig = get_full_interface_signature(func_name, target_robot)
                    formatted_line = f"    {full_sig}"
                else:
                    formatted_line = f"    {line}"
                formatted_lines.append(formatted_line)

            formatted_interface = "\n".join(formatted_lines)
            formatted_interface_str = f"""
{formatted_interface}
"""
            hint_for_FIM = f"{indent_str}{formatted_interface_str}{indent_str}"

        else:
            formatted_interface = interface
            hint_for_FIM = f"{indent_str}# MUST call: {formatted_interface}\n{indent_str}"
    else:
        hint_for_FIM = f"{indent_str}"

    fim_prompt = f"{instruction_block}{FIM_PREFIX}{prefix_code}{hint_for_FIM}{FIM_SUFFIX}\n{suffix_code}{FIM_MIDDLE}"

    return prefix_code, fim_prompt, suffix_code


def prepare_fim_input(
    source_path: str,
    guidance_path: str
) -> List[Dict[str, Any]]:
    code = read_source_code(source_path)
    guidance = parse_guidance(guidance_path)
    guidance_items = guidance.get('guidance', [])

    grasp_guidance = guidance.get('grasp_guidance')

    target_robot = guidance.get('target_robot', 'ur5')
    matched_regions = find_matching_regions(code, guidance_items)

    if not matched_regions:
        return []

    fim_inputs = []
    for region in matched_regions:
        prefix_code, fim_prompt, suffix_code = construct_fim_prompt(
            code, region, grasp_guidance=grasp_guidance, target_robot=target_robot
        )

        fim_inputs.append({
            'fim_prompt': fim_prompt,
            'region': region,
            'prefix_code': prefix_code,
            'suffix_code': suffix_code,
            'original_code': code,
            'grasp_guidance': grasp_guidance,
            'target_robot': target_robot
        })

    return fim_inputs


def load_model_and_tokenizer(
    model_name: str = "Qwen/Qwen2.5-Coder-7B",
    device: Optional[str] = None,
    torch_dtype: Optional[torch.dtype] = None
) -> Tuple[Any, Any, str]:
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    if torch_dtype is None:
        torch_dtype = torch.float16 if device == "cuda" else torch.float32

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype,
        device_map=device if device == "cuda" else None,
        quantization_config=quantization_config,
        trust_remote_code=True
    )

    if device == "cpu":
        model = model.to(device)

    model.eval()

    return model, tokenizer, device


def generate_fim_completion(
    model: Any,
    tokenizer: Any,
    fim_prompt: str,
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    top_p: float = 1.0,
    seed: Optional[int] = 42,
    past_key_values: Optional[Tuple] = None,
    use_cache: bool = True,
    single_line_only: bool = False
) -> Tuple[str, Optional[Tuple]]:
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    embed_device = model.model.embed_tokens.weight.device

    inputs = tokenizer(fim_prompt, return_tensors="pt").to(embed_device)
    input_length = inputs.input_ids.shape[1]

    gen_kwargs = {
        'max_new_tokens': max_new_tokens,
        'pad_token_id': tokenizer.pad_token_id,
        'eos_token_id': tokenizer.eos_token_id,
        'use_cache': use_cache,
    }

    if temperature == 0.0:
        gen_kwargs['do_sample'] = False
    else:
        gen_kwargs['do_sample'] = True
        gen_kwargs['temperature'] = temperature
        gen_kwargs['top_p'] = top_p

    if past_key_values is not None:
        gen_kwargs['past_key_values'] = past_key_values

    if single_line_only:
        stopping_criteria = StoppingCriteriaList([
            SingleStatementStoppingCriteria(tokenizer, input_length)
        ])
        gen_kwargs['stopping_criteria'] = stopping_criteria

    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            return_dict_in_generate=True,
            output_hidden_states=False,
            **gen_kwargs
        )

    generated_ids = outputs.sequences[0][input_length:]
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    num_tokens = len(generated_ids)

    new_kv_cache = outputs.past_key_values if use_cache and hasattr(outputs, 'past_key_values') else None

    return generated_text, new_kv_cache, num_tokens


def extract_generated_code(generated_text: str) -> str:
    for token in [FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, '<|file_sep|>', '<|endoftext|>']:
        generated_text = generated_text.split(token)[0]

    lines = generated_text.split('\n')
    result_lines = []
    for i, line in enumerate(lines):
        stripped = line.strip()
        if stripped.startswith('def ') and i > 0:
            break
        if stripped.startswith('class ') and i > 0:
            break
        if stripped.startswith('# ---') or stripped.startswith('"""') and i > 0:
            break
        result_lines.append(line)

    result = '\n'.join(result_lines).strip()
    if not result and lines:
        result = lines[0].strip()

    result_lines = result.split('\n')
    cleaned_lines = []
    found_code = False
    for line in result_lines:
        stripped = line.strip()
        if stripped and not stripped.startswith('#'):
            found_code = True
            cleaned_lines.append(line)
        elif not found_code:
            cleaned_lines.append(line)

    result = '\n'.join(cleaned_lines).strip()
    return result


def _extract_func_call_with_balanced_parens(line: str, func_name: str) -> Optional[Tuple[int, int, str]]:
    func_idx = line.find(func_name)
    if func_idx == -1:
        return None

    paren_start = line.find('(', func_idx)
    if paren_start == -1:
        return None

    paren_depth = 1
    i = paren_start + 1
    while i < len(line) and paren_depth > 0:
        if line[i] == '(':
            paren_depth += 1
        elif line[i] == ')':
            paren_depth -= 1
        i += 1

    if paren_depth != 0:
        return None

    paren_end = i
    args_str = line[paren_start + 1:paren_end - 1]

    return (func_idx, paren_end, args_str)


def filter_invalid_params(generated_code: str, robot_name: str = "ur5") -> str:
    signatures = _load_skill_signatures(robot_name)

    if not signatures:
        return generated_code

    result_lines = []
    for line in generated_code.split('\n'):
        for func_name in signatures.keys():
            if func_name not in line:
                continue

            call_info = _extract_func_call_with_balanced_parens(line, func_name)
            if call_info is None:
                continue

            func_start, func_end, args_str = call_info
            valid_params = set(signatures[func_name])

            if 'move_to' in func_name:
                valid_params -= {'trials', 'max_configs', 'timeout_s'}

            filtered_args = []
            current_arg = ""
            paren_depth = 0
            bracket_depth = 0

            for char in args_str + ',':
                if char == '(':
                    paren_depth += 1
                    current_arg += char
                elif char == ')':
                    paren_depth -= 1
                    current_arg += char
                elif char == '[':
                    bracket_depth += 1
                    current_arg += char
                elif char == ']':
                    bracket_depth -= 1
                    current_arg += char
                elif char == ',' and paren_depth == 0 and bracket_depth == 0:
                    arg = current_arg.strip()
                    if arg:
                        if '=' in arg:
                            param_name = arg.split('=')[0].strip()
                            if param_name in valid_params:
                                filtered_args.append(arg)
                        else:
                            filtered_args.append(arg)
                    current_arg = ""
                else:
                    current_arg += char

            new_args_str = ', '.join(filtered_args)
            new_call = f"{func_name}({new_args_str})"

            line = line[:func_start] + new_call + line[func_end:]

        result_lines.append(line)

    return '\n'.join(result_lines)


def normalize_indent(spaces: int, indent_unit: int = 4) -> int:
    if spaces == 0:
        return 0
    return ((spaces - 1) // indent_unit + 1) * indent_unit

def _linecol_to_offset(code: str, lineno: int, col: int) -> int:
    lines = code.splitlines(keepends=True)
    return sum(len(lines[i]) for i in range(lineno - 1)) + col


def reconstruct_code(original_code: str, region: MatchedRegion, generated_code: str) -> str:
    start = _linecol_to_offset(original_code, region.start_lineno, region.start_col)
    end   = _linecol_to_offset(original_code, region.end_lineno, region.end_col)

    return original_code[:start] + generated_code + original_code[end:]

def run_fim_infilling(
    fim_input: Dict[str, Any],
    model: Any = None,
    tokenizer: Any = None,
    model_name: str = "Qwen/Qwen2.5-Coder-7B",
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    top_p: float = 1.0,
    seed: int = 42,
    device: Optional[str] = None,
    past_key_values: Optional[Tuple] = None
) -> Tuple[str, Optional[Tuple]]:
    if model is None or tokenizer is None:
        model, tokenizer, device = load_model_and_tokenizer(model_name, device)

    region = fim_input['region']
    original_code = fim_input['original_code']
    grasp_guidance = fim_input.get('grasp_guidance')

    interface = region.guidance_item.get('interface', '')
    interface_lines = [line.strip() for line in interface.split('\n') if line.strip()]

    fim_prompt = fim_input['fim_prompt']
    is_single_line = len(interface_lines) <= 1

    print(f"[FIM Infilling] Generating completion for region at lines {region.start_lineno}-{region.end_lineno}")
    print(f"[FIM Infilling] Single-line mode: {is_single_line}")

    generated_text, new_kv_cache, num_tokens = generate_fim_completion(
        model=model,
        tokenizer=tokenizer,
        fim_prompt=fim_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        seed=seed,
        past_key_values=past_key_values,
        use_cache=True,
        single_line_only=is_single_line
    )

    generated_code = extract_generated_code(generated_text)

    if len(interface_lines) > 1:
        generated_lines = [line for line in generated_code.split('\n') if line.strip() and not line.strip().startswith('#')]
        expected_lines = len(interface_lines)

        while len(generated_lines) < expected_lines:

            lines = original_code.split('\n')
            original_line = lines[region.start_lineno - 1]
            indent = len(original_line) - len(original_line.lstrip())
            indent_str = ' ' * indent

            temp_code = reconstruct_code(original_code, region, generated_code)
            temp_lines = temp_code.split('\n')

            last_gen_line_idx = None
            for idx, line in enumerate(temp_lines):
                if generated_lines[-1].strip() in line:
                    last_gen_line_idx = idx
                    break

            if last_gen_line_idx is None:
                break

            prefix_lines = temp_lines[:last_gen_line_idx + 1]
            suffix_lines = temp_lines[last_gen_line_idx + 1:]

            prefix_code = '\n'.join(prefix_lines)
            if prefix_code and not prefix_code.endswith('\n'):
                prefix_code += '\n'
            prefix_code += indent_str

            suffix_code = '\n'.join(suffix_lines)

            next_idx = len(generated_lines)
            if next_idx < len(interface_lines):
                next_func = interface_lines[next_idx]
                func_name_match = re.match(r'(\w+)\s*\(', next_func)
                if func_name_match:
                    func_name = func_name_match.group(1)
                    target_robot = fim_input.get('target_robot', 'sawyer')
                    full_sig = get_full_interface_signature(func_name, target_robot)
                    hint = f"# MUST call: obs, reward, done = {full_sig}\n{indent_str}"
                else:
                    hint = f"# MUST call: {next_func}\n{indent_str}"
            else:
                hint = ""

            next_fim_prompt = f"{FIM_PREFIX}{prefix_code}{hint}{FIM_SUFFIX}\n{suffix_code}{FIM_MIDDLE}"

            next_text, _, next_tokens = generate_fim_completion(
                model=model,
                tokenizer=tokenizer,
                fim_prompt=next_fim_prompt,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                seed=seed,
                past_key_values=None,
                use_cache=False,
                single_line_only=True
            )
            num_tokens += next_tokens

            next_code = extract_generated_code(next_text)

            if not next_code.strip():
                break

            generated_code = generated_code.rstrip() + '\n' + indent_str + next_code.strip()
            generated_lines = [line for line in generated_code.split('\n') if line.strip() and not line.strip().startswith('#')]

    target_robot = fim_input.get('target_robot', 'ur5')
    generated_code = filter_invalid_params(generated_code, target_robot)

    final_code = reconstruct_code(original_code, region, generated_code)

    return final_code, new_kv_cache, num_tokens


def _generate_grasp_prefix_lines(
    current_code: str,
    grasp_prefix_templates: List[str],
    grasp_guidance: Optional[str],
    model: Any,
    tokenizer: Any,
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    top_p: float = 1.0,
    seed: int = 42,
    device: Optional[str] = None
) -> Tuple[str, int]:
    total_tokens = 0

    lines = current_code.split('\n')
    grasp_at_line_idx = None
    for idx, line in enumerate(lines):
        if 'ur5_grasp_at' in line and not line.strip().startswith('#'):
            grasp_at_line_idx = idx
            break

    if grasp_at_line_idx is None:
        return current_code, 0

    grasp_at_line = lines[grasp_at_line_idx]
    indent = len(grasp_at_line) - len(grasp_at_line.lstrip())
    indent_str = grasp_at_line[:indent]

    for template_idx, template in enumerate(reversed(grasp_prefix_templates)):

        prefix_lines = lines[:grasp_at_line_idx]
        prefix_code = '\n'.join(prefix_lines)
        if prefix_code and not prefix_code.endswith('\n'):
            prefix_code += '\n'
        prefix_code += indent_str

        suffix_lines = lines[grasp_at_line_idx:]
        suffix_code = '\n'.join(suffix_lines)

        fim_prompt = f"<|fim_prefix|>{prefix_code}<|fim_suffix|>\n{suffix_code}<|fim_middle|>"

        if grasp_guidance:
            context = f"# Grasp guidance:\n# {grasp_guidance.replace(chr(10), chr(10) + '# ')}\n\n"
            fim_prompt = f"<|fim_prefix|>{context}{prefix_code}<|fim_suffix|>\n{suffix_code}<|fim_middle|>"

        generated_text, _, num_tokens = generate_fim_completion(
            model=model,
            tokenizer=tokenizer,
            fim_prompt=fim_prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            seed=seed,
            past_key_values=None,
            use_cache=False
        )
        total_tokens += num_tokens

        generated_code = extract_generated_code(generated_text)

        new_line = indent_str + generated_code.strip()
        lines.insert(grasp_at_line_idx, new_line)

        grasp_at_line_idx += 1

        current_code = '\n'.join(lines)

    return current_code, total_tokens


def run_full_infilling_pipeline(
    source_path: str,
    guidance_path: str,
    model_name: str = "Qwen/Qwen2.5-Coder-7B",
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    top_p: float = 1.0,
    seed: int = 42,
    device: Optional[str] = None,
    output_path: Optional[str] = None,
    model: Any = None,
    tokenizer: Any = None,
) -> str:
    fim_inputs = prepare_fim_input(source_path, guidance_path)

    if not fim_inputs:
        return read_source_code(source_path)

    if model is None or tokenizer is None:
        model, tokenizer, device = load_model_and_tokenizer(model_name, device)

    current_code = read_source_code(source_path)
    kv_cache = None
    total_tokens = 0

    for i, fim_input in enumerate(fim_inputs):
        kv_cache = None
        if i > 0:
            guidance = parse_guidance(guidance_path)
            guidance_items = guidance.get('guidance', [])
            grasp_guidance = guidance.get('grasp_guidance')
            target_robot = guidance.get('target_robot', 'ur5')
            matched_regions = find_matching_regions(current_code, guidance_items)

            if not matched_regions:
                break

            region = matched_regions[0]
            prefix_code, fim_prompt, suffix_code = construct_fim_prompt(
                current_code, region, grasp_guidance=grasp_guidance, target_robot=target_robot
            )
            fim_input = {
                'fim_prompt': fim_prompt,
                'region': region,
                'prefix_code': prefix_code,
                'suffix_code': suffix_code,
                'original_code': current_code,
                'grasp_guidance': grasp_guidance,
                'target_robot': target_robot
            }
        else:
            fim_input['original_code'] = current_code

        current_code, kv_cache, num_tokens = run_fim_infilling(
            fim_input=fim_input,
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            seed=seed,
            device=device,
            past_key_values=kv_cache
        )
        total_tokens += num_tokens

        grasp_prefix_templates = fim_input['region'].guidance_item.get('grasp_prefix_templates')
        if grasp_prefix_templates:
            current_code, prefix_tokens = _generate_grasp_prefix_lines(
                current_code=current_code,
                grasp_prefix_templates=grasp_prefix_templates,
                grasp_guidance=fim_input.get('grasp_guidance'),
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                seed=seed,
                device=device
            )
            total_tokens += prefix_tokens

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(current_code)

    return current_code, total_tokens


def prepare_fim_masked_code(source_path: str, guidance_path: str) -> List[Dict[str, Any]]:
    return prepare_fim_input(source_path, guidance_path)


def infill_code(
    fim_input: Dict[str, Any],
    model: Any = None,
    tokenizer: Any = None,
    **kwargs
) -> str:
    code, _ = run_fim_infilling(fim_input, model, tokenizer, **kwargs)
    return code
