import json

from responser import respond_gpt as respond
from prompt import PromptManager
from extract_json_code import extract_json_code, extract_diff_code
import logging
import re
from logging.handlers import RotatingFileHandler
import time
from copy import deepcopy

from fuzzywuzzy import fuzz
import os
import tqdm
import argparse

from concurrent.futures import ThreadPoolExecutor, as_completed


class Retry:
    """处理异常的上下文管理器 + 迭代器"""

    def __init__(self, max_tries=5):
        self.max_tries = max_tries

    def __iter__(self):
        for i in range(self.max_tries):
            yield self
            if self.allright or i == self.max_tries - 1:
                return
            time.sleep(3)

    def __enter__(self):
        self.allright = False

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.allright = True
        else:
            print(exc_val)
        return True


def format_sft(instruction, output, history=None):
    data = {"instruction": instruction, "output": output}
    if history:
        data["history"] = history
    return data


def pretty_json(data) -> str:
    """返回美化后的 JSON 字符串，保持 Unicode 字符不被转义"""
    return json.dumps(data, indent=2, ensure_ascii=False)


def generate_issue_metric(issue, issue_ground_truth, patch, code, message, prompts):
    prompt_text = (
        prompts.get_prompt("评测Issue质量")
        + "\nThe following is the specific Issue_A information:\n```json\n"
        + pretty_json(issue_ground_truth)
        + "```\nThe score of Issue_A is:\n```json"
        + '{"Title": 8, "Description": 8, "Reproducibility": 8, "Relevance": 8, "Explanation": 8}'
        + "```"
        + "\nThe following is the specific Issue_B information which you need to score:\n```json\n"
        + pretty_json(issue)
        + "```"
        
        
        
        + "\nThe following is the runtime error message:\n```\n" if message else ""
        + message if message else ""
        + "```\n" if message else ""
        + "The following is the specific patch information:\n```\n"
        + patch
        + "```\n"
        + prompts.get_prompt("评测Issue质量")
    )
    logging.info("Prompt for issue metric:\n%s", prompt_text)
    for retry in Retry(20):
        with retry:
            response = respond(prompt_text)
            logging.info("issue metric:\n" + response)
            response = extract_json_code(response)
            response = json.loads(response)
    if retry.allright:
        return response
    else:
        issue_message = {
            "title and description": response 
        }


