import re


def make_an_extended_block(retrieved_context, tokenizer):
    content = retrieved_context[0]
    #content = retrieved_context["full_code"]
    # put the file path in the comment
    f_path_comment = f'# The below code fragment can be found in:\n'
    f_paths_str = '# '+'/'.join(retrieved_context[-2]) + '\n'
    #f_paths_str = '# ' + '/'.join(retrieved_context["fpath_tuple"]) + '\n'
    # put code lines in the comment
    code_lines = content.splitlines(keepends=True)
    content_lines_comment = [f'# {line.rstrip()}\n' for line in code_lines]
    # aggregate the comment and the code lines
    seperator = '# ' + '-' * 50 + '\n'
    block_str = "".join([f_path_comment, f_paths_str, seperator] + content_lines_comment + [seperator])
    tokenized_block = tokenizer.tokenize(block_str)
    token_len = len(tokenized_block)
    return block_str, token_len


def make_an_import_block(import_context, tokenizer , max_import_length):
    header = f'# File import information includes the module name, function methods (names and logic), class name, and class methods, variables, and fields.\n'
    seperator = '# ' + '-' * 50 + '\n'
    header_tokens = tokenizer.tokenize(header+seperator+seperator)
    initilal_len = len(header_tokens)
    if initilal_len > max_import_length:
        return '',0
    max_content=max_import_length-initilal_len
    lines =[]
    used_tokens = 0
    for info in import_context:
        if used_tokens >= max_content:
            break
        if info['type'] in {'module', 'none', 'unknown'}:
            line = f'# {info["type"]}:{info["name"]}\n'
            line_len = len(tokenizer.tokenize(line))
            if used_tokens + line_len <= max_content:
                lines.append(line)
                used_tokens += line_len

        elif info['type'] in ['function', 'method']:
            
            code = ''.join(info["source_lines"])

            
            py_pattern = r'^\s*@?(.*?)\s*def\s+(\w+)\s*\((.*?)\)\s*(->\s*\w+)?'
            java_pattern = r'(public|private|protected|static)\s+(\w+)\s+(\w+)\s*\((.*?)\)'

            
            py_match = re.search(py_pattern, code, re.MULTILINE)
            java_match = re.search(java_pattern, code)

            
            components = []
            if py_match:
                decorator, name, params, return_type = py_match.groups()
                if decorator:
                    components.append(f"@{decorator.split('.')[-1]}")
                components.append(f"Function:{name}({params})")
                if return_type:
                    components.append(f"->{return_type.split('->')[-1].strip()}")
            elif java_match:
                
                modifier, ret_type, name, params = java_match.groups()
                components.append(f"Method:{modifier} {ret_type} {name}({params})")

            
            if components:
                code_lines = [cl.split('#', 1)[0].strip() for cl in code.splitlines() if cl.strip()]
                for line in code_lines[1:]: 
                    if line and not line.startswith(('def', 'public', 'private')):
                        code_sample = re.sub(r'\s+', ' ', line)
                        components.append(f'code:{code_sample}...')
                        break

            if components:
                temp_content = '# ' + ' | '.join(components) + '\n'
                content_length = len(tokenizer.tokenize(temp_content))
                # print("2"+temp_content)
                if used_tokens + content_length <= max_content:
                    lines.append(temp_content)
                    # print("3"+temp_content)
                    used_tokens += content_length

        elif info['type'] in ['class']:
            class_desc = []
            if 'class_vals' in info:  # Python类
                match = re.search(r'class\s+(\w+)(?:\((.*?)\))?', info['source_lines'][0])
                if match:
                    bases = match.group(2)
                    class_desc.append(f"Class:{match.group(1)}")
                    if bases:
                        class_desc.append(f"Inherits:{','.join(bases.split(',')[:2])}")  
            else:  # Java类
                class_desc.append(f"Class:{info['name']}")

            
            methods = []
            for m_name, m_data in list(info['methods'].items()):  
                param_match = re.search(r'\((.*?)\)', m_data['source'][0])
                params = param_match.group(1) if param_match else ''
                methods.append(f"{m_name}({params})" if params else m_name)  

            
            fields = []
            # field_source = info.get('class_vals', info.get('fields', {}))
            if info.get('class_vals'):
                field_source = info.get('class_vals')
                for f_name, f_data in field_source:  
                    val = re.sub(r'\s+', ' ', f_data['source'][0].split('=')[-1].strip()[:15]) 
                    fields.append(f"{f_name}={val}")
            elif info.get('fields'):
                field_source = info.get('fields')
                for f_data in field_source:
                    # print(f_data)
                    fields.append(f"{f_data}")
            
            components = []
            if class_desc:
                components.append(' | '.join(class_desc))
            if methods:
                components.append(f"Methods:{','.join(methods)}")
            if fields:
                components.append(f"Fields:{','.join(fields)}")

            content = '# ' + ' | '.join(components) + '\n'
            content_length = len(tokenizer.tokenize(content))

            if used_tokens + content_length <= max_content:
                lines.append(content)
                used_tokens += content_length

    block = f"{header}{seperator}{''.join(lines)}{seperator}"
    tokenized_block = tokenizer.tokenize(block)
    token_len = len(tokenized_block)
    return block, token_len

