import numpy as np
import time

import torch

from stork.models import RecurrentSpikingModel

import stork.nodes.base
from stork import generators
from stork import loss_stacks
from stork import monitors

import operator
import warnings
from itertools import chain
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
from torch._utils import (
    _get_all_device_indices,
    _get_available_device_type,
    _get_device_index,
    _get_devices_properties,
    ExceptionWrapper
)
import torch.nn as nn
from torch.nn.modules import Module
from torch.nn.parallel.replicate import _replicatable_module, _broadcast_coalesced_reshape
from collections import OrderedDict
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs

from torch.cuda._utils import _get_device_index
import threading
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
from typing import Set, Optional
import copy
import inspect

from challenge.custom.readout import CustomReadoutGroup
from collections import defaultdict





import numpy as np
import time

import torch

from stork.models import RecurrentSpikingModel

import stork.nodes.base
from stork import generators
from stork import loss_stacks
from stork import monitors

import operator
import warnings
from itertools import chain
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
from torch._utils import (
    _get_all_device_indices,
    _get_available_device_type,
    _get_device_index,
    _get_devices_properties,
    ExceptionWrapper
)
import torch.nn as nn
from torch.nn.modules import Module
from torch.nn.parallel.replicate import _replicatable_module, _broadcast_coalesced_reshape
from collections import OrderedDict
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs

from torch.cuda._utils import _get_device_index
import threading
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
from typing import Set, Optional
import copy
import inspect

from challenge.custom.readout import CustomReadoutGroup
from collections import defaultdict

import torch
import gc  # 导入 gc 模块

def find_shallow_copies(module: torch.nn.Module, max_depth=1000) -> Dict[int, List[str]]:
    """
    检测模块中所有子变量（包括子模块、参数、buffer、普通属性）的浅拷贝关系
    返回格式: {共享对象的id: [所有指向该对象的属性路径列表]}
    """
    # 记录已访问过的模块，防止循环引用导致的无限递归
    visited_modules = set()
    # 记录所有对象的id和对应路径 {id: [path1, path2, ...]}
    object_registry = defaultdict(list)

    def _track_objects(current_module, current_path="module", depth=0):
        if depth > max_depth or id(current_module) in visited_modules:
            return
        visited_modules.add(id(current_module))

        # 遍历所有子模块
        for name, child in current_module.named_children():
            child_path = f"{current_path}.{name}"
            object_registry[id(child)].append(child_path)
            _track_objects(child, child_path, depth + 1)

        # 遍历所有参数 (包括普通参数和buffer)
        for param_type in ["named_parameters", "named_buffers"]:
            for name, param in getattr(current_module, param_type)(recurse=False):
                param_path = f"{current_path}.{name}({param_type.split('_')[-1]})"
                object_registry[id(param)].append(param_path)

        # 遍历其他用户自定义属性（非模块、参数、buffer）
        for name in dir(current_module):
            # 跳过内置方法、私有属性和特殊方法
            if name.startswith("__") or name in ["named_parameters", "named_children", "named_buffers"]:
                continue

            attr = getattr(current_module, name)
            # 排除布尔类型
            if isinstance(attr, bool):
                continue
            # 跳过方法和不可哈希对象
            if inspect.ismethod(attr) or not isinstance(attr, (int, float, str, list, dict, torch.Tensor)):
                continue

            attr_path = f"{current_path}.{name}"
            object_registry[id(attr)].append(attr_path)

    _track_objects(module)
    # 过滤出有重复引用的对象
    return {k: v for k, v in object_registry.items() if len(v) > 1}

def print_shallow_copies_report(shared_objects: Dict[int, List[str]]):
    """打印浅拷贝检测报告"""
    if not shared_objects:
        print("✅ 未检测到任何浅拷贝关系")
        return

    print("⚠️ 检测到以下浅拷贝关系:")
    for obj_id, paths in shared_objects.items():
        print(f"对象ID {hex(obj_id)} 被以下路径共享:")
        for path in paths:
            print(f"  - {path}")
        print("---")

