import os
import json
import argparse
import random
from tqdm import tqdm
from jinja2 import Template

BUG_STR = 'yes'
NOBUG_STR = 'no'
OUTPUT_PROMPT_TEMP_NOBUG = Template("""<type>
{{ output_category }}
</type>""")
OUTPUT_PROMPT_TEMP_BUG = Template("""<type>
{{ output_category }}
</type>
<review>
{{ output_comment }}
</review>
<repaired_code>
{{ output_code }}
</repaired_code>""")
OUTPUT_PROMPT_TEMP_BUG_DIFF = Template("""<type>
{{ output_category }}
</type>
<review>
{{ output_comment }}
</review>
<repaired_code>
{{ output_code_in_diff }}
</repaired_code>""")

def print_data_stats(results):
    print('total_count:', len(results))
    print(f"need review / no review: {len([x for x in results if x['output_category'] == BUG_STR])} / {len([x for x in results if x['output_category'] == NOBUG_STR])}", )
    print(f"no lang count: {len([x for x in results if x['language'] == ''])}", )

def format_sharegpt_data(args, data):
    results = []
    for item in data:
        if item['output_category'] == NOBUG_STR:
            output = OUTPUT_PROMPT_TEMP_NOBUG.render(item)
        else:
            if "output=all+outputdiff" in args.prompt_path:
                output = OUTPUT_PROMPT_TEMP_BUG_DIFF.render(item)
            elif "output=cls" in args.prompt_path:
                output = OUTPUT_PROMPT_TEMP_NOBUG.render(item)
            else:
                output = OUTPUT_PROMPT_TEMP_BUG.render(item)

        sharegpt_conv = {
            'id': item['id'],
            'conversations': [
                {
                    'from': 'user',
                    'value': INPUT_PROMPT_TEMP.render(item)
                },
                {
                    "from": "assistant",
                    "value": output
                }
            ]
        }

        if args.mode == 'train':
            results.append(sharegpt_conv)
        else:
            item['sharegpt_format'] = sharegpt_conv
            results.append(item)

    return results


def crfeedback_data_special_preprocess(io_format_data):
    for x in io_format_data:
        x['language'] = x['programming_language']
        x['input_code'] = x['input_source_code']
        x['output_comment'] = x['output_bug_explanation']
        x['output_code'] = x['output_repaired_code']
        x['output_code_in_diff'] = x['output_repaired_code_in_diff']
        if x['output_bug_type'] == 'no bug':
            x['output_category'] = 'no'
        else:
            x['output_category'] = 'yes'

    return io_format_data

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Data preprocess')
    parser.add_argument('--mode', type=str, default='train', help='train or test')
    parser.add_argument('--input_ioformat_path', type=str, default=r"D:\WorkDocs\8-AICoding\code_data_flywheel\OUTPUT\CRFeedback_prepare\review_feedback_0724+0725_ioformat.json")
    parser.add_argument('--save_dir', type=str, default="D:\WorkDocs\8-AICoding\code_data_flywheel\OUTPUT\CRFeedback_prepare")
    parser.add_argument('--prompt_path', type=str, default=r"D:\WorkDocs\8-AICoding\code_data_flywheel\src\prompts\input=code_output=all_context=none_v0905.txt")
    args = parser.parse_args()

    pname = os.path.basename(args.prompt_path).replace(".txt", '')
    input_name = os.path.basename(args.input_ioformat_path).replace(".txt", '')
    args.save_sharegpt_filename = args.input_ioformat_path.replace('_ioformat.json', f"_prompt={pname}_{args.mode}_sharegptformat.json")

    with open(args.input_ioformat_path, 'r', encoding='utf-8') as f:
        io_format_data = json.load(f)
        io_format_data = crfeedback_data_special_preprocess(io_format_data)
    print_data_stats(io_format_data)
    print(f"Load ioformat data from {args.input_ioformat_path}")

    with open(args.prompt_path, 'r', encoding='utf-8') as f:
        INPUT_PROMPT_TEMP = Template(f.read())
        print(f"Read prompt template from {args.prompt_path}")

    sharegpt_data = format_sharegpt_data(args, io_format_data)

    with open(args.save_sharegpt_filename, 'w', encoding='utf-8') as f:
        f.write(json.dumps(sharegpt_data, ensure_ascii=False, indent=4))
        print(f"saved to {args.save_sharegpt_filename}")
