import torch

import operator
import warnings
from torch._utils import (
    _get_device_index,
    _get_devices_properties,
    ExceptionWrapper
)

import torch.nn as nn
from torch.nn.modules import Linear, Dropout
from torch.nn.modules import Module
from torch.nn.parallel.replicate import _replicatable_module, _broadcast_coalesced_reshape
from collections import OrderedDict
from torch.nn.parallel.scatter_gather import scatter_kwargs

from torch.cuda._utils import _get_device_index
import threading
from typing import Any, cast, Dict, List, Sequence, Tuple, Union, TypeVar, Optional
import copy

from stork.nodes import InputGroup

T = TypeVar("T", bound=Module)

def data_scatter(
    x_batch,
    y_batch,
    device_ids: Sequence[Union[int, torch.device]],
) -> Any:
    assert len(x_batch)==len(y_batch), "len(x_batch) != len(y_batch)"
    # 获取设备信息
    device_ids = device_ids
    num_devices = len(device_ids)
    # 计算切分尺寸 (处理无法整除的情况)
    split_sizes = [len(x_batch) // num_devices] * num_devices
    remainder = len(x_batch) % num_devices
    for i in range(remainder):
        split_sizes[i] += 1
    # 自动分割数据和标签到不同设备
    x_batches = [
        chunk.to(device_ids[i])
        for i, chunk in enumerate(x_batch.split(split_sizes, dim=0))
    ]
    target_batches = [
        chunk.to(device_ids[i])
        for i, chunk in enumerate(y_batch.split(split_sizes, dim=0))
    ]

    return x_batches, target_batches, split_sizes

def scatter(
    x_batch,
    y_batch,
    cur_batch_size,
    record,
    device_ids: Sequence[Union[int, torch.device]],
    dim,
) -> Any:
    # Scatter x_batch to GPUs
    x_scattered_inputs, _ = scatter_kwargs((x_batch,), None, device_ids, dim=dim)
    y_scattered_inputs, _ = scatter_kwargs((y_batch,), None, device_ids, dim=dim)
    # 提取分散后的x_batch部分
    scattered_x_batch = [x[0] for x in x_scattered_inputs]
    scattered_y_batch = [y[0] for y in y_scattered_inputs]

    # 计算每个GPU的实际批次大小
    # 假设x_batch是Tensor，直接取每个分片的第0维大小
    updated_batch_sizes = [x.size(0) for x in scattered_x_batch]

    # 更新cur_batch_size为各GPU的批次大小列表
    cur_batch_size = updated_batch_sizes

    updated_record = [record for x in scattered_x_batch]

    return scattered_x_batch, scattered_y_batch, cur_batch_size, updated_record

def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
    imbalance_warn = """
    There is an imbalance between your GPUs. You may want to exclude GPU {} which
    has less than 75% of the memory or cores of GPU {}. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable."""
    device_ids = [_get_device_index(x, True) for x in device_ids]
    dev_props = _get_devices_properties(device_ids)

    def warn_imbalance(get_prop):
        values = [get_prop(props) for props in dev_props]
        min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
        max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
        if min_val / max_val < 0.75:
            warnings.warn(
                imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])
            )
            return True
        return False

    if warn_imbalance(lambda props: props.total_memory):
        return
    if warn_imbalance(lambda props: props.multi_processor_count):
        return

def get_a_var(
    obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]],
) -> Optional[torch.Tensor]:
    if isinstance(obj, torch.Tensor):
        return obj

    if isinstance(obj, (list, tuple)):
        for result in map(get_a_var, obj):
            if isinstance(result, torch.Tensor):
                return result
    if isinstance(obj, dict):
        for result in map(get_a_var, obj.items()):
            if isinstance(result, torch.Tensor):
                return result
    return None