def generate_code_location_metric(pred, gt):
    def exact_match_score(predictions, ground_truths):
        if len(predictions) != len(ground_truths):
            raise ValueError("The length of the predicted codes and the ground truth codes should be equal.")
        
        exact_match = 0
        for pred, gt in zip(predictions, ground_truths):
            if pred is None:
                pred = ""
            if gt is None:
                gt = ""
            if pred.split() == gt.split():
                exact_match += 1
        
        return round(exact_match / len(predictions), 5)

    def edit_similarity_score(predictions, ground_truths):
        if len(predictions) != len(ground_truths):
            raise ValueError("The length of the predicted codes and the ground truth codes should be equal.")
        
        edit_sim = 0.0
        for pred, gt in zip(predictions, ground_truths):
            if pred is None:
                pred = ""
            if gt is None:
                gt = ""
            edit_sim += fuzz.ratio(pred, gt)
        
        return round(edit_sim / len(predictions), 5)

    def calculate_line_scores(test_lines, gen_lines):
        es_scores = []
        em_scores = []
        for test_line in test_lines:
            max_es = 0
            max_em = 0
            for gen_line in gen_lines:
                es = exact_match_score([test_line], [gen_line])
                em = edit_similarity_score([test_line], [gen_line])
                if es > max_es:
                    max_es = es
                if em > max_em:
                    max_em = em
            es_scores.append(max_es)
            em_scores.append(max_em)
        avg_es = round(sum(es_scores) / len(es_scores), 5) if es_scores else 0.0
        avg_em = round(sum(em_scores) / len(em_scores), 5) if em_scores else 0.0
        return avg_es, avg_em, es_scores, em_scores

    results = []
    for test_item in gt:
        es_scores = {
            'file': [],
            'function': [],
            'content_all': [],
            'content_change': []
        }
        em_scores = {
            'file': [],
            'function': [],
            'content_all': [],
            'content_change': []
        }
        
        for gen_item in pred:
            if "location" in gen_item:
                test_lines = gen_item["location"].split('\n')
                es_scores = {
                    'file': [-1],
                    'function': [-1],
                    'content_all': [],
                    'content_change': [-1]
                }
                em_scores = {
                    'file': [-1],
                    'function': [-1],
                    'content_all': [],
                    'content_change': [-1]
                }
                key = 'content_all'
                
                gen_lines = gen_item.get(key, "").split('\n')
                test_lines = [x.strip() for x in test_lines]
                gen_lines = [x.strip() for x in gen_lines]
                es_score, em_score, es_line_scores, em_line_scores = calculate_line_scores(test_lines, gen_lines)
                
                es_scores[key].append(es_score)
                em_scores[key].append(em_score)

                logging.info(f"Error JSON, so skip this item and {key}'s ES is {es_score}, EM is {em_score}")
                break
        
            for key in ['file', 'function']:
                es_score = exact_match_score([test_item.get(key, "")], [gen_item.get(key, "")])
                em_score = edit_similarity_score([test_item.get(key, "")], [gen_item.get(key, "")])
                
                es_scores[key].append(es_score)
                em_scores[key].append(em_score)
            
            for key in ['content_all', 'content_change']:
                
                value = test_item.get(key, "")
                if isinstance(value, list):
                    test_lines = [x.strip() for x in value]  
                elif isinstance(value, str):
                    test_lines = [value.strip()]  
                elif isinstance(value, dict):
                    test_lines = [value.get(x, "").strip() for x in value]  
                else:
                    test_lines = [""]
                
                value = gen_item.get(key, "")
                if isinstance(value, list):
                    gen_lines = [x.strip() for x in value]
                elif isinstance(value, str):
                    gen_lines = [value.strip()]
                elif isinstance(value, dict):
                    gen_lines = [value.get(x, "").strip() for x in value]
                else:
                    gen_lines = [""]
                test_lines = [x.strip() for x in test_lines]
                gen_lines = [x.strip() for x in gen_lines]
                es_score, em_score, es_line_scores, em_line_scores = calculate_line_scores(test_lines, gen_lines)
                
                es_scores[key].append(es_score)
                em_scores[key].append(em_score)
                
                
                logging.info(f"Test item {key} line scores (ES): {es_line_scores}")
                logging.info(f"Test item {key} line scores (EM): {em_line_scores}")
        
        avg_es_scores = {key: round(sum(scores) / len(scores), 5) if scores else -1 for key, scores in es_scores.items()}
        avg_em_scores = {key: round(sum(scores) / len(scores), 5) if scores else -1 for key, scores in em_scores.items()}
        
        result = {
            'avg_es_scores': avg_es_scores,
            'avg_em_scores': avg_em_scores
        }
        results.append(result)
        logging.info(f"Processed test item results: {result}")
    
    return results


def generate_code_patch(issue, code, prompts, localization=None):
    prompt_text = (
        prompts.get_prompt("解决Issue")
        + "\nThe following is the specific Issue information:\n```json\n"
        + pretty_json(issue)
        + "```\nThe following is the specific Code information:\n```json\n"
        + pretty_json(code)
        + "```"
    )
    if localization:
        prompt_text += (
            "\nThe following is the specific code localization which may cause the issue information:\n```json\n"
            + pretty_json(localization)
            + "```"
        )
    logging.info("Prompt for generate_code_patch:\n%s", prompt_text)
    response = respond(prompt_text)
    logging.info("Response for generate_code_patch:\n%s", response)
    res = extract_diff_code(response)
    logging.info("Extracted JSON for generate_code_patch:\n%s", res)
    return res




