from radon.metrics import h_visit
from radon.complexity import cc_visit
import warnings
import torch
import textwrap
from huggingface_hub import snapshot_download
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel
import json
import math
import ast
from collections import Counter

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # 1. 下载整个模型到本地路径（离线可读）
target_cache_dir = "/huggingface"
local_path = snapshot_download(repo_id="microsoft/codebert-base",cache_dir=target_cache_dir)
# 2. 加载 tokenizer 和 model（不再联网）
tokenizer = RobertaTokenizer.from_pretrained(local_path)
model = RobertaModel.from_pretrained(local_path)
model.to(device)


class HalsteadAnalyzer(ast.NodeVisitor):
    def __init__(self):
        self.operator_counter = Counter()
        self.operand_counter = Counter()
        self._in_function_call = False

    def visit_FunctionDef(self, node):
        self.operator_counter["def"] += 1
        self.operand_counter[node.name] += 1
        for arg in node.args.args:
            self.operand_counter[arg.arg] += 1
        self.generic_visit(node)

    def visit_ClassDef(self, node):
        self.operator_counter["class"] += 1
        self.operand_counter[node.name] += 1
        self.generic_visit(node)

    def visit_For(self, node):
        self.operator_counter["for"] += 1
        self.visit(node.target)
        self.visit(node.iter)
        for stmt in node.body:
            self.visit(stmt)
        for stmt in node.orelse:
            self.visit(stmt)

    def visit_While(self, node):
        self.operator_counter["while"] += 1
        self.generic_visit(node)

    def visit_If(self, node):
        self.operator_counter["if"] += 1
        self.generic_visit(node)

    def visit_With(self, node):
        self.operator_counter["with"] += 1
        self.generic_visit(node)

    def visit_Try(self, node):
        self.operator_counter["try"] += 1
        self.generic_visit(node)

    def visit_ExceptHandler(self, node):
        self.operator_counter["except"] += 1
        self.generic_visit(node)

    def visit_Raise(self, node):
        self.operator_counter["raise"] += 1
        self.generic_visit(node)

    def visit_Assert(self, node):
        self.operator_counter["assert"] += 1
        self.generic_visit(node)

    def visit_Assign(self, node):
        self.operator_counter["="] += 1
        for target in node.targets:
            self.visit(target)
        self.visit(node.value)

    def visit_AugAssign(self, node):
        self.operator_counter[type(node.op).__name__] += 1
        self.visit(node.target)
        self.visit(node.value)

    def visit_Subscript(self, node):
        # 访问 value[index] 的 value 和 index
        self.visit(node.value)  # 比如访问 season_flower_data
        self.visit(node.slice)  # 比如访问 season

    def visit_Index(self, node):  # 兼容旧版本AST（Python <=3.8）
        self.visit(node.value)

    def visit_BinOp(self, node):
        self.operator_counter[type(node.op).__name__] += 1
        self.visit(node.left)
        self.visit(node.right)

    def visit_BoolOp(self, node):
        self.operator_counter[type(node.op).__name__] += 1
        for value in node.values:
            self.visit(value)

    def visit_Compare(self, node):
        for op in node.ops:
            self.operator_counter[type(op).__name__] += 1
        self.visit(node.left)
        for comparator in node.comparators:
            self.visit(comparator)

    def visit_UnaryOp(self, node):
        if isinstance(node.op, ast.USub):
            # 只有当负号不是针对常量时，才计数USub
            if isinstance(node.operand, ast.Constant):
                # 特殊处理：负数常量，直接作为整体
                value = -node.operand.value
                self.operand_counter[str(value)] += 1
            else:
                self.operator_counter["USub"] += 1
                self.visit(node.operand)
        else:
            # 其他单目运算（比如 Not, Invert）
            self.operator_counter[type(node.op).__name__] += 1
            self.visit(node.operand)
    def visit_Call(self, node):
        if isinstance(node.func, ast.Attribute):
            self.visit(node.func.value)  # 先访问func左边的对象
            self.operator_counter[node.func.attr] += 1
        elif isinstance(node.func, ast.Name):
            self.operator_counter[node.func.id] += 1

        # 只遍历args，不把func名重复算为operand
        for arg in node.args:
            self.visit(arg)
        for keyword in node.keywords:
            self.visit(keyword.value)

    def visit_Return(self, node):
        self.operator_counter["return"] += 1
        if node.value:
            self.visit(node.value)

    def visit_Name(self, node):
        self.operand_counter[node.id] += 1

    def visit_Constant(self, node):
        self.operand_counter[str(node.value)] += 1

    def analyze(self, code: str):
        tree = ast.parse(code)
        self.visit(tree)

        η1 = len(self.operator_counter)
        η2 = len(self.operand_counter)
        N1 = sum(self.operator_counter.values())
        N2 = sum(self.operand_counter.values())

        η = η1 + η2
        N = N1 + N2

        volume = N * math.log2(η) if η > 0 else 0
        difficulty = (η1 / 2) * (N2 / η2) if η2 > 0 else 0
        effort = difficulty * volume

        return {
            "η1": η1,
            "η2": η2,
            "N1": N1,
            "N2": N2,
            "vocabulary": η,
            "length": N,
            "volume": volume,
            "difficulty": difficulty,
            "effort": effort
        }