def parallel_apply(
    modules: Sequence[Module],
    x_batch,
    cur_batch_size,
    record,
    devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
):


    assert len(modules)>1, \
        "At least two modules are required for DataParallel training."
    assert len(modules) == len(x_batch), \
        f"The number of modules {len(modules)} is not equal to the number of inputs {len(x_batch)}"
    assert len(modules[0].groups)==len(modules[1].groups), \
        "The number of groups in the modules is not equal"

    if cur_batch_size is not None:
        assert len(modules) == len(cur_batch_size)
    else:
        cur_batch_size = (cast(Dict[str, Any], {}),) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = [_get_device_index(x, True) for x in devices]
    streams = [torch.cuda.current_stream(x) for x in devices]
    lock = threading.Lock()
    results = {}
    grad_enabled, autocast_enabled = (
        torch.is_grad_enabled(),
        torch.is_autocast_enabled(),
    )

    def _worker_of_parallel_apply(
        i: int,
        module: Module,
        x_batch,
        cur_batch_size,
        record,
        device: Optional[Union[int, torch.device]] = None,
        stream: Optional[torch.cuda.Stream] = None,
    ) -> None:
        torch.set_grad_enabled(grad_enabled)
        assert device is not None, \
            "Device cannot be None. Please ensure the device is properly initialized."
        if stream is None:
            stream = torch.cuda.current_stream(device)
        try:
            with (torch.cuda.device(device),
                  torch.cuda.stream(stream),
                  torch.amp.autocast("cuda", enabled=autocast_enabled)):

                output = module.run(x_batch, cur_batch_size, record)
                module.reg_loss = module.compute_regularizer_losses()
                module.state_sequence_toTensor()

            with lock:
                results[i] = output

        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where=f"in replica {i} on device {device}"
                )

    threads = [
        threading.Thread(
            target=_worker_of_parallel_apply, args=(i, module, x_batch, cur_batch_size, record, device, stream)
        )
        for i, (module, x_batch, cur_batch_size, record, device, stream) in enumerate(
            zip(modules, x_batch, cur_batch_size, record, devices, streams)
        )
    ]

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    outputs = []
    for i in range(len(x_batch)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs

def get_replicas(
            model: T,
            batchsize,
            detach: bool = False,

) -> List[T]:
    replicas = replicate(model, model.multiGPU['device_ids'], batchsize, detach)
    assert len(replicas) == len(model.multiGPU['device_ids']), "len(replicas) != len(self.device_ids)"

    for replica in replicas:
        replica.reset_states()
    return replicas


def replicate(
    network: T,
    devices: Sequence[Union[int, torch.device]],
    batchsize,
    detach: bool = False,
) -> List[T]:
    if not _replicatable_module(network):
        raise RuntimeError(
            "Cannot replicate network where python modules are "
            "childrens of ScriptModule"
        )

    if not devices:
        return []

    devices = [_get_device_index(x, True) for x in devices]
    num_replicas = len(devices)


    # 对 params 进行广播和重塑
    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)
    # 确保第一个设备的参数也是独立副本
    param_copies[0] = [p.clone().detach().requires_grad_() for p in param_copies[0]]


    # 对 buffers 进行广播和重塑
    buffers = list(network.buffers())
    buffers_rg: list[torch.Tensor] = []
    buffers_not_rg: list[torch.Tensor] = []
    for buf in buffers:
        if buf.requires_grad and not detach:
            buffers_rg.append(buf)
        else:
            buffers_not_rg.append(buf)

    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
    buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}

    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
    buffer_copies_not_rg = _broadcast_coalesced_reshape(
        buffers_not_rg, devices, detach=True
    )
    if len(buffer_copies_rg) and len(buffer_copies_rg[0]) > 0:
        buffer_copies_rg[0] = [b.clone().detach().requires_grad_() for b in buffer_copies_rg[0]]
    if len(buffer_copies_not_rg) and len(buffer_copies_not_rg[0]) > 0:
        buffer_copies_not_rg[0] = [b.clone().detach() for b in buffer_copies_not_rg[0]]



    # 对 modules 进行深拷贝和复制
    modules = list(network.modules())
    modules_deepcopy: List[List[Module]] = [[] for _ in devices]
    module_copies: List[List[Module]] = [[] for _ in devices]
    module_indices: List[Dict[int, int]] = [{} for _ in devices]
    networks_deepcopy = []
    param_copy_to_original_maps = []  # 存储每个副本的参数到原始参数的映射
    buffer_copy_to_original_maps = []  # 存储每个副本的参数到原始参数的映射

    for j in range(num_replicas):
        network_deepcopy = deep_copy_non_module_parameters(network)
        networks_deepcopy.append(network_deepcopy)
        modules_deepcopy[j] = list(network_deepcopy.modules())

        # 填充module_indices和module_copies
        for i, module in enumerate(modules_deepcopy[j]):
            module_indices[j][id(module)] = i

            replica = module

            replica._parameters = {}

            # This is a temporary fix for DDP. DDP needs to access the
            # replicated model parameters. It used to do so through
            # `mode.parameters()`. The fix added in #33907 for DP stops the
            # `parameters()` API from exposing the replicated parameters.
            # Hence, we add a `_former_parameters` dict here to support DDP.
            replica._former_parameters = OrderedDict()
            module_copies[j].append(replica)

    for j in range(num_replicas):
        for i, module in enumerate(modules_deepcopy[j]):
            # replica = module_copies[j][i]
            # for key, child in module._modules.items():
            #     if child is None:
            #         replica._modules[key] = None
            #     else:
            #         module_idx = module_indices[j][id(child)]
            #         setattr(replica, key, module_copies[j][module_idx])
            if hasattr(module, "device"):
                setattr(module, 'device', devices[j])
            if hasattr(module, "batch_size"):
                setattr(module, 'batch_size', batchsize[j])



    # 匹配模型副本与parameters和buffers
    for i, module in enumerate(modules):
        for j in range(num_replicas):
            replica = module_copies[j][i]
            for key, param in module._parameters.items():
                if param is None:
                    replica._parameters[key] = None
                else:
                    param_idx = param_indices[param]
                    param_copy = param_copies[j][param_idx]
                    # parameters in replicas are no longer leaves,
                    # so setattr them as non-parameter attributes
                    setattr(replica, key, param_copy)
                    # expose the parameter for DDP
                    replica._former_parameters[key] = param_copy

            for key, buf in module._buffers.items():  # type: ignore[assignment]
                if buf is None:
                    replica._buffers[key] = None
                else:
                    # original_buffer = buffer_copy_to_original_maps[j][buf]
                    if buf.requires_grad and not detach:
                        buffer_copies = buffer_copies_rg
                        buffer_idx = buffer_indices_rg[buf]
                    else:
                        buffer_copies = buffer_copies_not_rg
                        buffer_idx = buffer_indices_not_rg[buf]
                    # setattr(replica, key, buffer_copies[j][buffer_idx])
                    replica._buffers[key] = buffer_copies[j][buffer_idx]



    return [cast(T, module_copies[j][0]) for j in range(num_replicas)]
    # return [networks_deepcopy[0],networks_deepcopy[1]]

