from radon.metrics import h_visit
from radon.complexity import cc_visit
import warnings
import torch
import textwrap
from huggingface_hub import snapshot_download
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel

import math
import ast
from collections import Counter

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # 1. 下载整个模型到本地路径（离线可读）
target_cache_dir = "/huggingface/hub"
local_path = snapshot_download(repo_id="microsoft/codebert-base",cache_dir=target_cache_dir)
# 2. 加载 tokenizer 和 model（不再联网）
tokenizer = RobertaTokenizer.from_pretrained(local_path)
model = RobertaModel.from_pretrained(local_path)
model.to(device)


class HalsteadAnalyzer(ast.NodeVisitor):
    def __init__(self):
        self.operator_counter = Counter()
        self.operand_counter = Counter()

    def visit_FunctionDef(self, node):
        self.operator_counter["def"] += 1
        self.operand_counter[node.name] += 1
        self.generic_visit(node)

    def visit_For(self, node):
        self.operator_counter["for"] += 1
        self.generic_visit(node)

    def visit_If(self, node):
        self.operator_counter["if"] += 1
        self.generic_visit(node)

    def visit_While(self, node):
        self.operator_counter["while"] += 1
        self.generic_visit(node)

    def visit_Assign(self, node):
        self.operator_counter["="] += 1
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.operand_counter[target.id] += 1
        self.generic_visit(node)

    def visit_AugAssign(self, node):
        self.operator_counter[type(node.op).__name__] += 1
        self.generic_visit(node)

    def visit_BinOp(self, node):
        self.operator_counter[type(node.op).__name__] += 1
        self.generic_visit(node)

    def visit_Compare(self, node):
        for op in node.ops:
            self.operator_counter[type(op).__name__] += 1
        self.generic_visit(node)

    def visit_Call(self, node):
        # 函数调用名作为运算符
        if isinstance(node.func, ast.Attribute):
            self.operator_counter[node.func.attr] += 1
        elif isinstance(node.func, ast.Name):
            self.operator_counter[node.func.id] += 1

        # 函数参数作为运算数
        for arg in node.args:
            if isinstance(arg, ast.Constant):
                self.operand_counter[str(arg.value)] += 1
            elif isinstance(arg, ast.Name):
                self.operand_counter[arg.id] += 1
        self.generic_visit(node)

    def visit_Name(self, node):
        self.operand_counter[node.id] += 1

    def visit_Constant(self, node):
        self.operand_counter[str(node.value)] += 1

    def analyze(self, code: str):
        tree = ast.parse(code)
        self.visit(tree)

        η1 = len(self.operator_counter)
        η2 = len(self.operand_counter)
        N1 = sum(self.operator_counter.values())
        N2 = sum(self.operand_counter.values())

        η = η1 + η2
        N = N1 + N2

        volume = N * math.log2(η) if η > 0 else 0
        difficulty = (η1 / 2) * (N2 / η2) if η2 > 0 else 0
        effort = difficulty * volume

        # "η1": η1, "η2": η2,
        # "N1": N1, "N2": N2,
        # "operators": dict(self.operator_counter),
        # "operands": dict(self.operand_counter)
        return {

            "vocabulary": η,
            "length": N,
            "volume": volume,
            "difficulty": difficulty,
            "effort": effort,

        }

def analyze_custom_halstead(code: str):
    try:
        # print("🔎 Custom Halstead Input Code:")
        # tree = ast.parse(textwrap.dedent(code))
        # analyzer.visit(tree)
        analyzer = HalsteadAnalyzer()
        result = analyzer.analyze(code)
        # print("📊 Custom Halstead Result:", result)
        return result
    except Exception as e:
        print("❌ Custom Halstead error:", e)
        return {"volume": -1, "effort": -1}


def compute_codebert_similarity(code1: str, code2: str) -> float:
    inputs = tokenizer([code1, code2], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # ⬅ 把输入移动到 GPU 或 CPU
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1)
    sim = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
    return sim