def analyze_custom_halstead(code: str):
    try:
        # print("🔎 Custom Halstead Input Code:")
        # tree = ast.parse(textwrap.dedent(code))
        # analyzer.visit(tree)
        analyzer = HalsteadAnalyzer()
        result = analyzer.analyze(code)
        # print("📊 Custom Halstead Result:", result)
        return result
    except Exception as e:
        print("❌ Custom Halstead error:", e)
        return {"volume": -1, "effort": -1}


def compute_codebert_similarity(code1: str, code2: str) -> float:
    inputs = tokenizer([code1, code2], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # ⬅ 把输入移动到 GPU 或 CPU
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1)
    sim = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
    return sim

def compute_cyclomatic_complexity(code: str) -> float:
    try:
        code = textwrap.dedent(code)
        results = cc_visit(code)
        if not results:
            return 0
        return sum(block.complexity for block in results) / len(results)
    except Exception as e:
        print("❌ Cyclomatic Complexity Error:", e)
        return -1

def analyze_halstead_complexity(code: str):
    try:
        code = textwrap.dedent(code)
        result = h_visit(code)
        # print("🔎 Code passed to h_visit:\n", code)
        # result.total 是 HalsteadReport，直接返回
        return result.total
    except Exception as e:
        print("❌ Halstead error:", e)
        return None

def is_followup_more_complex(original_code: str, followup_code: str, similarity_threshold=0.95):
    try:
        # ✅ 用自定义 Halstead 分析替代
        halstead_o = analyze_custom_halstead(original_code)
        halstead_f = analyze_custom_halstead(followup_code)

        volume_o = halstead_o.get("volume", -1)
        volume_f = halstead_f.get("volume", -1)
        effort_o = halstead_o.get("effort", -1)
        effort_f = halstead_f.get("effort", -1)

        cc_o = compute_cyclomatic_complexity(original_code)
        cc_f = compute_cyclomatic_complexity(followup_code)


        # similarity = compute_codebert_similarity(original_code, followup_code)
        similarity=0
        is_more_complex = (volume_f > volume_o and effort_f > effort_o and cc_f >= cc_o)
        is_not_duplicate = similarity < similarity_threshold

        return is_more_complex and is_not_duplicate, {
            "halstead_o": halstead_o,
            "halstead_f": halstead_f,
            "effort_o": effort_o,
            "effort_f": effort_f,
            "cc_o": cc_o,
            "cc_f": cc_f,
            "similarity": similarity
        }
    except Exception as e:
        return False, {
            "volume_o": -1, "volume_f": -1,
            "effort_o": -1, "effort_f": -1,
            "cc_o": -1, "cc_f": -1,
            "similarity": -1, "error": str(e)
        }

if __name__ == "__main__":
    # 原始代码和追问代码
    original_code = """
def execute_command(image):
    image_patch = ImagePatch(image)
    apple_patches = image_patch.find("apple")
    answer = len(apple_patches)
    return answer
"""

    followup_code ="""
def execute_command(image1, image2):
    image_patch1 = ImagePatch(image1)
    image_patch2 = ImagePatch(image2)
    apple_patches = image_patch1.find("apple")
    apple_patches = apple_patches.append(image_patch2.find("apple"))
    red_apple_patches = []
    for apple_patch in apple_patches:
        if apple_patch.verify_property("apple", "red"):
            red_apple_patches.append(apple_patch)
    answer = len(red_apple_patches)
    return answer
"""
    # 判断是否更复杂
    keep, metrics = is_followup_more_complex(original_code, followup_code)
    print("\n🧠 Complexity comparison result:", "✔️ Keep" if keep else "❌ Reject")
    print("📊 Halstead - Original:", json.dumps(metrics.get("halstead_o", -1), indent=4, ensure_ascii=False))
    print("📊 Halstead - Followup:", json.dumps(metrics.get("halstead_f", -1), indent=4, ensure_ascii=False))
    print("📊 Cyclomatic Complexity - Original:", metrics.get("cc_o", -1))
    print("📊 Cyclomatic Complexity - Followup:", metrics.get("cc_f", -1))
    print("🤝 CodeBERT Similarity:", metrics.get("similarity", -1))