def deep_copy_non_module_parameters(module):

    deepcopy_model=copy.deepcopy(prepare_deep_copy_non_module_parameters(module))

    # deepcopy_model.data_generator = None
    return deepcopy_model

def prepare_deep_copy_non_module_parameters(module):
    for name, child in module.named_children():
        # 递归处理子模块
        if isinstance(child, InputGroup):
            continue
        if isinstance(child, Dropout):
            continue
        if isinstance(child, Linear):
            continue
        new_child = prepare_deep_copy_non_module_parameters(child)
        setattr(module, name, new_child)

    # 处理非标准属性（用户自定义的普通张量/对象）
    module_dict = vars(module)
    child_names = {name for name, _ in module.named_children()}
    managed_params = set(module._parameters.keys()) | set(module._buffers.keys())
    for name in list(module_dict.keys()):
        if name in child_names or name in managed_params:
            continue  # 跳过已处理项

        value = module_dict[name]

        # 跳过特殊属性和可调用对象
        if name.startswith('_') or callable(value):
            continue
        if 'scheduler' in name:
            continue
        if 'optimizer' in name:
            continue
        if 'loss_stack' in name:
            continue
        if 'data_generator' in name:
            continue
        if 'regularizers' in name:
            continue
        if 'constraints' in name:
            continue

        # 处理普通张量
        if isinstance(value, torch.Tensor):
            cloned = value.clone().detach().to(value.device)
            cloned.requires_grad_(value.requires_grad)
            setattr(module, name, cloned)

        # 处理复杂对象（列表/字典/自定义类等）
        else:
            if isinstance(value, (dict)):
                for key in value.keys():
                    if isinstance(value[key],list):
                        for i in range(len(value[key])):
                            if isinstance(value[key][i],torch.Tensor):
                                cloned = value[key][i].clone().detach().to(value[key][i].device)
                                cloned.requires_grad_(value[key][i].requires_grad)
                                value[key][i]=cloned
                    elif isinstance(value[key],torch.Tensor):
                        cloned = value[key].clone().detach().to(value[key].device)
                        cloned.requires_grad_(value[key].requires_grad)
                        value[key]=cloned
    return module

# def prepare_deep_copy_non_module_parameters(module):
#
#     return module.prepare_for_deepcopy()