def make_an_import_block1(import_context, tokenizer , max_import_length):
    # header = f'#  File import information should be in regular expression format, including module names, function methods (names and logic), class names, class methods, variables and fields\n'
    header = f'# File import information includes the module name, function methods (names and logic), class name, and class methods, variables, and fields.\n'
    seperator = '# ' + '-' * 50 + '\n'
    header_tokens = tokenizer.tokenize(header+seperator+seperator)
    initilal_len = len(header_tokens)
    if initilal_len > max_import_length:
        return '',0
    max_content=max_import_length-initilal_len
    lines =[]
    used_tokens = 0
    for info in import_context:
        if used_tokens >= max_content:
            break
        if info['type'] in {'module', 'none', 'unknown'}:
            line = f'# {info["type"]}:{info["name"]}\n'
            line_len = len(tokenizer.tokenize(line))
            if used_tokens + line_len <= max_content:
                lines.append(line)
                used_tokens += line_len

        if info['type'] in ['function', 'method']:
            code = ''.join(info["source_lines"])
            for cl in code.splitlines():
                code_part = cl.split('#', 1)[0].strip()
                if not code_part:
                    continue
                line = f'# {code_part}\n'
                line_len = len(tokenizer.tokenize(line))
                if used_tokens + line_len > max_content:
                    break
                lines.append(line)
                used_tokens += line_len

            
        elif info['type'] in ['class']:
            class_desc = []
            if 'class_vals' in info:  
                match = re.search(r'class\s+(\w+)(?:\((.*?)\))?', info['source_lines'][0])
                if match:
                    bases = match.group(2)
                    class_desc.append(f"Class:{match.group(1)}")
                    if bases:
                        class_desc.append(f"Inherits:{','.join(bases.split(','))}")
            else:  # Java类
                class_desc.append(f"Class:{info['name']}")

            
            methods = []
            for m_name, m_data in list(info['methods'].items()):
                param_match = re.search(r'\((.*?)\)', m_data['source'][0])
                params = param_match.group(1) if param_match else ''
                methods.append(f"{m_name}({params})" if params else m_name)

            
            fields = []
            if info.get('class_vals'):
                field_source = info.get('class_vals')
                for f_name, f_data in field_source:  
                    val = re.sub(r'\s+', ' ', f_data['source'][0].split('=')[-1].strip()[:15])
                    fields.append(f"{f_name}={val}")
            elif info.get('fields'):
                field_source = info.get('fields')
                for f_data in field_source:
                    fields.append(f"{f_data}")
            
            components = []
            if class_desc:
                components.append(' | '.join(class_desc))
            if methods:
                components.append(f"Methods:{','.join(methods)}")
            if fields:
                components.append(f"Fields:{','.join(fields)}")

            content = '# ' + ' | '.join(components) + '\n'
            content_length = len(tokenizer.tokenize(content))

            if used_tokens + content_length <= max_content:
                lines.append(content)
                used_tokens += content_length

    block = f"{header}{seperator}{''.join(lines)}{seperator}"
    tokenized_block = tokenizer.tokenize(block)
    token_len = len(tokenized_block)
    return block, token_len

