import re
import os
import json
import argparse
from pathlib import Path
from datetime import datetime
import hashlib
from tqdm import tqdm
from utils import (
    check_and_reformat_input_code,
    get_input_output_code_modify_for_item,
    get_diff_str_compared_with_old,
    NO_DIFF_STR,
    extract_first_last_lineno
)

BUG_TYPE_MAPPING_DICT = {
    "Logical error": "Logical relationship defects",
    "Resource leak": "Common code defects",
    "Security vulnerabilities": "Security vulnerable defects",
    "Others": "Other defects",
    "Performance issues": "Common code defects",
    "Code style": "no bug",
    "Optimization suggestions": "no bug",
}

UNKNOWN_DIFF_STR = 'unknown'

def load_data(data_dir: str):
    data_list = []
    data_dir_path = Path(data_dir)
    for json_file in data_dir_path.glob("*.json"):
        with open(json_file, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                data_list.extend(data)
                print(f"Read {json_file} successfully, data amount：{len(data)}")
            except Exception as e:
                print(f"Error when read {json_file}: {e}")
    return data_list

def data_analysis(data_list: list):
    from collections import Counter

    status_counter = Counter()
    language_counter = Counter()
    ai_type_counter = Counter()
    type_counter = Counter()

    for data in data_list:
        language = data.get('programming_language', 'Unknown').lower()
        language_counter[language] += 1
        review_comments = data['code_review_api']['output']

        for comment in review_comments:
            status_counter[comment.get('accepted_status', 'Unknown')] += 1

            ai_type = comment.get('ai_type', 'Unknown')
            ai_type_counter[ai_type] += 1

            type = str(comment.get('type', ['Unknown']))
            type_counter[type] += 1

    print("--------------------------------")
    print(f"Original data amount: {len(data_list)}")
    print("programming_language:", "，".join([f"{lang}: {count}" for lang, count in language_counter.items()]))
    print("type:", "，".join([f"{k}: {v}" for k, v in type_counter.items()]))
    print("ai_type:", "，".join([f"{k}: {v}" for k, v in ai_type_counter.items()]))
    print("accepted_status:", "，".join([f"{status}: {count}" for status, count in status_counter.items()]))
    print("--------------------------------")


def standard_data_analysis(data_list: list):
    from collections import Counter
    language_counter = Counter()
    bug_type_counter = Counter()
    repaired_code_counter = Counter()

    for data in data_list:
        lang = data.get('programming_language', 'Unknown')
        bug_type = data.get('output_bug_type', 'Unknown')
        has_repaired_code = 'yes' if data.get('output_repaired_code', '').strip() else 'no'

        language_counter[lang] += 1
        bug_type_counter[bug_type] += 1
        repaired_code_counter[has_repaired_code] += 1

    print("--------------------------------")
    print(f"Data amount: {len(data_list)}")
    print("programming_language:", "，".join([f"{lang}: {count}" for lang, count in language_counter.items()]))
    print("output_bug_type:", "，".join([f"{bug_type}: {count}" for bug_type, count in bug_type_counter.items()]))
    print("has output_repaired_code:", "，".join([f"{k}: {v}" for k, v in repaired_code_counter.items()]))
    print(f"input_code_diff = {UNKNOWN_DIFF_STR}: {len([x for x in data_list if x['input_code_diff'] == UNKNOWN_DIFF_STR])} (out of {len(data_list)})")
    print(f"input_code_diff = {NO_DIFF_STR}: {len([x for x in data_list if x['input_code_diff'] == NO_DIFF_STR])} (out of {len(data_list)})")
    print("--------------------------------")


def get_dataid(data, data_source_tag):
    if data['id']:
        code_unique_id = f"{data_source_tag}_{data['id']}"
    else:
        code_hash = hashlib.md5(str(data['code_review_api']['input']).encode('utf-8')).hexdigest()
        code_unique_id = f"{data_source_tag}_{code_hash}"
    return code_unique_id

def get_context(data):
    raw_context = data['code_review_api']['input']['context']
    try:
        context_dict = json.loads(raw_context)
        context_entities = context_dict['code_graph']['entities']
        context_str = ''
        for entity in context_entities:
            file_path = entity['file_path']
            sign = entity['sign']
            context_str += '# ' + file_path+'\n'
            context_str += sign+'\n'
    except:
        context_str = raw_context
    return context_str


def convert_to_standardformat(data_list: list, data_source_tag: str):
    standard_data_list = []
    for idx, data in enumerate(tqdm(data_list)):
        code_unique_id = get_dataid(data, data_source_tag)
        if code_unique_id in [data['id'] for data in standard_data_list]:
            print(f"Unique id {code_unique_id} exists, skip")
            continue

        review_comments = data['code_review_api']['output']
        if len(review_comments) == 0:
            continue
        else:
            review_comment = review_comments[0]
            for comment in review_comments[1:]:
                if review_comment['accepted_status'] == 'resolved':
                    break
                if comment['accepted_status'] in ['resolved', 'rejected']:
                    review_comment = comment

        old_file, current_file, fixed_file = review_comment['origin_file_content'], review_comment['issue_file_content'], review_comment['final_file_content']
        if '404 File Not Found' in f"{current_file=}\n{fixed_file=}":
            print(f"Error (issue_file_content/final_file_content 404 File Not Found): {data['id']}")
            continue
        if 'Page Not Found' in f"{current_file=}\n{fixed_file=}":
            print(f"Error (issue_file_content/final_file_content Page Not Found): {data['id']}")
            continue
        if len(data['code_review_api']['input']['code'].splitlines()) > 200:
            print(f"Error (data['code_review_api']['input']['code'] too long): {data['id']}")
            continue

        if review_comment['accepted_status'] == 'resolved':
            mapped_output_bug_type = BUG_TYPE_MAPPING_DICT[review_comment['ai_type']]
        elif review_comment['accepted_status'] == 'rejected':
            mapped_output_bug_type = "no bug"
        else:
            print("Other status, skip")
            continue

        try:
            if mapped_output_bug_type == "no bug":
                _, input_review_code, _, _ = check_and_reformat_input_code(
                    old_file_content=current_file,
                    input_code=data['code_review_api']['input']['code']
                )
                output_repaired_code = ''
                output_code_in_diff = ''
            else:
                input_review_code, output_repaired_code, output_code_in_diff = get_input_output_code_modify_for_item(
                    old_file_content=data['fixed_file_content'], 
                    new_file_content=data['origin_file_content'], 
                    input_code=data['code_review_api']['input']['code']
                )
                if output_repaired_code is None:
                    mapped_output_bug_type = "no bug"
                    output_repaired_code = ''
                    output_code_in_diff = ''
        except:
            print(f"Error (data['code_review_api']['input']['code'] not in issue_file_content): {data['id']}")
            continue

        if '404 File Not Found' in old_file or 'Page Not Found' in old_file:
            print(f"Error (origin_file_content Not Found): {data['id']}")
            input_code_diff = UNKNOWN_DIFF_STR
        else:
            input_code_diff = get_diff_str_compared_with_old(old_file, current_file, input_review_code)

        start_lineno, end_lineno = extract_first_last_lineno(input_review_code.strip(("\n")))
        lineno_range = [] if mapped_output_bug_type=='no bug' else [start_lineno, end_lineno]
        standard_data = {
            "id": code_unique_id,
            "source": data_source_tag,
            "input_source_code": input_review_code.strip(("\n")),
            "input_code_diff": input_code_diff, 
            "input_source_code_comment": data['code_review_api']['input']['description'],
            "input_source_code_context": get_context(data),
            "lineno_range": lineno_range,
            "output_bug_type": mapped_output_bug_type,
            "output_bug_explanation": re.sub("Programming Rules \d+:", '', review_comment['explanation']).strip(), 
            "output_repaired_code": output_repaired_code.strip("\n"),
            "output_repaired_code_in_diff": output_code_in_diff.strip("\n"), 
            "programming_language": data['programming_language'].lower(),
            "timestamp": review_comment['created_at']
        }
        standard_data_list.append(standard_data)

    return standard_data_list

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CodeReviewFeedback to stantard format")
    parser.add_argument('--input_path', type=str, default=r"D:\WorkDocs\8-AICoding\DATA\CodeReviewFeedback\review_feedback_0911_timerange202501-09")
    parser.add_argument('--output_dir', type=str, default="D:\WorkDocs\8-AICoding\code_data_flywheel\OUTPUT\CRFeedback_prepare")
    parser.add_argument('--train_test_split_time', type=str, default="2025-08-01T00:00:00+00:00")
    args = parser.parse_args()

    args.data_source_tag = os.path.basename(args.input_path.rstrip(os.sep))

    data_list = load_data(args.input_path)
    data_analysis(data_list)

    standard_data_list = convert_to_standardformat(data_list, args.data_source_tag)
    print(f"IO format data={len(standard_data_list)}")
    standard_data_analysis(standard_data_list)

    train_split = [x for x in standard_data_list if datetime.fromisoformat(x['timestamp']) < datetime.fromisoformat("2025-08-01T00:00:00+00:00")]
    test_split = [x for x in standard_data_list if datetime.fromisoformat(x['timestamp']) > datetime.fromisoformat("2025-08-01T00:00:00+00:00") and
                                                   datetime.fromisoformat(x['timestamp']) < datetime.fromisoformat("2025-09-01T00:00:00+00:00")]

    train_output_path = f"{args.output_dir}/TRAIN_{args.data_source_tag}_ioformat.json"
    with open(train_output_path, "w", encoding="utf-8") as f:
        print(f"Train data={len(train_split)}, exported as {train_output_path}")
        standard_data_analysis(train_split)
        json.dump(train_split, f, ensure_ascii=False, indent=4)

    test_output_path = f"{args.output_dir}/TEST_{args.data_source_tag}_ioformat.json"
    with open(test_output_path, "w", encoding="utf-8") as f:
        print(f"Test data={len(test_split)}, exported as {test_output_path}")
        standard_data_analysis(test_split)
        json.dump(test_split, f, ensure_ascii=False, indent=4)
