import argparse
import json
import os
from tqdm import tqdm
import re
from transformers import AutoTokenizer

def extract_code_blocks(text):
    pattern = r"```\n(.*?)\n```"
    matches = re.findall(pattern, text, re.DOTALL)
    # if len(matches) == 0:
    #     if "```" in text:
    #         # handle the case where the code block is not complete
    #         return [text.split("```", 1)[-1].strip()]
    return matches

def rerank_file_localization(gt_data_item, loc_file_item):
    if error_times_in_dialogue(loc_file_item['dialogue']) > 0:
        return []
    assert gt_data_item['instance_id'] == loc_file_item['instance_id']
    assert loc_file_item['dialogue'][-1]['type'] == 'file_localization_based_on_skeleton_result'

    
    pred_loc_strings = extract_code_blocks(loc_file_item['dialogue'][-1]['content'])
    if len(pred_loc_strings) == 0:
        # print(f"File localization instance {loc_file_item['instance_id']} has no code block!")
        # print(f"Content:\n{loc_file_item['dialogue'][-1]['content']}\n")
        pred_loc_string = loc_file_item['dialogue'][-1]['content']
    else:
        pred_loc_string = pred_loc_strings[-1]

    pred_loc_files = [line.strip() for line in pred_loc_string.strip().split('\n') if line.strip() != '']
    gt_files = gt_data_item['gt_localization']['modified_files'] + gt_data_item['gt_localization']['removed_files']

    reranked_files = []
    for gt_file in gt_files:
        reranked_files.append(gt_file)
        if gt_file not in pred_loc_files:
            # print(f"File localization missing GT file: {gt_file}")
            # print(f"Predicted files: {pred_loc_files}")
            # print(f"GT files: {gt_files}")
            return []
        else:
            pred_loc_files.remove(gt_file)
    
    reranked_files.extend(pred_loc_files)

    # check if all files in file_localization_based_on_dependecy_result 
    loc_on_dependency_result = None
    for dialogue_turn in loc_file_item['dialogue']:
        if 'type' in dialogue_turn and dialogue_turn['type'] == 'file_localization_based_on_dependecy_result':
            loc_on_dependency_result = dialogue_turn['content']

    if loc_on_dependency_result:
        loc_on_dependency_blocks = extract_code_blocks(loc_on_dependency_result)
        dep_loc_files = [line.strip() for line in loc_on_dependency_blocks[-1].strip().split('\n') if line.strip() != '']
        for file in reranked_files:
            if file not in dep_loc_files:
                return []
    else:
        return []
    
    reranked_files_string = '\n'.join(reranked_files)

    assert pred_loc_string in loc_file_item['dialogue'][-1]['content'], f"Predicted localization string not in the last dialogue content!"

    loc_file_item['dialogue'][-1]['content'] = loc_file_item['dialogue'][-1]['content'].replace(pred_loc_string, reranked_files_string)

    return reranked_files

def if_a_is_contained_in_b(a_unit, b_unit):
    file_a, pos_a = a_unit.strip().split('\n')
    file_b, pos_b = b_unit.strip().split('\n')
    if file_a == file_b:
        if pos_a == pos_b:
            return True
        else:
            if pos_a.startswith('function: ') and pos_b.startswith('class: '):
                func_name_a = pos_a[len('function: '):].strip()
                class_name_b = pos_b[len('class: '):].strip()
                if func_name_a.startswith(class_name_b + '.'):
                    return True
    return False