#For the second version, use the original text for classes and functions, and add the import statement
def make_an_import_block2(import_context, tokenizer , max_import_length):
    # header = f'#  File import information should be in regular expression format, including module names, function methods (names and logic), class names, class methods, variables and fields\n'
    header = f'# Cross-file reference snippets:\n'
    seperator = '# ' + '-' * 50 + '\n'
    header_tokens = tokenizer.tokenize(header+seperator+seperator)
    initilal_len = len(header_tokens)
    if initilal_len > max_import_length:
        return '',0
    max_content=max_import_length-initilal_len
    lines =[]
    used_tokens = 0
    pattern = r'^(\s*)(public|private|protected|static|final)(\s+)([\w<>]+\s+)(\w+\s*\(.*?\))\s*\{?' 
    for info in import_context:
        if used_tokens >= max_content:
            break
        if info['type'] in ['function', 'method']:
            code = ''.join(info['import_line']) + '\n'
            length = len(tokenizer.tokenize(code))
            if used_tokens+length > max_content:
                break
            else:
                lines.append(code)
                used_tokens += length
            cleaned_lines = [line.strip() for line in info['source_lines']]
            for line in cleaned_lines:
                line1 = f'# {line}\n'
                line_len = len(tokenizer.tokenize(line1))
                if used_tokens + line_len > max_content:
                    break
                # code += line1
                lines.append(line1)
                used_tokens += line_len

        elif info['type'] in ['class']:
            code = ''.join(info['import_line']) + '\n'
            length=len(tokenizer.tokenize(code))
            if used_tokens + length > max_content:
                break
            else:
                lines.append(code)
                used_tokens += length
            cleaned_lines = [line.strip() for line in info['source_lines']]
            for line in cleaned_lines:
                if line[:3]=='def':
                    line1 = f'# {line}\n'
                    line_len = len(tokenizer.tokenize(line1))
                    if used_tokens + line_len > max_content:
                        break
                    # code += line1
                    lines.append(line1)
                    used_tokens += line_len
                elif re.match(pattern, line): #java的
                    line1 = f'# {line}\n'
                    line_len = len(tokenizer.tokenize(line1))
                    if used_tokens + line_len > max_content:
                        break
                    # code += line1
                    lines.append(line1)
                    used_tokens += line_len

    block = f"{header}{seperator}{''.join(lines)}{seperator}"
    tokenized_block = tokenizer.tokenize(block)
    token_len = len(tokenized_block)
    return block, token_len

#Use regular expressions for classes, original text for functions, and add the import statement
def make_an_import_block3(import_context, tokenizer , max_import_length):
    header = f'# Cross-file reference snippets:\n'
    seperator = '# ' + '-' * 50 + '\n'
    header_tokens = tokenizer.tokenize(header + seperator + seperator)
    initilal_len = len(header_tokens)
    if initilal_len > max_import_length:
        return '', 0
    max_content = max_import_length - initilal_len
    lines = []
    used_tokens = 0
    for info in import_context:
        if used_tokens >= max_content:
            break
        if info['type'] in ['function', 'method']:
            import_line = ''.join(info["import_line"]) + '\n'
            code = import_line.join(info["source_lines"])
            for cl in code.splitlines():
                code_part = cl.split('#', 1)[0].strip()
                if not code_part:
                    continue
                line = f'# {code_part}\n'
                line_len = len(tokenizer.tokenize(line))
                if used_tokens + line_len > max_content:
                    break
                lines.append(line)
                used_tokens += line_len

        elif info['type'] in ['class']:
            class_desc = []
            if 'class_vals' in info:  # Python
                match = re.search(r'class\s+(\w+)(?:\((.*?)\))?', info['source_lines'][0])
                if match:
                    bases = match.group(2)
                    class_desc.append(f"Class:{match.group(1)}")
                    if bases:
                        class_desc.append(f"Inherits:{','.join(bases.split(','))}")
            else:  # Java
                class_desc.append(f"Class:{info['name']}")

            methods = []
            for m_name, m_data in list(info['methods'].items()):
                param_match = re.search(r'\((.*?)\)', m_data['source'][0])
                params = param_match.group(1) if param_match else ''
                methods.append(f"{m_name}({params})" if params else m_name)

            fields = []
            if info.get('class_vals'):
                field_source = info.get('class_vals')
                for f_name, f_data in field_source:
                    val = re.sub(r'\s+', ' ', f_data['source'][0].split('=')[-1].strip()[:15])
                    fields.append(f"{f_name}={val}")
            elif info.get('fields'):
                field_source = info.get('fields')
                for f_data in field_source:
                    fields.append(f"{f_data}")
            components = []
            if class_desc:
                components.append(' | '.join(class_desc))
            if methods:
                components.append(f"Methods:{','.join(methods)}")
            if fields:
                components.append(f"Fields:{','.join(fields)}")

            content = '# ' + ' | '.join(components) + '\n'
            import_line = ''.join(info["import_line"]) + '\n'
            content= import_line+ content
            content_length = len(tokenizer.tokenize(content))

            if used_tokens + content_length <= max_content:
                lines.append(content)
                used_tokens += content_length

    block = f"{header}{seperator}{''.join(lines)}{seperator}"
    tokenized_block = tokenizer.tokenize(block)
    token_len = len(tokenized_block)
    return block, token_len

