from datetime import timedelta
from functools import partial

from torch.distributed.fsdp.wrap import (
    _or_policy,
    lambda_auto_wrap_policy,
)

import os
import gc
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp.api import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
import torch.nn as nn

from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxTransformer2DModel, FluxSingleTransformerBlock, \
        AdaLayerNormContinuous, FSDPWrappedLinear, CombinedTimestepGuidanceTextProjEmbeddings, FluxPosEmbed

@torch.no_grad()
def reduce_dict(log_dict: dict, device: torch.device):
    """
    对所有 rank 的 log_dict 做 all-reduce 求平均，返回一个新的 dict。
    支持：
        - torch.Tensor（任意形状，先 mean 再平均）
        - int / float / bool 等 Python 标量
        - 其它不可规约类型原样返回
    """
    if not (dist.is_available() and dist.is_initialized()):
        # 单卡场景直接转成 Python 标量
        return {
            k: (v.item() if isinstance(v, torch.Tensor) else v)
            for k, v in log_dict.items()
        }

    reduced = {}
    keys, tensors, dtypes = [], [], []

    for k, v in log_dict.items():
        if isinstance(v, torch.Tensor):
            t = v.detach().float().mean().view(1).to(device)
            keys.append(k)
            tensors.append(t)
            dtypes.append("tensor")
        elif isinstance(v, (int, float, bool)):
            t = torch.tensor(float(v), device=device).view(1)
            keys.append(k)
            tensors.append(t)
            dtypes.append(type(v).__name__)  # 记录原始类型
        else:
            # 字符串 / None / 其它不可规约类型，跳过
            reduced[k] = v

    if not tensors:        # 没有任何可规约的值
        return reduced

    stacked = torch.cat(tensors)
    dist.all_reduce(stacked, op=dist.ReduceOp.SUM)
    stacked /= dist.get_world_size()

    # 拆回每个 key，并按原始类型还原
    for k, v, dtype in zip(keys, stacked.unbind(), dtypes):
        if dtype == "tensor":
            reduced[k] = v.item()
        elif dtype == "int":
            # 如果需要严格的 int（例如 timestep）
            reduced[k] = int(round(v.item()))
        elif dtype == "bool":
            reduced[k] = bool(v.item())
        else:               # float
            reduced[k] = v.item()

    return reduced
    
def reduce_mean(tensor):
    """跨所有 ranks 求平均"""
    if dist.is_available() and dist.is_initialized():
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        tensor /= dist.get_world_size()
    return tensor
    
def fsdp_state_dict(model):
    fsdp_fullstate_save_policy = FullStateDictConfig(
        offload_to_cpu=True, rank0_only=True
    )
    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy
    ):
        checkpoint = model.state_dict()

    return checkpoint


def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, ignored_modules=None, cpu_offload=False):
    if mixed_precision:
        mixed_precision_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.float32,
            cast_forward_inputs=True
        )
    else:
        mixed_precision_policy = None

    if wrap_strategy == "transformer":
        auto_wrap_policy = partial(
            transformer_auto_wrap_policy,
            # transformer_layer_cls=transformer_module
            transformer_layer_cls={FluxTransformerBlock, AdaLayerNormContinuous, FSDPWrappedLinear, CombinedTimestepGuidanceTextProjEmbeddings, FluxPosEmbed, FluxSingleTransformerBlock}
        )


    elif wrap_strategy == "size":
        auto_wrap_policy = partial(
            size_based_auto_wrap_policy,
            min_num_params=min_num_params
            # min_num_params=1e5
        )
    else:
        auto_wrap_policy = lambda m, r, n: (n >= 1e6) or isinstance(m, torch.nn.Linear)
        
    os.environ["NCCL_CROSS_NIC"] = "1"

    sharding_strategy = {
        "full": ShardingStrategy.FULL_SHARD,
        "hybrid_full": ShardingStrategy.HYBRID_SHARD,
        "hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
        "no_shard": ShardingStrategy.NO_SHARD,
    }[sharding_strategy]
    module = FSDP(
        module,
        auto_wrap_policy=auto_wrap_policy,
        sharding_strategy=sharding_strategy,
        mixed_precision=mixed_precision_policy,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
        use_orig_params=True,
        ignored_modules=ignored_modules,
        cpu_offload=CPUOffload(offload_params=cpu_offload),
        backward_prefetch=None,   # 避免额外缓存
        sync_module_states=False  # Load ckpt on rank 0 and sync to other ranks
    )
    return module


def barrier():
    if dist.is_initialized():
        dist.barrier()


def launch_distributed_job(backend: str = "nccl"):
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    host = os.environ["MASTER_ADDR"]
    port = int(os.environ["MASTER_PORT"])

    if ":" in host:  # IPv6
        init_method = f"tcp://[{host}]:{port}"
    else:  # IPv4
        init_method = f"tcp://{host}:{port}"
    dist.init_process_group(rank=rank, world_size=world_size, backend=backend,
                            init_method=init_method, timeout=timedelta(minutes=30))
    torch.cuda.set_device(local_rank)

# ---------- EMA ----------
class EMA_FSDP:
    def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
        print("we init FSDP EMA")
        self.decay = decay
        self.shadow = {} if dist.get_rank() == 0 else None
        self._init_shadow(fsdp_module)

    @torch.no_grad()
    def _init_shadow(self, fsdp_module):
        state = fsdp_state_dict(fsdp_module)   # 所有 rank 必须进入
        if dist.get_rank() == 0:               # 只有 rank0 保存
            self.shadow = {k: v.clone().float().cpu() for k, v in state.items()}
        del state
        gc.collect()
        torch.cuda.empty_cache()

    @torch.no_grad()
    def update(self, fsdp_module):
        d = self.decay
        state = fsdp_state_dict(fsdp_module)  # 所有 rank 必须执行
        if dist.get_rank() == 0:              # 只有 rank0 处理
            for k, v in state.items():
                self.shadow[k].mul_(d).add_(v.float().cpu(), alpha=1. - d)
        del state
        gc.collect()
        torch.cuda.empty_cache()

    # ---- 序列化 ----
    def state_dict(self):
        return self.shadow  # 仅 rank0 有值

    def load_state_dict(self, sd):
        if dist.get_rank() == 0:
            self.shadow = {k: v.clone() for k, v in sd.items()}

    # ---- 把 EMA 权重拷回模型（仅 rank0 需要） ----
    def copy_to(self, fsdp_module):
        if dist.get_rank() != 0:
            return
        state = {k: self.shadow[k] for k in self.shadow}
        with FSDP.state_dict_type(
            fsdp_module,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=False, rank0_only=True),
        ):
            fsdp_module.load_state_dict(state, strict=True)