def rerank_func_localization_results(gt_data_item, loc_func_item):
    assert gt_data_item['instance_id'] == loc_func_item['instance_id']
    assert loc_func_item['dialogue'][-1]['type'] == 'summary_localization_request'
    pred_loc_block = extract_code_blocks(loc_func_item['dialogue'][-1]['content'])
    assert len(pred_loc_block) == 1, f"Function localization instance {loc_func_item['instance_id']} should have exactly one code block!"
    pred_loc_str = pred_loc_block[0]
    pred_loc_units = pred_loc_str.strip().split('\n\n')

    # show_pred_loc_units = '\n\n'.join(pred_loc_units)
    # print(f">>>>>>>>>>Predicted localization units: \n{show_pred_loc_units}")
    gt_localization_map = gt_data_item['gt_localization']['related_locs']
    gt_loc_units = []
    for file in gt_localization_map:
        gt_loc_units.extend([f'{file}\n{loc}' for loc in gt_localization_map[file]])

    # show_gt_loc_units = '\n\n'.join(gt_loc_units)
    # print(f">>>>>>>>>>GT localization units: \n{show_gt_loc_units}")

    ## rerank, put the gt localization units in the front
    reranked_units = []
    for unit in gt_loc_units:
        reranked_units.append(unit)
        # if unit is contained in any of the predicted units, remove that predicted unit
        is_contained = False
        for pred_unit in pred_loc_units:
            if if_a_is_contained_in_b(unit, pred_unit):
                pred_loc_units.remove(pred_unit)
                is_contained = True
        if is_contained == False:
            return []

    reranked_units.extend(pred_loc_units)

    # show_reranked_units = '\n\n'.join(reranked_units)
    # print(f">>>>>>>>>>Reranked localization units: \n{show_reranked_units}")
    assert pred_loc_str in loc_func_item['dialogue'][-1]['content'], f"Predicted localization block not in the last dialogue content!"
    reranked_loc_str = '\n\n'.join(reranked_units)
    loc_func_item['dialogue'][-1]['content'] = loc_func_item['dialogue'][-1]['content'].replace(pred_loc_str, reranked_loc_str)

    return reranked_units

def check_search_replace(edit_data_item):
    query = edit_data_item['dialogue'][0]['content']
    answer = edit_data_item['dialogue'][-1]['content']
    # find all search replace blocks in answer string and get search code string and replace code string
    pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)>>>>>>> REPLACE"
    matches = re.findall(pattern, answer, re.DOTALL)
    search_replace_blocks = []
    if matches is None or len(matches) == 0:
        # print(f"Edit data instance {edit_data_item['instance_id']} has no search replace block!")
        return False
    if '<<<<<<< SEARCH' not in answer or '>>>>>>> REPLACE' not in answer:
        # print(f"Edit data instance {edit_data_item['instance_id']} has incomplete search replace block!")
        return False
    if answer.count('<<<<<<< SEARCH') != answer.count('>>>>>>> REPLACE'):
        # print(f"Edit data instance {edit_data_item['instance_id']} has unmatched search replace block!")
        return False
    for match in matches:
        search_code = match[0].strip()
        replace_code = match[1].strip()
        search_replace_blocks.append((search_code, replace_code))
    
    for idx, (search_code, replace_code) in enumerate(search_replace_blocks):
        if search_code not in query:
            # print(f"Edit data instance {edit_data_item['instance_id']} search code not in query!")
            # print(f"Search code:\n{search_code}\n")
            # print(f"Query:\n{query}\n")
            return False

    return True

def replace_instruction_for_completion(completion_data_item):
    completion_data_item['dialogue'][0]['content'] = completion_data_item['dialogue'][0]['content'].replace(
        "Please help me to complete the code of the function at the position of",
        "Please help me to complete the code at the position of"
    )

def load_converstation_standard(trajactory):
    conversation = []
    for item in trajactory:
        conversation.append({"role": item["role"], "content": item["content"]})
    return conversation

def construct_stardard_format(data, source, file=None):
    standard_data = []
    for item in data:
        dialogue = item['dialogue']
        found_related_locs = item.get('found_related_locs', None)
        found_files = item.get('found_files', None)
        patch = item.get('patch', None)
        standard_item = {
            "instance_id": item["instance_id"],
            "dialogue": load_converstation_standard(dialogue),
            "source": source,
            "file": file,
            "language": item["language"]
        }
        if found_related_locs is not None:
            standard_item["found_related_locs"] = found_related_locs
        if found_files is not None:
            standard_item["found_files"] = found_files
        if patch is not None:
            standard_item["patch"] = patch
        standard_data.append(standard_item)
    return standard_data
    

def load_jsonl(jsonl_file: str) -> list:
    """
    Load JSONL file and skip any unparseable lines.

    Args:
        jsonl_file (str): JSONL file path.

    Returns:
        list: 一个containing所有SuccessParse出的 JSON object的list.
    """
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                # Try to parse each line and remove possible whitespace at the end
                data.append(json.loads(line.strip()))
            except json.JSONDecodeError:
                # If a line is not valid JSON, skip it
                # print(f"Skipping invalid line: {line.strip()}")
                continue
    return data

