import os
import json
from enum import Enum, auto
from typing import Optional, List, Dict, Any
from collections import defaultdict

# --- CodeLocation and ComponentType class（与之前相同） ---

class ComponentType(Enum):
    """code成分的class型"""
    GLOBAL = auto()
    FUNCTION = auto()
    CLASS = auto()
    METHOD = auto()

class CodeLocation:
    """
    一个统一的class型, 用于table示code中某一成分的position.
    新增了 to_dict and from_dict method以支持JSON序column化.
    """
    def __init__(self,
                 file_path: str,
                 component_type: ComponentType,
                 *,
                 class_name: Optional[str] = None,
                 member_name: Optional[str] = None):
        if not file_path:
            raise ValueError("file_path 不能为空")
        
        self.file_path = file_path
        self.component_type = component_type
        self.class_name = class_name
        self.member_name = member_name
        self._validate()

    def _validate(self):
# ... ()
        if self.component_type == ComponentType.GLOBAL:
            if self.class_name or self.member_name:
                raise ValueError("全局（GLOBAL）类型的成分不能有 class_name 或 member_name")
        elif self.component_type == ComponentType.FUNCTION:
            if not self.member_name or self.class_name:
                raise ValueError("函数（FUNCTION）类型的成分必须有 member_name 且不能有 class_name")
        elif self.component_type == ComponentType.CLASS:
            if not self.class_name or self.member_name:
                raise ValueError("类（CLASS）类型的成分必须有 class_name 且不能有 member_name")
        elif self.component_type == ComponentType.METHOD:
            if not self.class_name or not self.member_name:
                raise ValueError("方法（METHOD）类型的成分必须同时拥有 class_name 和 member_name")

    # --- 工厂method (与之前相同) ---
    @classmethod
    def for_global(cls, file_path: str) -> 'CodeLocation':
        return cls(file_path, ComponentType.GLOBAL)

    @classmethod
    def for_function(cls, file_path: str, function_name: str) -> 'CodeLocation':
        return cls(file_path, ComponentType.FUNCTION, member_name=function_name)

    @classmethod
    def for_class(cls, file_path: str, class_name: str) -> 'CodeLocation':
        return cls(file_path, ComponentType.CLASS, class_name=class_name)

    @classmethod
    def for_method(cls, file_path: str, class_name: str, method_name: str) -> 'CodeLocation':
        return cls(file_path, ComponentType.METHOD, class_name=class_name, member_name=method_name)

    def get_parent(self) -> 'CodeLocation':
        """
        Getcurrent CodeLocation 的上一层（父级）position.

        - 对于 GLOBAL, FUNCTION, 或 CLASS, 父级是files全局作用域 (GLOBAL).
        - 对于 METHOD, 父级是其所在的class (CLASS).

        :return: 一个新的 CodeLocation instance, table示父级位置.
        """
        if self.component_type == ComponentType.METHOD:
            # A method's parent is its class.
            return CodeLocation.for_class(self.file_path, self.class_name)
        
        # The parent for GLOBAL, FUNCTION, and CLASS is the global scope of the file.
        return CodeLocation.for_global(self.file_path)
    
    def to_dict(self) -> Dict[str, Any]:
        """将 CodeLocation instanceConvert为可JSON序column化的dict."""
        data = {
            "component_type": self.component_type.name, # 将 Enum 成员Convert为strings
            "file_path": self.file_path,
        }
        if self.class_name:
            data["class_name"] = self.class_name
        if self.member_name:
            data["member_name"] = self.member_name
        return data

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'CodeLocation':
        """fromdictcreate CodeLocation instance."""
        component_type_str = data.get("component_type")
        if not component_type_str:
            raise ValueError("数据字典中缺少 'component_type' 键")

        try:
            component_type = ComponentType[component_type_str] # 将stringsConvert回 Enum 成员
        except KeyError:
            raise ValueError(f"无效的 component_type 值: '{component_type_str}'")

        file_path = data.get("file_path")
        class_name = data.get("class_name")
        member_name = data.get("member_name")

        # 复用工厂method, 以ensurecreate的object是valid的
        if component_type == ComponentType.GLOBAL:
            return cls.for_global(file_path)
        elif component_type == ComponentType.FUNCTION:
            return cls.for_function(file_path, member_name)
        elif component_type == ComponentType.CLASS:
            return cls.for_class(file_path, class_name)
        elif component_type == ComponentType.METHOD:
            return cls.for_method(file_path, class_name, member_name)
        
