import os
import re
import json
import copy
import tempfile
import subprocess
import concurrent.futures
from tqdm import tqdm

class Evaluator():
    def __init__(self, pass_rate_predictor=None):
        self.pass_rate_predictor = pass_rate_predictor

    def calculate_pass_rate_score(self, test_results, test_weights):
        total_weight = sum(weight for test_id, weight in test_weights.items() if test_id in list(test_results.keys()))
        if total_weight == 0:
            return 0.0
        
        passed_weight = sum(weight for test_id, weight in test_weights.items() if test_results.get(test_id, {}).get("success", False))
        
        return passed_weight / total_weight
    
    def calculate_pass_rate(self, test_results):
        test_number = len(test_results)
        if test_number == 0:
            return 0.0
        
        passed_number = sum(1 for result in test_results.values() if result.get("success", False))
        return passed_number / test_number

    def calculate_batch_scores(self, code_data, use_irl=False):
        items = list(code_data.items())
        code_ids = [k for k, _ in items]
        code_entries = [v for _, v in items]
        full_score_dict = {}

        # 计算pass_rate_score（快速计算，无需并行）
        pass_rate_scores = {
            code_id: self.calculate_pass_rate_score(entry["test_results"], entry["test_weights"])
            for code_id, entry in code_data.items()
        }

        # 计算pass_rate（快速计算，无需并行）
        pass_rate = {
            code_id: self.calculate_pass_rate(entry["test_results"])
            for code_id, entry in code_data.items()
        }

        # 批量预测score
        code_strs = [entry["code"] for entry in code_entries]
        prediction_scores = [0.0] * len(code_strs)
        if self.pass_rate_predictor is not None and self.pass_rate_predictor.model is not None:
            try:
                prediction_scores = self.pass_rate_predictor.predict_score(code_strs)
                print("###############################################################")
                print(f"Prediction scores: {prediction_scores}")
                print("###############################################################")
            except Exception as e:
                print(code_strs)
                raise e

        # 并行计算静态分析分数
        with concurrent.futures.ThreadPoolExecutor() as executor:
            static_scores = list(tqdm(
                executor.map(self._compute_static_scores, code_strs),
                total=len(code_strs),
                desc="Analyzing codes"
            ))

        # 组合最终分数
        final_scores = {}
        if use_irl:
            irl_scores = self.irl_adjust_scores(final_scores)
            for i, code_id in enumerate(code_ids):
                final_scores[code_id] = (
                    0.6 * pass_rate_scores[code_id] +
                    0.25 * irl_scores[i] +
                    0.05 * prediction_scores[i] +
                    0.05 * static_scores[i][0] +
                    0.05 * static_scores[i][1]
                )
                full_score_dict[code_id] = {
                    "pass_rate": pass_rate[code_id],
                    "pass_rate_score": pass_rate_scores[code_id],
                    "prediction_score": prediction_scores[i],
                    "pylint_score": static_scores[i][0],
                    "radon_score": static_scores[i][1],
                    "irl_score": irl_scores[i]
                }
        else:
            for i, code_id in enumerate(code_ids):
                final_scores[code_id] = (
                    0.7 * pass_rate_scores[code_id] +
                    0.1 * prediction_scores[i] +
                    0.1 * static_scores[i][0] +
                    0.1 * static_scores[i][1]
                )
                full_score_dict[code_id] = {
                    "pass_rate": pass_rate[code_id],
                    "pass_rate_score": pass_rate_scores[code_id],
                    "prediction_score": prediction_scores[i],
                    "pylint_score": static_scores[i][0],
                    "radon_score": static_scores[i][1]
                }

        return final_scores, full_score_dict

    def _compute_static_scores(self, code_str):
        return (
            self.pylint_code_score(code_str),
            self.radon_mi_code_score(code_str)
        )

    def pylint_code_score(self, code):
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
                tmp.write(code)
                tmp_path = tmp.name
            
            result = subprocess.run(
                ["pylint", "--output-format=text", tmp_path],
                capture_output=True,
                text=True,
                check=False
            )
            os.unlink(tmp_path)
            
            match = re.search(r"rated at (\d+\.?\d*)/10", result.stdout)
            return float(match.group(1))/10 if match else -1
        
        except Exception as e:
            print(f"Pylint analysis failed: {e}")
            return -1

    def radon_mi_code_score(self, code):
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
                tmp.write(code)
                tmp_path = tmp.name
            
            result = subprocess.run(
                ["radon", "mi", "--json", tmp_path],
                capture_output=True,
                text=True,
                check=False
            )
            os.unlink(tmp_path)
            
            data = json.loads(result.stdout)
            if data and isinstance(data, dict):
                return list(data.values())[0]["mi"] / 100
            return -1
        except Exception as e:
            print(f"Radon analysis failed: {e}")
            return -1