def error_times_in_dialogue(dialogue, error_str="Error: Model generation failed") -> int:
    """
    Calculatedialogue中Error的次数.

    Args:
        dialogue (list): 对话list, 每个element是一个dict, containing 'role' and 'content' 键.
        error_str (str): 用于识别错误的data list of strings.
    """
    error_count = 0
    for item in dialogue:
        if item['role'] == 'assistant' and error_str in item['content']:
            error_count += 1
    return error_count

def filter_file_loc_data(file_loc_data_list, max_error_times=0, languages=[ ]):
    final_data = []
    for item in tqdm(file_loc_data_list, desc="Filtering file_loc_data"):
        if languages:
            if item["language"] not in languages:
                continue
        files = ''.join(item['found_files'])
        if files.strip():
            error_time = error_times_in_dialogue(item["dialogue"])
            if error_time <= max_error_times:
                final_data.append(item)
    return final_data

def filter_func_loc_data(func_loc_data_list, max_error_times=0, languages=[ ]):
    final_data = []
    for item in tqdm(func_loc_data_list, desc="Filtering func_loc_data"):
        if languages:
            if item["language"] not in languages:
                continue
        all_locs = ''
        for key, value in item["found_related_locs"].items():
            all_locs += ''.join(value) + ' '
        if all_locs.strip() and 'anonymous' not in all_locs:
            error_time = error_times_in_dialogue(item["dialogue"])
            if error_time <= max_error_times:
                final_data.append(item)
    return final_data

def filter_task_data(task_data_list, max_error_times=0, languages=[ ]):
    final_data = []
    for item in tqdm(task_data_list, desc="Filtering task_data"):
        if languages:
            if item["language"] not in languages:
                continue
        all_locs = ''
        for key, value in item["found_related_locs"].items():
            all_locs += ''.join(value) + ' '
        if all_locs.strip() and 'anonymous' not in all_locs:
            error_time = error_times_in_dialogue(item["dialogue"])
            if error_time <= max_error_times:
                final_data.append(item)
    return final_data

def filter_completion_data(completion_data_list, max_error_times=0, languages=[ ]):
    final_data = []
    for item in tqdm(completion_data_list, desc="Filtering completion_data"):
        if item['valid'] == False:
            continue
        if languages:
            if item["language"] not in languages:
                continue
        all_locs = ''
        for key, value in item["found_related_locs"].items():
            all_locs += ''.join(value) + ' '
        if 'anonymous' not in all_locs:
            error_time = error_times_in_dialogue(item["dialogue"])
            if error_time <= max_error_times:
                final_data.append(item)
    return final_data

def get_blacklist(raw_list, filtered_list):
    filtered_ids = {item["instance_id"] for item in filtered_list}
    return {item["instance_id"] for item in raw_list if item["instance_id"] not in filtered_ids}