def compare_modules_shallow_copy(module1, module2):
    shared_objects = []
    visited = set()

    def _compare(obj1, obj2, path):
        # 跳过布尔类型
        if isinstance(obj1, bool) or isinstance(obj2, bool):
            return
        if id(obj1) in visited or id(obj2) in visited:
            return
        visited.update([id(obj1), id(obj2)])

        # 检查对象是否为同一实例
        if obj1 is obj2:
            shared_objects.append((path, hex(id(obj1))))

        # 递归比较子结构（跳过布尔值）
        if isinstance(obj1, torch.nn.Module) and isinstance(obj2, torch.nn.Module):
            for (name1, child1), (name2, child2) in zip(obj1.named_children(), obj2.named_children()):
                _compare(child1, child2, f"{path}.{name1}")
            for (name1, param1), (name2, param2) in zip(obj1.named_parameters(), obj2.named_parameters()):
                _compare(param1, param2, f"{path}.{name1}")
            for (name1, buf1), (name2, buf2) in zip(obj1.named_buffers(), obj2.named_buffers()):
                _compare(buf1, buf2, f"{path}.{name1}")
        elif isinstance(obj1, (list, tuple)) and isinstance(obj2, (list, tuple)):
            for i, (item1, item2) in enumerate(zip(obj1, obj2)):
                _compare(item1, item2, f"{path}[{i}]")
        elif isinstance(obj1, dict) and isinstance(obj2, dict):
            for k in set(obj1.keys()) & set(obj2.keys()):
                _compare(obj1[k], obj2[k], f"{path}['{k}']")

    _compare(module1, module2, "root")
    return shared_objects

def collect_objects(module: torch.nn.Module, module_name: str = "module") -> Dict[int, List[str]]:
    """收集模块中的对象（排除布尔值）"""
    visited_modules = set()
    object_registry = defaultdict(list)

    def _traverse(current_module, current_path, depth=0):
        nonlocal visited_modules
        if depth > 1000 or id(current_module) in visited_modules:
            return
        visited_modules.add(id(current_module))

        # 记录当前模块本身
        object_registry[id(current_module)].append(current_path)

        # 遍历子模块
        for name, child in current_module.named_children():
            child_path = f"{current_path}.{name}"
            _traverse(child, child_path, depth + 1)

        # 遍历参数和缓冲区
        for param_type in ["named_parameters", "named_buffers"]:
            for name, param in getattr(current_module, param_type)(recurse=False):
                param_path = f"{current_path}.{name}({param_type.split('_')[-1]})"
                object_registry[id(param)].append(param_path)

        # 遍历其他属性（排除布尔值）
        for name in dir(current_module):
            if name.startswith("__") or name in ["named_parameters", "named_children", "named_buffers"]:
                continue

            attr = getattr(current_module, name)
            # 排除布尔类型和方法
            if isinstance(attr, bool) or inspect.ismethod(attr):
                continue
            if not isinstance(attr, (int, float, str, list, dict, torch.Tensor)):
                continue

            attr_path = f"{current_path}.{name}"
            object_registry[id(attr)].append(attr_path)

    _traverse(module, module_name)
    return object_registry

def compare_shallow_copies(
        module1: torch.nn.Module,
        module2: torch.nn.Module,
        name1: str = "module1",
        name2: str = "module2"
) -> Dict[int, Tuple[List[str], List[str]]]:
    """比较两个模块的浅拷贝关系（已排除布尔值）"""
    registry1 = collect_objects(module1, name1)
    registry2 = collect_objects(module2, name2)
    shared_objects = {}
    for obj_id in set(registry1.keys()) & set(registry2.keys()):
        shared_objects[obj_id] = (registry1[obj_id], registry2[obj_id])
    return shared_objects