def compute_cyclomatic_complexity(code: str) -> float:
    try:
        code = textwrap.dedent(code)
        results = cc_visit(code)
        if not results:
            return 0
        return sum(block.complexity for block in results) / len(results)
    except Exception as e:
        print("❌ Cyclomatic Complexity Error:", e)
        return -1

def analyze_halstead_complexity(code: str):
    try:
        code = textwrap.dedent(code)
        result = h_visit(code)
        # print("🔎 Code passed to h_visit:\n", code)
        # result.total 是 HalsteadReport，直接返回
        return result.total
    except Exception as e:
        print("❌ Halstead error:", e)
        return None

def is_followup_more_complex(original_code: str, followup_code: str, similarity_threshold=0.95):
    try:
        # ✅ 用自定义 Halstead 分析替代
        halstead_o = analyze_custom_halstead(original_code)
        halstead_f = analyze_custom_halstead(followup_code)

        volume_o = halstead_o.get("volume", -1)
        volume_f = halstead_f.get("volume", -1)
        effort_o = halstead_o.get("effort", -1)
        effort_f = halstead_f.get("effort", -1)

        cc_o = compute_cyclomatic_complexity(original_code)
        cc_f = compute_cyclomatic_complexity(followup_code)


        similarity = compute_codebert_similarity(original_code, followup_code)

        is_more_complex = (volume_f > volume_o and effort_f > effort_o and cc_f >= cc_o)
        is_not_duplicate = similarity < similarity_threshold

        return is_more_complex and is_not_duplicate, {
            "halstead_o": halstead_o,
            "halstead_f": halstead_f,
            "cc_o": cc_o,
            "cc_f": cc_f,
            "similarity": similarity
        }
    except Exception as e:
        return False, {
            "volume_o": -1, "volume_f": -1,
            "effort_o": -1, "effort_f": -1,
            "cc_o": -1, "cc_f": -1,
            "similarity": -1, "error": str(e)
        }

if __name__ == "__main__":
    # 原始代码和追问代码
    original_code = """
    def execute_command(image):
        total_objects = 0
        for img in image:
            patch = ImagePatch(img)
            table_patches = patch.find("table")
            if table_patches:
                table_patch = table_patches[0]
                objects_on_table = table_patch.find("object")
                total_objects += len(objects_on_table)
                print("Intermediate variables--Found objects on table:", len(objects_on_table))
        return total_objects
    """

    followup_code = """
    def execute_command(image):
        largest_object = None
        largest_area = 0
        for img in image:
            patch = ImagePatch(img)
            table_patches = patch.find("table")
            if table_patches:
                table_patch = table_patches[0]
                objects_on_table = table_patch.find("object")
                for obj in objects_on_table:
                    obj_area = obj.width * obj.height
                    print("Intermediate variables--Object area:", obj_area)
                    if obj_area > largest_area:
                        largest_area = obj_area
                        largest_object = obj
        return largest_object
    """

    # code = textwrap.dedent("""
    # def f():
    #     result = obj.some_api("test")
    #     return result
    # """)
    #
    # print(h_visit(code))

    # # 原始代码 Halstead 分析
    # print("\n📊 Full Halstead Analysis - Original Code")
    # halstead_o = analyze_halstead_complexity(original_code)
    # print(halstead_o)
    #
    # # Follow-up 代码 Halstead 分析
    # print("\n📊 Full Halstead Analysis - Follow-up Code")
    # halstead_f = analyze_halstead_complexity(followup_code)
    # print(halstead_f)

    # 判断是否更复杂
    keep, metrics = is_followup_more_complex(original_code, followup_code)
    print("\n🧠 Complexity comparison result:", "✔️ Keep" if keep else "❌ Reject")
    print("📊 Halstead - Original:", metrics.get("halstead_o", -1))
    print("📊 Halstead - Followup:", metrics.get("halstead_f", -1))
    print("📊 Cyclomatic Complexity - Original:", metrics.get("cc_o", -1))
    print("📊 Cyclomatic Complexity - Followup:", metrics.get("cc_f", -1))
    print("🤝 CodeBERT Similarity:", metrics.get("similarity", -1))
