import ast
import networkx as nx
from typing import Dict, Any, List, Tuple
import difflib
import numpy as np

class SemanticCodeAnalyzer:
    def __init__(self, code: str):
        self.tree = ast.parse(code)
        self.call_graph = nx.DiGraph()
        self.control_flow_graph = nx.DiGraph()
        self.max_score = 100

    def analyze(self) -> Dict[str, Any]:
        analysis = {
            "function_complexity": self._analyze_function_complexity(),
            "data_flow": self._analyze_data_flow(),
            "control_flow": self._analyze_control_flow(),
            "function_calls": self._analyze_function_calls(),
            "variable_usage": self._analyze_variable_usage(),
            "code_patterns": self._identify_code_patterns(),
        }
        analysis["scores"] = self._calculate_scores(analysis)
        return analysis

    def _analyze_function_complexity(self) -> Dict[str, int]:
        complexity = {}
        for node in ast.walk(self.tree):
            if isinstance(node, ast.FunctionDef):
                complexity[node.name] = self._calculate_cyclomatic_complexity(node)
        return complexity

    def _calculate_cyclomatic_complexity(self, node: ast.AST) -> int:
        complexity = 1
        for child in ast.walk(node):
            if isinstance(child, (ast.If, ast.While, ast.For, ast.And, ast.Or)):
                complexity += 1
        return complexity

    def _analyze_data_flow(self) -> List[str]:
        data_flow = []
        for node in ast.walk(self.tree):
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        data_flow.append(f"Assign to {target.id}")
            elif isinstance(node, ast.AugAssign):
                if isinstance(node.target, ast.Name):
                    data_flow.append(f"Modify {node.target.id}")
        return data_flow

    def _analyze_control_flow(self) -> List[str]:
        control_flow = []
        for node in ast.walk(self.tree):
            if isinstance(node, ast.If):
                control_flow.append("If statement")
            elif isinstance(node, ast.For):
                control_flow.append("For loop")
            elif isinstance(node, ast.While):
                control_flow.append("While loop")
            elif isinstance(node, ast.Try):
                control_flow.append("Try-except block")
        return control_flow

    def _analyze_function_calls(self) -> Dict[str, List[str]]:
        calls = {}
        for node in ast.walk(self.tree):
            if isinstance(node, ast.FunctionDef):
                calls[node.name] = []
                for child in ast.walk(node):
                    if isinstance(child, ast.Call) and isinstance(child.func, ast.Name):
                        calls[node.name].append(child.func.id)
                        self.call_graph.add_edge(node.name, child.func.id)
        return calls

    def _analyze_variable_usage(self) -> Dict[str, int]:
        usage = {}
        for node in ast.walk(self.tree):
            if isinstance(node, ast.Name):
                usage[node.id] = usage.get(node.id, 0) + 1
        return usage

    def _identify_code_patterns(self) -> List[str]:
        patterns = []
        for node in ast.walk(self.tree):
            if isinstance(node, ast.ListComp):
                patterns.append("List comprehension")
            elif isinstance(node, ast.Lambda):
                patterns.append("Lambda function")
            elif isinstance(node, ast.With):
                patterns.append("Context manager (with statement)")
        return patterns

    def _calculate_scores(self, analysis: Dict[str, Any]) -> Dict[str, float]:
        scores = {}
        
        # Complexity score
        avg_complexity = np.mean(list(analysis["function_complexity"].values())) if analysis["function_complexity"] else 0
        scores["complexity_score"] = min(avg_complexity / 10 * 100, self.max_score)
        
        # Data flow score
        data_flow_score = len(analysis["data_flow"]) * 5
        scores["data_flow_score"] = min(data_flow_score, self.max_score)
        
        # Control flow score
        control_flow_score = len(analysis["control_flow"]) * 10
        scores["control_flow_score"] = min(control_flow_score, self.max_score)
        
        # Function call score
        total_calls = sum(len(calls) for calls in analysis["function_calls"].values())
        scores["function_call_score"] = min(total_calls * 5, self.max_score)
        
        # Variable usage score
        total_usage = sum(analysis["variable_usage"].values())
        scores["variable_usage_score"] = min(total_usage * 2, self.max_score)
        
        # Code pattern score
        pattern_score = len(analysis["code_patterns"]) * 15
        scores["code_pattern_score"] = min(pattern_score, self.max_score)
        
        # Overall score (weighted average)
        weights = {
            "complexity_score": 0.2,
            "data_flow_score": 0.15,
            "control_flow_score": 0.2,
            "function_call_score": 0.15,
            "variable_usage_score": 0.15,
            "code_pattern_score": 0.15
        }
        scores["overall_score"] = sum(score * weights[key] for key, score in scores.items())
        
        return scores