def save_jsonl(data, path):
    """Savedata为JSONLformat"""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def filter_by_token_length(data_list, tokenizer, max_tokens):
    """
    根据tokenlengthFilterdata
    
    Args:
        data_list: 数据list
        tokenizer: 分词器
        max_tokens: maxtoken数
    
    Returns:
    """
    if not data_list:
        return []
    
    print(f"Applying tokenizer to {len(data_list)} samples...")
    prompts = []
    
    for item in tqdm(data_list, desc="Applying chat_template"):
        try:
            prompt = tokenizer.apply_chat_template(item["dialogue"], tokenize=False)
        except Exception as e:
            prompt = '\n'.join([turn["content"] for turn in item["dialogue"]])
        prompts.append(prompt)
    
    print(f"Tokenizing {len(prompts)} prompts...")
    tokenized = tokenizer(prompts, add_special_tokens=False)
    lengths = [len(ids) for ids in tokenized["input_ids"]]
    
    # 找出符合length要求的index
    indices_to_keep = {
        i for i, l in enumerate(lengths) if l < max_tokens
    }
    
    filtered_data = [
        item for i, item in enumerate(data_list)
        if i in indices_to_keep
    ]
    
    print(f"Filtered out {len(data_list) - len(filtered_data)} samples exceeding {max_tokens} tokens.")
    
    # 释放memory
    del prompts
    del tokenized
    del lengths
    
    return filtered_data

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Filter and save intermediate jsonl data.")
    parser.add_argument("--gt_file_path", type=str, required=True, help="Path to ground truth JSONL file")
    parser.add_argument("--file_traj_path", type=str, required=True, help="Path to loc_file_outputs.jsonl")
    parser.add_argument("--func_traj_path", type=str, required=True, help="Path to loc_func_outputs.jsonl")
    parser.add_argument("--task_traj_path", type=str, required=True, help="Path to edit_task.jsonl")
    parser.add_argument("--completion_traj_path", type=str, required=True, help="Path to completion_data.jsonl")
    parser.add_argument("--intermediate_dir", type=str, default="intermediate_data", help="Directory to save filtered intermediate jsonl files")
    parser.add_argument("--tokenizer_name", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Tokenizer model name or path.")
    parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for dialogue, data >= this will be filtered.")
    parser.add_argument("--filter_common_instances", action="store_true", help="Only save instances that exist in all three data types (file_loc, func_loc, task).")
    args = parser.parse_args()

    print("=== 数据筛选与保存中间文件阶段 ===")

    # Loadground truthdata
    gt_data = load_jsonl(args.gt_file_path)
    gt_data_map = {item['instance_id']: item for item in gt_data}
    print(f"Loaded {len(gt_data)} ground truth items.")

    
    ## Processfileslocalizationdata
    print("\n--- Processing File Localization Data ---")
    file_traj = load_jsonl(args.file_traj_path)
    print(f"Raw file_loc_data: {len(file_traj)}")
    original_file_loc_data_num = len(file_traj)
    init_file_loc_data = filter_file_loc_data(file_traj, max_error_times=0, languages=[])
    file_loc_data = []
    for item in tqdm(init_file_loc_data, desc="Reranking file_loc_data"):
        gt_item = gt_data_map[item['instance_id']]
        try:
            check_res = rerank_file_localization(gt_item, item)
            if check_res:
                file_loc_data.append(item)
        except Exception as e:
            pass
    
    file_blacklist = get_blacklist(file_traj, file_loc_data)
    file_loc_data = construct_stardard_format(file_loc_data, 'file_localization', args.file_traj_path)
    print(f"Filtered File localization data: {len(file_loc_data)}")

    
    ## Processfunctionlocalizationdata
    print("\n--- Processing Function Localization Data ---")
    func_traj_raw = load_jsonl(args.func_traj_path)
    print(f"Raw func_loc_data: {len(func_traj_raw)}")
    original_func_loc_data_num = len(func_traj_raw)
    func_traj = [item for item in func_traj_raw if item["instance_id"] not in file_blacklist]
    init_func_loc_data = filter_func_loc_data(func_traj, max_error_times=0, languages=[])
    func_loc_data = []
    for item in tqdm(init_func_loc_data, desc="Reranking func_loc_data"):
        gt_item = gt_data_map[item['instance_id']]
        try:
            check_res = rerank_func_localization_results(gt_item, item)
            if check_res:
                func_loc_data.append(item)
        except Exception as e:
            pass

    func_blacklist = file_blacklist | get_blacklist(func_traj, func_loc_data)
    func_loc_data = construct_stardard_format(func_loc_data, 'func_localization', args.func_traj_path)
    print(f"Filtered Function localization data: {len(func_loc_data)}")

    
    ## Processtaskdata
    print("\n--- Processing Task Data ---")
    task_traj_raw = load_jsonl(args.task_traj_path)
    print(f"Raw task_data: {len(task_traj_raw)}")
    original_task_data_num = len(task_traj_raw)
    task_traj = [item for item in task_traj_raw if item["instance_id"] not in func_blacklist]
    init_task_data = filter_task_data(task_traj, max_error_times=0, languages=[])

    task_data = []
    for item in tqdm(init_task_data, desc="Checking task_data"):
        try:
            check_res = check_search_replace(item)
            if check_res:
                task_data.append(item)
        except Exception as e:
            pass
    
    task_blacklist = func_blacklist | get_blacklist(task_traj, task_data)
    task_data = construct_stardard_format(task_data, 'code_edit', args.task_traj_path)
    print(f"Filtered Task data: {len(task_data)}")

    
    ## Processcompletiondata
    print("\n--- Processing Completion Data ---")
    completion_traj = load_jsonl(args.completion_traj_path)
    print(f"Raw completion_data: {len(completion_traj)}")
    original_completion_data_num = len(completion_traj)
    init_completion_data = filter_completion_data(completion_traj, max_error_times=0, languages=[])
    completion_data = []
    for item in tqdm(init_completion_data, desc="Checking completion_data"):
        try:
            check_res = check_search_replace(item)
            if check_res:
                replace_instruction_for_completion(item)
                completion_data.append(item)
        except Exception as e:
            pass
    # get the num of completion data with found_files attr not empty
    num_with_found_files = sum(1 for item in completion_data if item['found_files'])
    print(f"Completion data with found_files: {num_with_found_files}/{len(completion_data)}")

    completion_data = construct_stardard_format(completion_data, 'completion', args.completion_traj_path)
    print(f"Filtered Completion data: {len(completion_data)}")

    print(f"\n=== 应用Token长度筛选 ===")
    print(f"加载分词器: {args.tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
    
    # 对每种dataclass型进rowtokenlengthFilter
    print(f"\n--- 筛选文件定位数据的Token长度 ---")
    file_loc_data = filter_by_token_length(file_loc_data, tokenizer, args.max_tokens)
    
    print(f"\n--- 筛选函数定位数据的Token长度 ---")
    func_loc_data = filter_by_token_length(func_loc_data, tokenizer, args.max_tokens)
    
    print(f"\n--- 筛选任务数据的Token长度 ---")
    task_data = filter_by_token_length(task_data, tokenizer, args.max_tokens)
    
    print(f"\n--- 筛选Completion数据的Token长度 ---")
    completion_data = filter_by_token_length(completion_data, tokenizer, args.max_tokens)

    # 如果启用了共同instanceFilter, 只保留在三个data集中都exists的instance_id
    if args.filter_common_instances:
        print(f"\n=== 应用共同实例ID筛选 ===")
        
        # Get每个data集的instance_idset
        file_loc_ids = {item["instance_id"] for item in file_loc_data}
        func_loc_ids = {item["instance_id"] for item in func_loc_data}
        task_ids = {item["instance_id"] for item in task_data}
        
        # Calculate三个set的交集
        common_ids = file_loc_ids & func_loc_ids & task_ids
        
        print(f"文件定位数据实例ID数量: {len(file_loc_ids)}")
        print(f"函数定位数据实例ID数量: {len(func_loc_ids)}")
        print(f"任务数据实例ID数量: {len(task_ids)}")
        print(f"共同实例ID数量: {len(common_ids)}")
        
        # Filterdata, 只保留共同的instance_id
        original_file_loc_len = len(file_loc_data)
        original_func_loc_len = len(func_loc_data)
        original_task_len = len(task_data)
        
        file_loc_data = [item for item in file_loc_data if item["instance_id"] in common_ids]
        func_loc_data = [item for item in func_loc_data if item["instance_id"] in common_ids]
        task_data = [item for item in task_data if item["instance_id"] in common_ids]
        
        print(f"筛选后文件定位数据: {len(file_loc_data)}/{original_file_loc_len}")
        print(f"筛选后函数定位数据: {len(func_loc_data)}/{original_func_loc_len}")
        print(f"筛选后任务数据: {len(task_data)}/{original_task_len}")

    # Save经过tokenlengthFilter的data
    save_jsonl(file_loc_data, os.path.join(args.intermediate_dir, "filtered_file_loc_data.jsonl"))
    save_jsonl(func_loc_data, os.path.join(args.intermediate_dir, "filtered_func_loc_data.jsonl"))
    save_jsonl(task_data, os.path.join(args.intermediate_dir, "filtered_task_data.jsonl"))
    save_jsonl(completion_data, os.path.join(args.intermediate_dir, "filtered_completion_data.jsonl"))

    print(f"\n=== 数据筛选完成, 中间文件保存在: {args.intermediate_dir} ===")
    print(f"文件定位数据: {len(file_loc_data)}/{original_file_loc_data_num} 条")
    print(f"函数定位数据: {len(func_loc_data)}/{original_func_loc_data_num} 条")
    print(f"任务数据: {len(task_data)}/{original_task_data_num} 条")
    print(f"Completion数据: {len(completion_data)}/{original_completion_data_num} 条")
    print(f"总计: {len(file_loc_data) + len(func_loc_data) + len(task_data) + len(completion_data)} 条")