def make_an_import_block4(import_context, tokenizer, max_import_length):
    header = "# Cross-file reference snippets (enhanced):\n"
    seperator = '# ' + '-' * 50 + '\n'
    block = [header, seperator]
    current_tokens = len(tokenizer.tokenize(''.join(block)))

    for info in import_context:
        if current_tokens >= max_import_length:
            break

        import_line = ''.join(info.get("import_line")) + '\n'
        import_tokens = len(tokenizer.tokenize(import_line))

        if info['type'] == 'class':
            class_info = parse_class_info(info)
            content = build_class_section(class_info, import_line)
        elif info['type'] in ['function', 'method']:
            func_info = parse_function_info(info)
            content = build_function_section(func_info, import_line)
        else:
            continue

        # Token check
        content_tokens = len(tokenizer.tokenize(content))
        if current_tokens + content_tokens + import_tokens > max_import_length:
            break

        block.append(content)
        current_tokens += content_tokens + import_tokens

    block.append(seperator)
    return ''.join(block), current_tokens


def parse_class_info(info):
    class_def = info['source_lines'][0].strip()
    match = re.match(r'class\s+(\w+)(?:\((.*?)\))?\s*:', class_def)
    class_name = match.group(1) if match else info['name']
    bases = match.group(2).split(',') if match and match.group(2) else []

    methods = []
    for m_name, m_data in info['methods'].items():
        src = m_data['source'][0].strip()
        params_match = re.search(r'\((.*?)\)', src)
        params = params_match.group(1) if params_match else ''
        return_match = re.search(r'->\s*([\w\[\],\s]+):', src)
        return_type = return_match.group(1).strip() if return_match else 'None'

        methods.append({
            'name': m_name,
            'params': params,
            'return_type': return_type,
            'is_static': '@staticmethod' in src
        })

    fields = []
    if info.get('class_vals'):
        for f_name, f_data in info['class_vals'].items():
            field_def = f_data['source'][0].strip()
            type_match = re.search(r':\s*([\w\[\],\s]+)', field_def)
            field_type = type_match.group(1) if type_match else 'Any'
            fields.append(f"{f_name}: {field_type}")
    elif info.get('fields'):
        field_source = info.get('fields')
        for f_data in field_source:
            fields.append(f"{f_data}")


    return {
        'name': class_name,
        'bases': bases,
        'methods': methods,
        'fields': fields
    }

def build_class_section(class_info, import_line):
    section = []
    bases_str = f"({', '.join(class_info['bases'])})" if class_info['bases'] else ""
    section.append(f"# Class: {class_info['name']}{bases_str}\n")

    if class_info['fields']:
        fields = class_info['fields']
        section.append("# Fields:\n")
        section.extend([f"#   - {f}\n" for f in fields])

    if class_info['methods']:
        section.append("# Methods:\n")
        for method in class_info['methods']:
            static_tag = "@static " if method['is_static'] else ""
            section.append(
                f"#   {static_tag}{method['name']}({method['params']})"
                f" -> {method['return_type']}\n"
            )

    return import_line + ''.join(section) + '\n'


def parse_function_info(info):
    func_def = info['source_lines'][0].strip()
    params_match = re.search(r'\((.*?)\)', func_def)
    params = params_match.group(1) if params_match else ''
    return_match = re.search(r'->\s*([\w\[\],\s]+):', func_def)
    return_type = return_match.group(1).strip() if return_match else 'None'

    return {
        'name': info['name'],
        'params': params,
        'return_type': return_type,
        'source': info['source_lines'][:3]
    }