def print_comparison_report(shared_objects: Dict[int, Tuple[List[str], List[str]]]):
    """打印报告（已过滤布尔值）"""
    if not shared_objects:
        print("✅ 两个模块间未检测到任何浅拷贝关系")
        return

    print("⚠️ 检测到跨模块浅拷贝关系:")
    for obj_id, (paths1, paths2) in shared_objects.items():
        print(f"对象 {hex(obj_id)} 在以下位置共享:")
        print(f"  {paths1[0].split('.')[0]} 中的路径:")
        for path in paths1:
            print(f"    - {path}")
        print(f"  {paths2[0].split('.')[0]} 中的路径:")
        for path in paths2:
            print(f"    - {path}")
        print("---")


def get_model_gpu_memory(model: torch.nn.Module, include_gradients: bool = True) -> Dict[str, float]:
    """计算模型在GPU上的显存占用，并返回各设备的内存（MB）"""
    device_memory = defaultdict(int)

    # 处理子模块
    for _, child in model.named_children():
        child_mem = get_model_gpu_memory(child, include_gradients)
        for dev, mem in child_mem.items():
            device_memory[dev] += mem

    # 处理参数和梯度
    for param in model.parameters():
        if param.is_cuda:
            dev = str(param.device)
            device_memory[dev] += param.numel() * param.element_size()
        if include_gradients and param.grad is not None and param.grad.is_cuda:
            grad_dev = str(param.grad.device)
            device_memory[grad_dev] += param.grad.numel() * param.grad.element_size()

    # 处理缓冲区
    for buffer in model.buffers():
        if buffer.is_cuda:
            dev = str(buffer.device)
            device_memory[dev] += buffer.numel() * buffer.element_size()

    # 处理自定义属性中的张量
    def scan_tensors(obj: Any, visited: Set[int]) -> None:
        """递归扫描并统计张量，但不重复计算"""
        obj_id = id(obj)
        if obj_id in visited:
            return
        visited.add(obj_id)

        if isinstance(obj, torch.Tensor) and obj.is_cuda:
            dev = str(obj.device)
            device_memory[dev] += obj.numel() * obj.element_size()
            return

        if isinstance(obj, dict):
            for v in obj.values():
                scan_tensors(v, visited)
        elif isinstance(obj, (list, tuple, set, frozenset)):
            for item in obj:
                scan_tensors(item, visited)
        elif hasattr(obj, '__dict__'):
            for attr in vars(obj).values():
                scan_tensors(attr, visited)

    # 扫描模块的自定义属性（排除子模块和已管理的参数/缓冲区）
    module_dict = vars(model)
    child_names = {name for name, _ in model.named_children()}
    managed_attrs = set(model._parameters.keys()) | set(model._buffers.keys())

    for name in module_dict:
        if name in child_names or name in managed_attrs or name.startswith('_') or callable(module_dict[name]):
            continue
        scan_tensors(module_dict[name], set())

    # 转换为MB
    return {
        dev: round(mem_bytes / (1024 ** 2), 2)
        for dev, mem_bytes in device_memory.items()
        if mem_bytes > 0
    }