#
        raise RuntimeError("反序列化过程中遇到未知的组件类型")

    def get_output_format(self) -> str:
        normalized_path = self.file_path.replace(os.path.sep, '/')
        if self.component_type == ComponentType.GLOBAL: return f"{normalized_path}\nglobal"
        if self.component_type == ComponentType.FUNCTION: return f"{normalized_path}\nfunction: {self.member_name}"
        if self.component_type == ComponentType.CLASS: return f"{normalized_path}\nclass: {self.class_name}"
        if self.component_type == ComponentType.METHOD: return f"{normalized_path}\nfunction: {self.class_name}.{self.member_name}"

    # --- 魔法method (与之前相同) ---
    def __str__(self) -> str:
        # ... (与之前相同)
        normalized_path = self.file_path.replace(os.path.sep, '/')
        if self.component_type == ComponentType.GLOBAL: return normalized_path
        if self.component_type == ComponentType.FUNCTION: return f"{normalized_path}::{self.member_name}"
        if self.component_type == ComponentType.CLASS: return f"{normalized_path}::{self.class_name}"
        if self.component_type == ComponentType.METHOD: return f"{normalized_path}::{self.class_name}.{self.member_name}"
        raise TypeError("未知的 CodeLocation 类型")

    def __repr__(self) -> str:
        # ... (与之前相同)
        parts = [f"file_path='{self.file_path}'", f"type={self.component_type.name}"]
        if self.class_name: parts.append(f"class_name='{self.class_name}'")
        if self.member_name: parts.append(f"member_name='{self.member_name}'")
        return f"CodeLocation({', '.join(parts)})"

    def __eq__(self, other):
        # ... (与之前相同)
        if not isinstance(other, CodeLocation): return NotImplemented
        return (self.file_path == other.file_path and self.component_type == other.component_type and self.class_name == other.class_name and self.member_name == other.member_name)

    def __hash__(self):
        return hash((self.file_path, self.component_type, self.class_name, self.member_name))

class CodeLocationGroup:
    """
    管理一group CodeLocation 的set.
'in' .
    """
    def __init__(self, locations: List[CodeLocation]):
        self._locations = self._merge_locations(locations)
        # 为了快速Find, create一个内部set
        self._locations_set = set(self._locations)

    def _merge_locations(self, locations: List[CodeLocation]) -> List[CodeLocation]:
        unique_locations = sorted(list(set(locations)), key=str)
        class_locations = {(loc.file_path, loc.class_name) for loc in unique_locations if loc.component_type == ComponentType.CLASS}
        merged = []
        for loc in unique_locations:
            if loc.component_type == ComponentType.METHOD and (loc.file_path, loc.class_name) in class_locations:
                continue
            merged.append(loc)
        return merged

    @classmethod
    def from_file_location_string_map(cls, string_map: Dict[str, List[str]]) -> 'CodeLocationGroup':
        """
        从一个描述性的字符串映射中创建 CodeLocationGroup.

        Args:
            string_map (Dict[str, List[str]]): 
, , .
                例如: {"file.py": ["global", "class: MyClass", "function: MyClass.my_method"]}
        
        Returns:
            CodeLocationGroup: 一个新的 CodeLocationGroup 实例.
        """
        locations: List[CodeLocation] = []
        for file_path, components in string_map.items():
            for component_str in components:
                component_str = component_str.strip()
                if component_str == "global":
                    locations.append(CodeLocation.for_global(file_path))
                    continue

                parts = component_str.split(':', 1)
                if len(parts) != 2:
                    # 可以选择skipped或发出Warning
                    continue
                
                comp_type, comp_name = parts[0].strip(), parts[1].strip()

                if comp_type == "class":
                    locations.append(CodeLocation.for_class(file_path, comp_name))
                elif comp_type == "function":
                    # Checkname中是否containing'.'来判断是method还是顶层function
                    if '.' in comp_name:
                        class_name, method_name = comp_name.rsplit('.', 1)
                        locations.append(CodeLocation.for_method(file_path, class_name, method_name))
                    else:
                        locations.append(CodeLocation.for_function(file_path, comp_name))
        
        return cls(locations)

    def to_file_location_string_map(self) -> Dict[str, List[str]]:
        """
        将 CodeLocationGroup Convert回描述性的stringsmapping.
        这是 from_file_location_string_map 的反向操作.

        Returns:
            Dict[str, List[str]]: 
                一个dict, 键是filespath, 值是描述code成分的字符串list.
        """
        string_map = defaultdict(list)
        
        # 遍历已MergeandSort的 locations 以ensure一致性
        for loc in self.locations:
            component_str = ""
            if loc.component_type == ComponentType.GLOBAL:
                component_str = "global"
            elif loc.component_type == ComponentType.CLASS:
                component_str = f"class: {loc.class_name}"
            elif loc.component_type == ComponentType.FUNCTION:
                component_str = f"function: {loc.member_name}"
            elif loc.component_type == ComponentType.METHOD:
                component_str = f"function: {loc.class_name}.{loc.member_name}"
            
            if component_str:
                string_map[loc.file_path].append(component_str)
        
        return dict(string_map)
        
    def group_by_file(self) -> Dict[str, List[CodeLocation]]:
        grouped = defaultdict(list); [grouped[loc.file_path].append(loc) for loc in self._locations]; return dict(grouped)
        
    def save(self, filepath: str):
        data_to_save = {fp: [loc.to_dict() for loc in locs] for fp, locs in self.group_by_file().items()}
        with open(filepath, 'w', encoding='utf-8') as f: json.dump(data_to_save, f, indent=4, ensure_ascii=False)
        
    @classmethod
    def load(cls, filepath: str) -> 'CodeLocationGroup':
        with open(filepath, 'r', encoding='utf-8') as f: grouped_data = json.load(f)
        reconstructed = [CodeLocation.from_dict(ld) for ldl in grouped_data.values() for ld in ldl]
        return cls(reconstructed)
    
    def union(self, other: 'CodeLocationGroup') -> 'CodeLocationGroup':
        if not isinstance(other, CodeLocationGroup): return NotImplemented
        combined_locations = self.locations + other.locations
        return CodeLocationGroup(combined_locations)

    def intersection(self, other: 'CodeLocationGroup') -> 'CodeLocationGroup':
        if not isinstance(other, CodeLocationGroup): return NotImplemented
        self_set = set(self.locations)
        other_set = set(other.locations)
        intersection_list = list(self_set.intersection(other_set))
        return CodeLocationGroup(intersection_list)

    @property
    def locations(self) -> List[CodeLocation]:
        return self._locations

    def __add__(self, other: 'CodeLocationGroup') -> 'CodeLocationGroup':
        return self.union(other)
    
    def __or__(self, other: 'CodeLocationGroup') -> 'CodeLocationGroup':
        return self.union(other)

    def __and__(self, other: 'CodeLocationGroup') -> 'CodeLocationGroup':
        return self.intersection(other)

    def __len__(self) -> int:
        return len(self.locations)

    def __iter__(self):
        return iter(self.locations)
    
    def __eq__(self, other):
        if not isinstance(other, CodeLocationGroup): return NotImplemented
        return self._locations_set == other._locations_set

    def __repr__(self) -> str:
        return f"<CodeLocationGroup with {len(self)} locations>"
    
    def __str__(self):
        return '\n'.join([loc.__str__() for loc in self.locations])

    # --- 新增:重载 'in' 操作符 ---
    def __contains__(self, location: object) -> bool:
        """
        智能判断一个 CodeLocation 是否“属于”这个group.

:
        1. 如果 location 本身就在组里, return True.
        2. 如果 location 是一个method, 且其所属的class在组里, 也返回 True.
        """
        if not isinstance(location, CodeLocation):
            return False

        # 规则 1: 直接Match
        # use预先Calculate好的set, 效率更高
        if location in self._locations_set:
            return True
        
        # 规则 2: Checkmethod是否属于group里的某个class
        if location.component_type == ComponentType.METHOD:
            # Build其所属class的 CodeLocation
            parent_class_location = CodeLocation.for_class(
                location.file_path, 
                location.class_name
            )
            # Check这个class是否在group里
            if parent_class_location in self._locations_set:
                return True
        