def build_function_section(func_info, import_line):
    section = []
    section.append(
        f"# Function: {func_info['name']}"
        f"({func_info['params']}) -> {func_info['return_type']}\n"
    )

    section.append("# Implementation snippet:\n")
    for line in func_info['source']:
        stripped = line.split('#', 1)[0].strip()
        if stripped:
            section.append(f"# > {stripped}\n")

    return import_line + ''.join(section) + '\n'


def make_str_block_with_max_token_length(tokenizer, max_token_num: int, context_str: str, with_comment=False):
    str_block = ""
    new_line = context_str.splitlines(keepends=True)
    if with_comment:
        context_str_lines_comment = [f'# {line}' for line in new_line]
        new_line = context_str_lines_comment
    curr_len = 0
    for i in range(1, len(new_line) + 1):
        line_len = len(tokenizer.tokenize(new_line[-i]))
        if line_len + curr_len < max_token_num:
            str_block = new_line[-i] + str_block
            curr_len += line_len
        else:
            break
    return str_block

def make_str_block_with_max_token_length1(tokenizer, max_token_num: int, context_str: str, with_comment=False):
    str_block = ""
    new_line = context_str.splitlines(keepends=True)
    if with_comment:
        context_str_lines_comment = [f'# {line}' for line in new_line]
        new_line = context_str_lines_comment
    curr_len = 0
    for i in range(1, len(new_line)):
        line_len = len(tokenizer.tokenize(new_line[i]))
        if line_len + curr_len < max_token_num:
            str_block = new_line[i] + str_block
            curr_len += line_len
        else:
            break
    return str_block

def build_infile_prompt(case, tokenizer, max_num_tokens):
    comment = "# Complete the next statement of the following codes:\n"
    comment_length = len(tokenizer.tokenize(comment))
    max_num_tokens = max_num_tokens // 2 - comment_length
    context = "".join(case['context'])
    prompt = make_str_block_with_max_token_length(tokenizer, max_num_tokens, context)
    return comment + prompt

# ours
def build_retrieval_prompt(case, tokenizer, max_num_tokens, max_top_k):
    context_max_tokens = max_num_tokens // 2
    comment = "# Based on above, complete the next statement of the following codes:\n "
    comment_length = len(tokenizer.tokenize(comment))
    context = make_str_block_with_max_token_length(tokenizer, context_max_tokens-comment_length, "".join(case['context']))
    context_prompt = comment + context

    context_length = len(tokenizer.tokenize(context_prompt))

    import_context = case['import_info']
    import_blocks = []
    # if import_context is not None and len(import_context) > 0:
    if len(import_context) > 0:
        max_import_length = max_num_tokens // 4
        import_str, _ = make_an_import_block2(import_context, tokenizer, max_import_length)
        import_blocks.append(import_str)

    import_prompt=''.join(import_blocks)
    import_tokens = len(tokenizer.tokenize(import_prompt))

    # print("========================")
    # print(import_prompt)
    # print("========================")

    max_retrieval_length = max_num_tokens - context_length - import_tokens
    seperator = '# ' + '-' * 50
    retrieval_prompt = "# Here are some relevant code fragments from other files of the repo:\n"
    retrieval_prompt += seperator + '\n'
    num_chosen_context = 0
    current_token_length = len(tokenizer.tokenize(retrieval_prompt))

    retrival_blocks = []
    top_k_context = case['top_k_context']

    for i in range(1, len(top_k_context) + 1):
        retrieval_context = top_k_context[-i]
        if num_chosen_context >= max_top_k:
            break
        block_str, token_len = make_an_extended_block(retrieval_context, tokenizer)
        if current_token_length + token_len < max_retrieval_length:
            retrival_blocks.insert(0, block_str)
            current_token_length += token_len
            num_chosen_context += 1
        else:
            continue
    retrieval_prompt += ''.join(retrival_blocks)

    return retrieval_prompt + import_prompt + context_prompt