def get_container_gpu_memory(
        container: Any,
        verbose: bool = False,
        visited: Optional[Set[int]] = None
) -> Dict[str, float]:
    """计算容器中所有张量的显存占用，并跳过已访问对象"""
    visited = visited if visited is not None else set()
    device_memory = defaultdict(int)

    def recursive_calc(obj: Any, depth: int = 0, path: str = "") -> None:
        obj_id = id(obj)
        if obj_id in visited:
            return
        visited.add(obj_id)

        # 处理张量
        if isinstance(obj, torch.Tensor):
            if obj.is_cuda:
                dev = str(obj.device)
                tensor_mem = obj.numel() * obj.element_size()
                device_memory[dev] += tensor_mem
                if verbose:
                    indent = "  " * depth
                    print(f"{indent}[{path}] {tuple(obj.shape)} {obj.dtype} on {dev} | {tensor_mem / 1024 ** 2:.2f} MB")
            return

        # 跳过模型实例（由get_model_gpu_memory处理）
        if isinstance(obj, torch.nn.Module):
            if verbose:
                indent = "  " * depth
                print(f"{indent}[{path}] Skip nn.Module: {type(obj).__name__}")
            return

        # 处理容器类型
        if isinstance(obj, dict):
            for k, v in obj.items():
                recursive_calc(v, depth + 1, f"{path}.{k}" if path else str(k))
        elif isinstance(obj, (list, tuple, set, frozenset)):
            for i, item in enumerate(obj):
                recursive_calc(item, depth + 1, f"{path}[{i}]")
        elif hasattr(obj, '__dict__'):
            for attr_name, attr_value in vars(obj).items():
                if not attr_name.startswith('_'):
                    recursive_calc(attr_value, depth + 1, f"{path}.{attr_name}")

    recursive_calc(container)
    return {
        dev: round(mem_bytes / (1024 ** 2), 2)
        for dev, mem_bytes in device_memory.items()
        if mem_bytes > 0
    }


def get_all_gpu_memory(scope: Optional[dict] = None, verbose: bool = False) -> Dict[str, float]:
    """遍历当前作用域的所有变量，统计总显存占用（支持容器中的模型）"""
    # 动态获取调用者的全局作用域
    if scope is None:
        frame = inspect.currentframe()
        caller_frame = frame.f_back if frame is not None else None
        scope = caller_frame.f_globals if caller_frame is not None else globals()

    total_memory = defaultdict(float)
    visited = set()  # 记录已处理的张量ID
    processed_modules = set()  # 记录已处理的Module实例ID

    def process_module(module: torch.nn.Module) -> None:
        """处理单个Module实例，并递归处理其子模块"""
        nonlocal total_memory, visited, processed_modules

        module_id = id(module)
        if module_id in processed_modules:
            return
        processed_modules.add(module_id)

        # 统计当前模块显存
        model_mem = get_model_gpu_memory(module)
        for dev, mem in model_mem.items():
            total_memory[dev] += mem

        # 递归处理子模块
        for _, child in module.named_children():
            process_module(child)

    def process_container(obj: Any, depth: int = 0, path: str = "") -> None:
        """递归扫描容器，处理所有张量和模型"""
        obj_id = id(obj)
        if obj_id in visited:
            return
        visited.add(obj_id)

        # 处理模型实例
        if isinstance(obj, torch.nn.Module):
            process_module(obj)
            return

        # 处理张量
        if isinstance(obj, torch.Tensor):
            if obj.is_cuda:
                dev = str(obj.device)
                tensor_mem = obj.numel() * obj.element_size()
                total_memory[dev] += tensor_mem
                if verbose:
                    indent = "  " * depth
                    print(f"{indent}[{path}] {tuple(obj.shape)} {obj.dtype} on {dev} | {tensor_mem / 1024 ** 2:.2f} MB")
            return

        # 处理容器类型
        if isinstance(obj, dict):
            for k, v in obj.items():
                process_container(v, depth + 1, f"{path}.{k}" if path else str(k))
        elif isinstance(obj, (list, tuple, set, frozenset)):
            for i, item in enumerate(obj):
                process_container(item, depth + 1, f"{path}[{i}]")
        elif hasattr(obj, '__dict__'):
            for attr_name, attr_value in vars(obj).items():
                if not attr_name.startswith('_'):
                    process_container(attr_value, depth + 1, f"{path}.{attr_name}")

    # 主处理流程
    for name, var in list(scope.items()):
        if name.startswith('__') or inspect.ismodule(var):
            continue
        process_container(var)

    # 转换为MB并返回
    return {
        dev: round(mem / (1024 ** 2), 2)
        for dev, mem in total_memory.items()
    }