#
        return False
    

# --- use示例 ---
if __name__ == '__main__':
    # 1. create一个原始的 CodeLocationGroup instance
    original_locations = [
        CodeLocation.for_class('app/main.py', 'WebApp'),
        CodeLocation.for_method('app/main.py', 'WebApp', 'run'), # 将被Merge
        CodeLocation.for_function('app/utils.py', 'calculate_sum'),
        CodeLocation.for_global('app/config.py'),
        CodeLocation.for_method('app/db.py', 'Database', 'query'), # 不会被Merge
    ]
    original_group = CodeLocationGroup(original_locations)
    
    print("--- 原始的 Group ---")
    for loc in original_group:
        print(f"  - {loc}")

    # 2. 将 Group Save到files
    file_to_save = 'locations.json'
    original_group.save(file_to_save)
    print(f"\nGroup 已保存到 '{file_to_save}'")

    # 打印一下 JSON filescontent, 看看format
    with open(file_to_save, 'r', encoding='utf-8') as f:
        print("\n--- JSON 文件内容 ---")
        print(f.read())

    # 3. fromfilesLoad一个新的 Group instance
    loaded_group = CodeLocationGroup.load(file_to_save)
    print("\n--- 从文件加载的 Group ---")
    for loc in loaded_group:
        print(f"  - {loc}")

    # 4. validation原始objectandLoad的对象是否相等
    print(f"\n原始 Group 和加载的 Group 是否相等? {original_group == loaded_group}")

    # 5. 清理create的files
    os.remove(file_to_save)

    # test1: 一个method, 其所属的class在group里 (应该为 True)
    method_in_class = CodeLocation.for_method('app/main.py', 'WebApp', 'start_server')
    print(f"检查 '{method_in_class}' 是否在组里: {method_in_class in original_group}")

    # test2: 一个直接exists于group里的function (应该为 True)
    func_in_group = CodeLocation.for_function('app/utils.py', 'calculate_sum')
    print(f"检查 '{func_in_group}' 是否在组里: {func_in_group in original_group}")

    # test3: 一个完全不exists的class (应该为 False)
    absent_class = CodeLocation.for_class('app/main.py', 'NonExistentClass')
    print(f"检查 '{absent_class}' 是否在组里: {absent_class in original_group}")

    # test4: 一个method, 其所属的class也不在group里 (应该为 False)
    absent_method = CodeLocation.for_method('app/services.py', 'CacheService', 'get_item')
    print(f"检查 '{absent_method}' 是否在组里: {absent_method in original_group}")