#ours but don't have EAID
def build_retrieval_prompt1(case, tokenizer, max_num_tokens, max_top_k):
    # original context
    context_max_tokens = max_num_tokens // 2
    comment = "# Based on above, complete the next statement of the following codes:\n"
    comment_length = len(tokenizer.tokenize(comment))
    context = make_str_block_with_max_token_length(tokenizer, context_max_tokens-comment_length, "".join(case['context']))
    context_prompt = comment + context

    # retrieved example
    seperator = '# ' + '-' * 50
    retrieval_prompt = "# Here are some relevant code fragments from other files of the repo:\n"
    retrieval_prompt += seperator + '\n'

    num_chosen_context = 0
    max_retrieval_length = max_num_tokens // 2
    current_token_length = len(tokenizer.tokenize(retrieval_prompt))
    retrival_blocks = []
    top_k_context = case['top_k_context']
    for i in range(1, len(top_k_context) + 1):
        retrieval_context = top_k_context[-i]
        if num_chosen_context >= max_top_k:
            break
        block_str, token_len = make_an_extended_block(retrieval_context, tokenizer)
        if current_token_length + token_len < max_retrieval_length:
            retrival_blocks.insert(0, block_str)
            current_token_length += token_len
            num_chosen_context += 1
        else:
            continue
    retrieval_prompt += ''.join(retrival_blocks)
    return retrieval_prompt +context_prompt

# Draco + ours
def build_retrieval_prompt2(case, tokenizer, max_num_tokens, max_top_k):
    cross_max_tokens = max_num_tokens // 2
    comment1 = "# Based on above, complete the next statement of the following codes:\n"
    cross_file_context = case['cross_file_line']

    cross_prompt = comment1 + cross_file_context
    cross_length = len(tokenizer.tokenize(cross_prompt))

    # retrieved example
    seperator = '# ' + '-' * 50
    retrieval_prompt = "# Here are some relevant code fragments from other files of the repo:\n"
    retrieval_prompt += seperator + '\n'

    num_chosen_context = 0
    max_retrieval_length = max_num_tokens - cross_length
    current_token_length = len(tokenizer.tokenize(retrieval_prompt))
    retrival_blocks = []
    top_k_context = case['top_k_context']
    for i in range(1, len(top_k_context) + 1):
        retrieval_context = top_k_context[-i]
        if num_chosen_context >= max_top_k:
            break
        block_str, token_len = make_an_extended_block(retrieval_context, tokenizer)
        if current_token_length + token_len < max_retrieval_length:
            retrival_blocks.insert(0, block_str)
            current_token_length += token_len
            num_chosen_context += 1
        else:
            continue
    retrieval_prompt += ''.join(retrival_blocks)
    return retrieval_prompt + cross_prompt


#only Draco
def build_retrieval_prompt3(case, tokenizer):
    context_prompt = case['cross_file_line']
    context_length=len(tokenizer.tokenize(context_prompt))
    return context_prompt


def truncate(prompt: str, max_num_tokens: int, side: str, tokenizer) -> str:
    """Truncate prompt from side given the token budget"""

    tokens = tokenizer.tokenize(prompt)
    print("Token type:", type(tokens[0]))
    num_tokens = len(tokens)

    if num_tokens > max_num_tokens:
        if side == 'left':
            prompt_tokens = tokens[num_tokens - max_num_tokens:]
        elif side == 'right':
            prompt_tokens = tokens[:max_num_tokens]
        # prompt = tokenizer.convert_tokens_to_string(prompt_tokens)
        prompt = tokenizer.decode(prompt_tokens, skip_special_tokens=True)
        # print("Prompt truncated:", prompt)
        # new_len = len(tokenizer.tokenize(prompt))
        # if new_len > max_num_tokens:
        #     logger.warning(
        #         f'Number of tokens after truncation is greater than max tokens allowed: {new_len=} {num_tokens=}')
    return prompt