def get_all_memory():
    # 清空缓存（对所有设备有效）
    torch.cuda.empty_cache()

    # 获取GPU数量
    num_devices = torch.cuda.device_count()

    for device in range(num_devices):
        # 获取当前GPU的已分配内存和保留内存
        allocated = torch.cuda.memory_allocated(device=device)
        reserved = torch.cuda.memory_reserved(device=device)

        print(f"GPU {device} 显存信息:")
        print(f"  Allocated memory: {allocated / 1024 ** 3:.2f} GB")  # 已分配内存
        print(f"  Reserved memory:  {reserved / 1024 ** 3:.2f} GB")  # 保留内存

# def get_all_memory_in_parm():
    # print(f"当前模型:")
    # model1_gpu=get_model_gpu_memory(modules[0])
    # model2_gpu=get_model_gpu_memory(modules[1])
    # print(model1_gpu)
    # print(model2_gpu)
    #
    # aaa = {}
    # aaa[0] = x_batch
    # aaa[1] = y_batch
    # aaa[2] = cur_batch_size
    # aaa[3] = record
    # aaa[4] = results
    # aaa[5] = out_loss_workers
    # aaa[6] = reg_loss_workers
    # aaa[7] = store_states_workers
    # aaa[8] = flat_seq_shape_workers
    # print(f"当前变量:")
    # parm_gpu=get_container_gpu_memory(aaa)
    # print(parm_gpu)
    #
    # total_gpu_mem={}
    # for key in parm_gpu.keys():
    #     if key not in model1_gpu.keys():
    #         total_gpu_mem[key] = parm_gpu[key] + model2_gpu[key]
    #     else:
    #         total_gpu_mem[key] = parm_gpu[key] + model1_gpu[key] + model2_gpu[key]
    # print(f"总计变量:")
    # print(total_gpu_mem)


def build_module_tree(module: nn.Module, parent_name: str = "root") -> Dict[str, Any]:
    """
    递归构建模块树结构
    :param module: 当前模块
    :param parent_name: 父模块名称（用于生成唯一键）
    :return: 树节点字典，包含模块信息和子节点
    """
    tree = {
        "name": str(module.__class__.__name__),
        "full_name": parent_name,
        "children": []
    }

    # 遍历当前模块的直接子模块
    for child_name, child_module in module.named_children():
        child_key = f"{parent_name}.{child_name}" if parent_name != "root" else child_name
        child_node = build_module_tree(child_module, child_key)
        tree["children"].append(child_node)

    return tree


def print_module_tree(tree: Dict[str, Any], indent: int = 0) -> None:
    """
    递归打印模块树结构
    :param tree: build_module_tree 返回的树结构
    :param indent: 缩进空格数
    """
    prefix = "    " * indent + "|-- " if indent > 0 else ""
    print(f"{prefix}{tree['name']} ({tree['full_name']})")
    for child in tree["children"]:
        print_module_tree(child, indent + 1)


def compare_architecture(model1, model2):
    # 递归比较模块结构
    def _compare(module1, module2, path=""):
        # 检查类型是否相同
        if type(module1) != type(module2):
            print(f"结构不同: {path} 的类型不同 ({type(module1).__name__} vs {type(module2).__name__})")
            return False

        # 检查子模块名称是否相同
        children1 = list(module1.named_children())
        children2 = list(module2.named_children())
        if len(children1) != len(children2):
            print(f"结构不同: {path} 的子模块数量不同 ({len(children1)} vs {len(children2)})")
            return False

        # 递归比较子模块
        same = True
        for (name1, child1), (name2, child2) in zip(children1, children2):
            if name1 != name2:
                print(f"结构不同: 子模块名称不同 ({name1} vs {name2})")
                same = False
            if not _compare(child1, child2, path=f"{path}.{name1}"):
                same = False
        return same

    return _compare(model1, model2, path="Root")


