import numpy as np
import random
import torch

from typing import List, Dict, Any, Optional, Union
import torch
import warnings

def merge_step_dicts(
    step_dicts: List[Dict[str, Any]],
    sum_keys: Optional[List[str]] = None,
) -> Dict[str, Any]:
    """
    把多个 training-step 的 log dict 合并成一条平均（或求和）记录。

    Args:
        step_dicts: 每个元素是 fwdbwd_one_step 返回的 dict
        sum_keys: 里面列出的 key 不做平均，直接求和

    Returns:
        合并后的新 dict，所有 Tensor 都转 Python 标量
    """
    if not step_dicts:
        return {}

    sum_keys = set(sum_keys or [])
    merged: Dict[str, Any] = {}

    # 先统计每个 key 出现的次数
    key_counter: Dict[str, int] = {}
    for d in step_dicts:
        for k in d:
            key_counter[k] = key_counter.get(k, 0) + 1

    # 按 key 合并
    for k in key_counter:
        if key_counter[k] != len(step_dicts):
            warnings.warn(
                f"Key '{k}' 只在 {key_counter[k]}/{len(step_dicts)} 个 step 中出现，"
                "缺失 step 会被忽略。"
            )

        values = [d[k] for d in step_dicts if k in d]

        # 空列表直接跳过
        if not values:
            continue

        # 如果要求求和
        if k in sum_keys:
            merged[k] = _sum_values(values)
            continue

        # 默认做平均
        merged[k] = _mean_values(values)

    return merged


# --------- 内部辅助 ---------
def _to_python_scalar(x: Union[torch.Tensor, float, int]) -> Union[float, int]:
    if isinstance(x, torch.Tensor):
        if x.numel() != 1:
            raise ValueError(f"期望标量 Tensor，却得到 shape {x.shape}")
        x = x.detach().cpu().item()
    return float(x) if isinstance(x, float) else int(x)


def _mean_values(values: List[Any]) -> Union[float, int]:
    try:
        scalars = [_to_python_scalar(v) for v in values]
        return sum(scalars) / len(scalars)
    except ValueError as e:
        # 无法平均就保留最后一个并警告
        warnings.warn(
            f"字段无法平均（{e}），已退化为保留最后一个 step 的值。"
        )
        return _to_python_scalar(values[-1])


def _sum_values(values: List[Any]) -> Union[float, int]:
    scalars = [_to_python_scalar(v) for v in values]
    return sum(scalars)
    
def set_seed(seed: int, deterministic: bool = False):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.

    Args:
        seed (`int`):
            The seed to set.
        deterministic (`bool`, *optional*, defaults to `False`):
            Whether to use deterministic algorithms where available. Can slow down training.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.use_deterministic_algorithms(True)


def merge_dict_list(dict_list):
    if len(dict_list) == 1:
        return dict_list[0]

    merged_dict = {}
    for k, v in dict_list[0].items():
        if isinstance(v, torch.Tensor):
            if v.ndim == 0:
                merged_dict[k] = torch.stack([d[k] for d in dict_list], dim=0)
            else:
                merged_dict[k] = torch.cat([d[k] for d in dict_list], dim=0)
        else:
            # for non-tensor values, we just copy the value from the first item
            merged_dict[k] = v
    return merged_dict