# repocoder Variant+ours
def build_retrieval_prompt4(case, tokenizer, tokenizer_raw,max_num_tokens, max_top_k):

    # cross_max_tokens = (max_num_tokens // 4) *3
    cross_max_tokens = max_num_tokens // 2
    comment1 = "# Cross_file information:\n"
    seperator = '# ' + '-' * 50
    cross_file_context = case['cross_file_line']
    comment_length1 = len(tokenizer.tokenize(comment1))
    seperator_length = len(tokenizer.tokenize(seperator))
    # cross_file_context = truncate(cross_file_context, cross_max_tokens-comment_length1-seperator_length,'right' , tokenizer_raw)
    cross_file_context = make_str_block_with_max_token_length(tokenizer, cross_max_tokens-comment_length1-seperator_length,
                                                   "".join(case['cross_file_line']))

    cross_prompt = comment1 + cross_file_context +seperator
    cross_file_length = len(tokenizer.tokenize(cross_prompt))

    # context_max_tokens= max_num_tokens // 6
    context_max_tokens = max_num_tokens // 2
    # context_max_tokens = 0
    comment = "# Based on above, complete the next statement of the following codes:\n"
    comment_length = len(tokenizer.tokenize(comment))
    # context = truncate("".join(case['context']), context_max_tokens - comment_length,'left', tokenizer_raw)
    context = make_str_block_with_max_token_length(tokenizer, context_max_tokens - comment_length,
                                                   "".join(case['context']))
    context_prompt = comment + context
    context_length = len(tokenizer.tokenize(context_prompt))

    # retrieved example
    seperator = '# ' + '-' * 50
    retrieval_prompt = "# Here are some relevant code fragments from other files of the repo:\n"
    retrieval_prompt += seperator + '\n'

    num_chosen_context = 0
    max_retrieval_length = max_num_tokens - cross_file_length-context_length
    current_token_length = len(tokenizer.tokenize(retrieval_prompt))
    retrival_blocks = []
    top_k_context = case['top_k_context']
    for i in range(1, len(top_k_context) + 1):
        retrieval_context = top_k_context[-i]
        if num_chosen_context >= max_top_k:
            break
        block_str, token_len = make_an_extended_block(retrieval_context, tokenizer)
        if current_token_length + token_len < max_retrieval_length:
            retrival_blocks.insert(0, block_str)
            current_token_length += token_len
            num_chosen_context += 1
        else:
            continue
    retrieval_prompt += ''.join(retrival_blocks)
    return retrieval_prompt + cross_prompt + context_prompt


# only repocoder Variant
def build_retrieval_prompt5(case, tokenizer, tokenizer_raw,max_num_tokens, max_top_k):

    cross_max_tokens = max_num_tokens // 2

    comment1 = "# Cross_file information:\n"
    seperator = '# ' + '-' * 50
    cross_file_context = case['cross_file_line']
    comment_length1 = len(tokenizer.tokenize(comment1))
    seperator_length = len(tokenizer.tokenize(seperator))
    # cross_file_context = truncate(cross_file_context, cross_max_tokens-comment_length1-seperator_length,'right' , tokenizer_raw)
    cross_file_context = make_str_block_with_max_token_length(tokenizer, cross_max_tokens-comment_length1-seperator_length,
                                                   "".join(case['cross_file_line']))

    cross_prompt = comment1 + cross_file_context +seperator
    cross_file_length = len(tokenizer.tokenize(cross_prompt))

    context_max_tokens= max_num_tokens // 2
    # context_max_tokens = 0
    comment = "# Based on above, complete the next statement of the following codes:\n"
    comment_length = len(tokenizer.tokenize(comment))
    # context = truncate("".join(case['context']), context_max_tokens - comment_length,'left', tokenizer_raw)
    context = make_str_block_with_max_token_length(tokenizer, context_max_tokens - comment_length,
                                                   "".join(case['context']))
    context_prompt = comment + context
    context_length = len(tokenizer.tokenize(context_prompt))

    return cross_prompt + context_prompt

def build_prompt(case, tokenizer, tokenizer_raw,max_num_tokens, max_top_k=10, mode='retrieval'):
    prompt = ""
    if mode == 'infile':
        prompt = build_infile_prompt(case, tokenizer, max_num_tokens)
    elif mode == 'retrieval':
        prompt = build_retrieval_prompt1(case, tokenizer, max_num_tokens, max_top_k)
    elif mode == 'pure_crossfile':
        prompt = build_retrieval_prompt3(case,tokenizer)
    elif mode == 'crossfile':
        prompt = build_retrieval_prompt2(case, tokenizer, max_num_tokens, max_top_k)
    elif mode == 'crossfile_java':
        prompt=build_retrieval_prompt4(case, tokenizer,tokenizer_raw, max_num_tokens, max_top_k)
    elif mode == 'crossfile_test':
        prompt = build_retrieval_prompt(case, tokenizer, max_num_tokens, max_top_k)
    elif mode =='crosscode1':
        prompt = build_retrieval_prompt5(case, tokenizer, tokenizer_raw, max_num_tokens, max_top_k)
    return prompt