def compare_parameter_shapes(model1, model2):
    params1 = dict(model1.named_parameters())
    params2 = dict(model2.named_parameters())

    same = True
    # 检查参数名称是否相同
    if set(params1.keys()) != set(params2.keys()):
        missing_keys = set(params1.keys()) - set(params2.keys())
        extra_keys = set(params2.keys()) - set(params1.keys())
        print(f"参数名称不同: 缺失的键 {missing_keys}, 多余的键 {extra_keys}")
        same = False

    # 检查每个参数的形状
    for name in params1:
        if name not in params2:
            continue
        if params1[name].shape != params2[name].shape:
            print(f"参数形状不同: {name} ({params1[name].shape} vs {params2[name].shape})")
            same = False

    return same


def compare_parameter_values(model1, model2, tol=1e-6):
    params1 = dict(model1.named_parameters())
    params2 = dict(model2.named_parameters())

    same = True
    for name in params1:
        if name not in params2:
            continue
        tensor1 = params1[name].data
        tensor2 = params2[name].data

        if not torch.allclose(tensor1, tensor2, atol=tol):
            max_diff = torch.max(torch.abs(tensor1 - tensor2)).item()
            print(f"参数值不同: {name} (最大差异: {max_diff:.2e})")
            same = False

    return same


# model1 = modules[0]
# model2 = modules[1]
# print("模型结构比较结果:")
# structure_same = compare_architecture(model1, model2)
# print("\n参数形状比较结果:")
# shape_same = compare_parameter_shapes(model1, model2)
# if structure_same and shape_same:
#     print("\n参数值比较结果:")
#     value_same = compare_parameter_values(model1, model2)
# else:
#     print("\n跳过参数值比较（结构或形状不同）")
#     value_same = False
# print("\n=== 最终报告 ===")
# print(f"模型结构是否相同: {'是' if structure_same else '否'}")
# print(f"参数形状是否相同: {'是' if shape_same else '否'}")
# print(f"参数值是否相同: {'是' if value_same else '否'}")


def tensor_detect():
    # 查看当前显存占用
    print(f"当前显存占用: {torch.cuda.memory_allocated()} bytes")
    print(f"当前显存保留: {torch.cuda.memory_reserved()} bytes")

    # 查看所有张量的显存占用
    for obj in gc.get_objects():
        if torch.is_tensor(obj) and obj.is_cuda:
            print(f"Tensor: {obj}, Size: {obj.element_size() * obj.nelement()} bytes")


def print_memory_usage():
    print("\n[显存使用摘要 - 全部GPU]")

    output = []
    # 遍历所有可用GPU
    for device_id in range(torch.cuda.device_count()):
        # 切换到目标GPU
        with torch.cuda.device(device_id):
            allocated_mb = torch.cuda.memory_allocated() / (1024 ** 2)
            reserved_mb = torch.cuda.memory_reserved() / (1024 ** 2)

            print(f"GPU {device_id}:")
            print(f"  ▸ 已分配显存: {allocated_mb:.2f} MB")
            print(f"  ▸ 保留显存: {reserved_mb:.2f} MB")
            print("-" * 30)
            output_={}
            output_["已分配显存"] = allocated_mb
            output_["保留显存"] = reserved_mb
            output.append(output_)
    return output


# 详细张量占用分析
def print_tensor_details():
    print("\n[各张量占用详情]")
    total_tensor_mem = 0

    # 遍历所有对象
    for obj in gc.get_objects():
        if torch.is_tensor(obj) and obj.is_cuda:
            # 计算张量占用的显存 (MB)
            tensor_mem = obj.element_size() * obj.nelement() / (1024 ** 2)
            total_tensor_mem += tensor_mem

            # 打印详细信息
            print(f"Tensor: {obj.shape} {obj.dtype}", end=' | ')
            print(f"Size: {tensor_mem:.2f} MB | Device: {obj.device}")

    # 打印汇总信息
    print(f"\n总计张量显存: {total_tensor_mem:.2f} MB")
    print("=" * 50)