class CodeScorer:
    def __init__(self, code: str):
        self.tree = ast.parse(code)
        self.node_weights = {
            ast.FunctionDef: 10,
            ast.ClassDef: 15,
            ast.If: 5,
            ast.For: 7,
            ast.While: 7,
            ast.Try: 6,
            ast.ExceptHandler: 4,
            ast.With: 5,
            ast.Assert: 3,
            ast.Import: 2,
            ast.ImportFrom: 2,
            ast.Assign: 1,
            ast.AugAssign: 1,
            ast.Return: 2,
            ast.Call: 3,
            ast.Lambda: 8,
            ast.ListComp: 6,
            ast.DictComp: 7,
            ast.GeneratorExp: 7,
        }
        self.max_score = 100  # Normalize scores to be between 0 and 100

    def calculate_score(self) -> Dict[str, Any]:
        structure_score = self._calculate_structure_score()
        complexity_score = self._calculate_complexity_score()
        behavior_score = self._calculate_behavior_score()
        depth = self._get_max_depth(self.tree)

        total_score = (structure_score + complexity_score + behavior_score) / 3

        return {
            "total_score": min(total_score, self.max_score),
            "structure_score": structure_score,
            "complexity_score": complexity_score,
            "behavior_score": behavior_score,
            "depth": depth,
            "details": self._get_details()
        }

    def _calculate_structure_score(self) -> float:
        node_counts = {node_type: 0 for node_type in self.node_weights}
        for node in ast.walk(self.tree):
            if type(node) in node_counts:
                node_counts[type(node)] += 1

        score = sum(count * self.node_weights[node_type] for node_type, count in node_counts.items())
        return min(score, self.max_score)

    def _calculate_complexity_score(self) -> float:
        max_depth = self._get_max_depth(self.tree)
        num_branches = sum(1 for node in ast.walk(self.tree) if isinstance(node, (ast.If, ast.For, ast.While)))
        return min((max_depth * 5 + num_branches * 3), self.max_score)

    def _calculate_behavior_score(self) -> float:
        num_func_calls = sum(1 for node in ast.walk(self.tree) if isinstance(node, ast.Call))
        num_assignments = sum(1 for node in ast.walk(self.tree) if isinstance(node, (ast.Assign, ast.AugAssign)))
        num_returns = sum(1 for node in ast.walk(self.tree) if isinstance(node, ast.Return))

        return min((num_func_calls * 3 + num_assignments + num_returns * 2), self.max_score)

    def _get_max_depth(self, node: ast.AST, current_depth: int = 0) -> int:
        if isinstance(node, (ast.If, ast.For, ast.While, ast.With, ast.Try, ast.FunctionDef, ast.ClassDef)):
            current_depth += 1

        return max([current_depth] + [self._get_max_depth(child, current_depth) for child in ast.iter_child_nodes(node)])

    def _get_details(self) -> Dict[str, int]:
        details = {}
        for node_type in self.node_weights:
            count = sum(1 for node in ast.walk(self.tree) if isinstance(node, node_type))
            if count > 0:
                details[node_type.__name__] = count
        return details
    



class ASTComparer:
    def __init__(self, code1: str, code2: str):
        self.tree1 = ast.parse(code1)
        self.tree2 = ast.parse(code2)
        self.node_weights = {
            ast.FunctionDef: 10,
            ast.ClassDef: 10,
            ast.If: 5,
            ast.For: 5,
            ast.While: 5,
            ast.Try: 5,
            ast.ExceptHandler: 3,
            ast.With: 3,
            ast.Assert: 2,
            ast.Import: 2,
            ast.ImportFrom: 2,
            ast.Assign: 1,
            ast.AugAssign: 1,
            ast.Return: 1,
        }

    def compare(self) -> float:
        structure_diff = self._compare_structure()
        content_diff = self._compare_content()
        return 0.6 * structure_diff + 0.4 * content_diff

    def _compare_structure(self) -> float:
        nodes1 = self._get_weighted_node_types(self.tree1)
        nodes2 = self._get_weighted_node_types(self.tree2)
        
        total_weight = sum(weight for _, weight in nodes1 + nodes2)
        diff_weight = sum(abs(nodes1.count(n) - nodes2.count(n)) * w for n, w in set(nodes1 + nodes2))
        
        return diff_weight / total_weight if total_weight > 0 else 0

    def _compare_content(self) -> float:
        content1 = self._get_content(self.tree1)
        content2 = self._get_content(self.tree2)
        
        matcher = difflib.SequenceMatcher(None, content1, content2)
        return 1 - matcher.ratio()

    def _get_weighted_node_types(self, tree: ast.AST) -> List[Tuple[type, int]]:
        return [(type(node), self.node_weights.get(type(node), 1)) 
                for node in ast.walk(tree) if isinstance(node, tuple(self.node_weights.keys()))]

    def _get_content(self, tree: ast.AST) -> List[str]:
        content = []
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                content.append(f"def {node.name}")
            elif isinstance(node, ast.ClassDef):
                content.append(f"class {node.name}")
            elif isinstance(node, ast.Name):
                content.append(node.id)
            elif isinstance(node, ast.Str):
                content.append(node.s)
            elif isinstance(node, ast.Num):
                content.append(str(node.n))
        return content

def ast_difference_metric(code1: str, code2: str) -> float:
    comparer = ASTComparer(code1, code2)
    return comparer.compare()

def code_behavior_score(code: str) -> Dict[str, Any]:
    scorer = CodeScorer(code)
    behaviour =  scorer.calculate_score()
    return behaviour["behavior_score"]

def semantic_code_analysis(code: str) -> Dict[str, Any]:
    analyzer = SemanticCodeAnalyzer(code)
    return analyzer.analyze()["scores"]["overall_score"]