def add_json(json_data):
    return "```json" + json.dumps(json_data, indent=4, ensure_ascii=False) + "```"







def convert_sets_to_lists(data):
    if isinstance(data, set):
        return list(data)
    elif isinstance(data, dict):
        return {key: convert_sets_to_lists(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [convert_sets_to_lists(element) for element in data]
    return data



def run_test(data, output_dir, file_name):
    
    logging.basicConfig(level=logging.INFO)

    
    handler = RotatingFileHandler(
        "test_metric.log", maxBytes=1024 * 1024 * 5, backupCount=5  
    )

    
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    handler.setFormatter(formatter)

    
    logger = logging.getLogger()
    logger.handlers = []  
    logger.addHandler(handler)

    

    output_path = os.path.join(output_dir, file_name)

    if os.path.exists(output_path):
        return

    prompts = PromptManager()

    issue = data['Results']["issue_ground_truth"]
    
    message = data['Results']["message"]
    patch = data['Results']["patch_ground_truth"]
    location = data['Results']["location_ground_truth"]
    
    
    
    
    
    
    

    
    issue_origin = generate_issue_metric(data['Results']['issue_origin'], issue, patch, location, message, prompts)  
    issue_message = generate_issue_metric(data['Results']['issue_message'], issue, patch, location, message, prompts)  
    issue_ground = generate_issue_metric(data['Results']['issue_ground'], issue, patch, location, message, prompts)  
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    results = {
        
        "Difficulty": data["Difficulty"],
        "issue_origin": issue_origin,
        "issue_message": issue_message,
        "issue_ground": issue_ground,
        "issue_ground_truth": issue,
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
    }
    
    data['Score'] = results
    with open(
        
        output_path, "w", encoding="utf-8"
    ) as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
        
        print("Save to file:", output_path)



def old():
    parser = argparse.ArgumentParser(description="Generate command line parameters.")
    
    
    
    

    
    
    
    
    
    
    
    
    
    
    
    
    model = "GPT4"
    
    
    model = "deepseek-coder"
    total_lengths = [1500, 3000, 6500, 13000, 30000, 61000, 124000]
    for total_length in total_lengths:
        dir = f'test_gen_new/{total_length}_{model}'
        output_dir = f"test_result_new/{total_length}_{model}/"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        for file in tqdm.tqdm(os.listdir(dir), desc=f"Run Metric for {model}"):
            if file == '.DS_Store':
                continue
            file_path = os.path.join(dir, file)
            print(file_path)
            with open(file_path, "r", encoding="utf-8-sig") as f:
                json_data = f.read()
            json_data = json.loads(json_data)  
            
            run_test(json_data, output_dir, model=model, idd=file_path.split('/')[-1])

























                    


                


                



def process_file(input_base_dir, output_base_dir, folder, file):
    input_dir = os.path.join(input_base_dir, folder)
    output_dir = os.path.join(output_base_dir, folder)
    
    if not file.endswith('.json') or file == '.DS_Store':
        return
    output_file_path = os.path.join(output_dir, file)
    
    
    if os.path.exists(output_file_path):
        return
        
    file_path = os.path.join(input_dir, file)
    print(f"Processing file: {file_path}")
    
    with open(file_path, "r", encoding="utf-8-sig") as f:
        json_data = f.read()
    
    json_data = json.loads(json_data)
    run_test(json_data, output_dir, file)

if __name__ == "__main__":
    input_base_dir = 'test_gen_new'
    output_base_dir = 'test_result_new'

    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for folder in os.listdir(input_base_dir):
            input_dir = os.path.join(input_base_dir, folder)
            output_dir = os.path.join(output_base_dir, folder)

            if os.path.isdir(input_dir):
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                for file in os.listdir(input_dir):
                    futures.append(executor.submit(process_file, input_base_dir, output_base_dir, folder, file))

        for future in tqdm.tqdm(as_completed(futures), total=len(futures), desc="Processing folders"):
            future.result()

    print("完成! OVER!")
