"""
工具函数
"""

from turtle import position
import numpy as np
import torch
import random
import math
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple, Callable
import time
from scipy.spatial.transform import Rotation as R
import cv2
from robosuite.controllers import load_controller_config
# from spatialmath import UnitQuaternion

controller_config = load_controller_config(default_controller="OSC_POSE")
output_max = np.array(controller_config["output_max"])
output_min = np.array(controller_config["output_min"])
input_max = np.array(controller_config["input_max"])
input_min = np.array(controller_config["input_min"])

def set_seed(seed: int):
    """设置随机种子"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_eef_from_real(real_eef):
    # real_eef: 14D, [pos(3), euler(3), gripper(1)] * left+right
    left_eef = real_eef[:7]
    right_eef = real_eef[7:]
    left_eef_pos = left_eef[:3]
    left_eef_euler = left_eef[3:6]
    right_eef_pos = right_eef[:3]
    right_eef_euler = right_eef[3:6]
    left_eef_quat = UnitQuaternion.RPY(left_eef_euler)
    left_eef_quat = [left_eef_quat.v, left_eef_quat.s]
    right_eef_quat = UnitQuaternion.RPY(right_eef_euler)
    right_eef_quat = [right_eef_quat.v, right_eef_quat.s]
    left_eef = np.concatenate([left_eef_pos, left_eef_quat], axis=0)
    right_eef = np.concatenate([right_eef_pos, right_eef_quat], axis=0)
    return left_eef, right_eef

class Timer:
    """计时器"""
    def __init__(self):
        self._start_time = time.time()
        self._last_time = time.time()

    def reset(self):
        """重置计时器，返回上次重置后的时间和总时间"""
        current_time = time.time()
        last_delta = current_time - self._last_time
        total_time = current_time - self._start_time
        self._last_time = current_time
        return last_delta, total_time

    def total_time(self):
        """返回总时间"""
        return time.time() - self._start_time


class Every:
    """每隔 N 步执行一次的工具"""
    def __init__(self, every: int):
        self._every = every

    def __call__(self, step: int) -> bool:
        if self._every is None or self._every <= 0:
            return False
        return step % self._every == 0


class Until:
    """直到达到 N 步的工具"""
    def __init__(self, until: int):
        self._until = until

    def __call__(self, step: int) -> bool:
        if self._until is None:
            return True
        return step < self._until


def normalize_obs(obs: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    """标准化观测"""
    return (obs - mean) / (std + 1e-8)


def denormalize_obs(obs: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    """反标准化观测"""
    return obs * std + mean


def normalize_action(action: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    """标准化动作"""
    return (action - mean) / (std + 1e-8)


def denormalize_action(action: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    """反标准化动作"""
    return action * std + mean


def compute_statistics(data: np.ndarray, axis: int = 0) -> Dict[str, np.ndarray]:
    """计算数据的统计信息"""
    return {
        'mean': np.mean(data, axis=axis),
        'std': np.std(data, axis=axis),
        'min': np.min(data, axis=axis),
        'max': np.max(data, axis=axis)
    }


def to_torch(data: Any, device: torch.device, dtype: Optional[torch.dtype] = None) -> Any:
    """
    将数据转换为 PyTorch tensor
    
    Args:
        data: 输入数据（numpy 数组、字典或其他）
        device: 目标设备
        dtype: 目标数据类型
    
    Returns:
        转换后的 tensor
    """
    if isinstance(data, dict):
        return {k: to_torch(v, device, dtype) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_torch(v, device, dtype) for v in data]
    elif isinstance(data, np.ndarray):
        tensor = torch.from_numpy(data)
        if dtype is not None:
            tensor = tensor.to(dtype)
        return tensor.to(device)
    elif isinstance(data, torch.Tensor):
        if dtype is not None:
            data = data.to(dtype)
        return data.to(device)
    else:
        return data


def to_numpy(data: Any) -> Any:
    """
    将数据转换为 numpy 数组
    
    Args:
        data: 输入数据（tensor、字典或其他）
    
    Returns:
        转换后的 numpy 数组
    """
    if isinstance(data, dict):
        return {k: to_numpy(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_numpy(v) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.detach().cpu().numpy()
    elif isinstance(data, np.ndarray):
        return data
    else:
        return data


class Logger:
    """简单的日志记录器"""
    def __init__(self, log_dir: Path, use_wandb: bool = False, use_tensorboard: bool = False):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        self.use_wandb = use_wandb
        self.use_tensorboard = use_tensorboard
        
        # 初始化 wandb
        if use_wandb:
            import wandb
            self.wandb = wandb
        
        # 初始化 tensorboard
        if use_tensorboard:
            from torch.utils.tensorboard import SummaryWriter
            self.tb_writer = SummaryWriter(log_dir=str(self.log_dir / 'tensorboard'))

    
    def log_metrics(self, metrics: Dict[str, float], step: int, prefix: str = ""):
        """记录指标"""
        if prefix:
            metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
        
        # 记录到 wandb
        if self.use_wandb:
            self.wandb.log(metrics, step=step)
        
        # 记录到 tensorboard
        if self.use_tensorboard:
            for k, v in metrics.items():
                self.tb_writer.add_scalar(k, v, step)
    
    def close(self):
        """关闭日志记录器"""
        if self.use_tensorboard:
            self.tb_writer.close()


def create_output_dir(base_dir: str, exp_name: str) -> Path:
    """创建输出目录"""
    # from datetime import datetime
    # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(base_dir) / exp_name
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


def save_config(config: Dict[str, Any], output_dir: Path):
    """保存配置到文件"""
    import json
    config_path = output_dir / 'config.json'
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    print(f"配置已保存至: {config_path}")


def load_config(config_path: Path) -> Dict[str, Any]:
    """从文件加载配置"""
    import json
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config


# ==================== AdaDS 相关工具函数 ====================

def merge_delta_actions(actions: np.ndarray) -> np.ndarray:
    """
    正确合并多个 delta 动作
    Args:
        actions: [N, 7] (delta_pos, delta_axis_angle, gripper)
    Returns:
        merged_action: [7]
    """
    delta_pos_sum = np.sum(actions[:, :3], axis=0)
    
    # 2. 旋转：LIBERO是Global Frame，因此左乘
    rotations = R.from_rotvec(actions[:, 3:6])
    merged_rot = R.identity()
    for rot in rotations:
        merged_rot = rot * merged_rot
    delta_rot_merged = merged_rot.as_rotvec()
    
    gripper_merged = actions[-1, 6]
    # gripper_merged = np.sum(actions[:, 6:7], axis=0)
    
    return np.hstack([delta_pos_sum, delta_rot_merged, gripper_merged])


def normalize_action_controller(action):
    action_scale = abs(output_max - output_min) / abs(input_max - input_min)
    action_output_transform = (output_max + output_min) / 2.0
    action_input_transform = (input_max + input_min) / 2.0
    transformed_action = (action[:, :6] - action_input_transform) * action_scale + action_output_transform
    transformed_action = np.concatenate([transformed_action, action[:, 6:]], axis=1)
    return transformed_action

def unnormalize_action_controller(action):
    action_scale = abs(output_max - output_min) / abs(input_max - input_min)
    action_output_transform = (output_max + output_min) / 2.0
    action_input_transform = (input_max + input_min) / 2.0
    transformed_action = (action[:, :6] - action_output_transform) / action_scale + action_input_transform
    transformed_action = np.concatenate([transformed_action, action[:, 6:]], axis=1)
    return transformed_action


def create_rot_from_wxyz(quat_wxyz):
    quat_xyzw = np.concatenate([quat_wxyz[:, 1:], quat_wxyz[:, 0:1]], axis=1)
    return R.from_quat(quat_xyzw)
    
def compute_eef_error(position_desired, quat_desired, position_actual, quat_actual, pos_coef=0.5, rot_coef=0.5, mode="max"):
    position_desired_after = position_desired[1:]
    quat_desired_after = R.from_quat(quat_desired[1:])
    position_actual_after = position_actual[1:]
    quat_actual_after = R.from_quat(quat_actual[1:])

    if len(position_desired.shape) == 1:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after)
    else:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after, axis=1)
    rot_err = (quat_desired_after.inv() * quat_actual_after).magnitude()
    err = (pos_coef * pos_err + rot_coef * rot_err) / (pos_coef + rot_coef)
    if mode == "max":
        return np.max(err)
    elif mode == "mean":
        return np.mean(err)
    else:
        raise ValueError(f"Invalid mode: {mode}")

def compute_eef_error_ratio(position_desired, quat_desired, position_actual, quat_actual, mode="max"):
    position_desired_after = position_desired[1:]
    position_desired_before = position_desired[:-1]
    quat_desired_after = R.from_quat(quat_desired[1:])
    quat_desired_before = R.from_quat(quat_desired[:-1])

    position_actual_after = position_actual[1:]
    quat_actual_after = R.from_quat(quat_actual[1:])

    if len(position_desired.shape) == 1:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after) / np.linalg.norm(position_desired_after - position_desired_before)
    else:
        pos_err = np.linalg.norm(position_desired_after - position_actual_after, axis=1) / np.linalg.norm(position_desired_after - position_desired_before, axis=1)
    rot_err = (quat_desired_after.inv() * quat_actual_after).magnitude() / (quat_desired_after.inv() * quat_desired_before).magnitude()
    err = 0.5*pos_err + 0.5*rot_err
    if mode == "max":
        return np.max(err)
    elif mode == "mean":
        return np.mean(err)
    else:
        raise ValueError(f"Invalid mode: {mode}")

def convert_delta_to_absolute_actions(
        delta_actions: np.ndarray, 
        states: np.ndarray,
        # gripper_speed: float=0.01
    ) -> np.ndarray:
    delta_actions = normalize_action_controller(delta_actions)
    goal_pos = states[:, :3] + delta_actions[:, :3]
    current_rot = R.from_euler('XYZ', states[:, 3:6])
    delta_rot = R.from_rotvec(delta_actions[:, 3:6])
    goal_rot = delta_rot * current_rot 
    goal_euler = goal_rot.as_euler('XYZ')
    goal_gripper = delta_actions[:, 6:7]

    return np.hstack([goal_pos, goal_euler, goal_gripper]).astype(np.float32)



def interpolate_pos_quat(position, quat, gripper, time_indices, target_indices):
    # 转换为 numpy 
    pos_pts = np.array(position)
    quat_pts = np.array(quat)
    gripper_pts = np.array(gripper)
    
    # 2. 修正四元数符号翻转 (Crucial for NLERP)
    # 确保相邻四元数点积为正，保证走最短路径
    for i in range(1, len(quat_pts)):
        if np.dot(quat_pts[i-1], quat_pts[i]) < 0:
            quat_pts[i] *= -1

    # 位置插值 (Linear)
    pos_interp = np.zeros((len(target_indices), 3))
    for i in range(3):
        pos_interp[:, i] = np.interp(target_indices, time_indices, pos_pts[:, i])
    
    # 姿态插值 (NLERP)
    quat_interp = np.zeros((len(target_indices), 4))
    for i in range(4):
        quat_interp[:, i] = np.interp(target_indices, time_indices, quat_pts[:, i])
    
    # NLERP 必须做的归一化
    norms = np.linalg.norm(quat_interp, axis=1, keepdims=True)
    quat_interp /= (norms + 1e-12) # 防止除零
    
    # 夹爪插值 (Linear)
    gripper_interp = np.interp(target_indices, time_indices, gripper_pts)

    return pos_interp, quat_interp, gripper_interp


def extract_eef_from_obs(obs: np.ndarray) -> Tuple[np.ndarray, np.ndarray, float]:
    """从观测中提取末端执行器位置、四元数和夹爪状态
    
    Args:
        obs: 观测数组，形状为 [obs_dim]，其中前7维是 [pos(3), quat(4)]，第8维是夹爪
    
    Returns:
        pos: [3] 位置
        quat: [4] 四元数
        gripper: 夹爪值
    """
    if len(obs) >= 8:
        pos = obs[:3]
        quat = obs[3:7]
        gripper = obs[7]
    elif len(obs) >= 7:
        pos = obs[:3]
        quat = obs[3:7]
        gripper = 0.0
    else:
        raise ValueError(f"观测维度不足: {len(obs)}")
    return pos, quat, gripper

def quat2axisangle(quat):
    """
    Converts quaternion to axis-angle format.
    Returns a unit vector direction scaled by its angle in radians.

    Args:
        quat (np.array): (x,y,z,w) vec4 float angles

    Returns:
        np.array: (ax,ay,az) axis-angle exponential coordinates
    """
    # clip quaternion
    if quat[3] > 1.0:
        quat[3] = 1.0
    elif quat[3] < -1.0:
        quat[3] = -1.0

    den = np.sqrt(1.0 - quat[3] * quat[3])
    if math.isclose(den, 0.0):
        # This is (close to) a zero degree rotation, immediately return
        return np.zeros(3)

    return (quat[:3] * 2.0 * math.acos(quat[3])) / den
    
def axisangle2quat(vec):
    """
    Converts scaled axis-angle to quat.

    Args:
        vec (np.array): (ax,ay,az) axis-angle exponential coordinates

    Returns:
        np.array: (x,y,z,w) vec4 float angles
    """
    # Grab angle
    angle = np.linalg.norm(vec)

    # handle zero-rotation case
    if math.isclose(angle, 0.0):
        return np.array([0.0, 0.0, 0.0, 1.0])

    # make sure that axis is a unit vector
    axis = vec / angle

    q = np.zeros(4)
    q[3] = np.cos(angle / 2.0)
    q[:3] = axis * np.sin(angle / 2.0)
    return q



def put_text(img, text, font_size=1, thickness=2, resize=False,position="top"):
    img = img.copy()
    if resize:
        img = cv2.resize(np.uint8(img), (256, 256))
    h, w = img.shape[:2]
    if position == "top":
        p = (10, 30)
    elif position == "bottom":
        p = (10, h - 10)
    # put the frame number in the top left corner
    img = cv2.putText(
        img,
        text,
        p,
        cv2.FONT_HERSHEY_SIMPLEX,
        font_size,
        (0, 255, 255),
        thickness,
        cv2.LINE_AA,
    )
    return img

import robosuite.controllers.base_controller as base_controller_module
import robosuite.models.grippers.panda_gripper as panda_gripper_module
import robosuite.environments.base as base_module

# used for high speed libero

def patched_scale_action(self, action):
    if self.action_scale is None:
        self.action_scale = abs(self.output_max - self.output_min) / abs(self.input_max - self.input_min)
        self.action_output_transform = (self.output_max + self.output_min) / 2.0
        self.action_input_transform = (self.input_max + self.input_min) / 2.0
    # 注意：这里没有 clip 操作
    transformed_action = (action - self.action_input_transform) * self.action_scale + self.action_output_transform
    return transformed_action

def step(self, action):
    self.timestep += 1
    policy_step = True
    for i in range(int(self.control_timestep / self.model_timestep)):
        self.sim.forward()
        self._pre_action(action, policy_step)
        self.sim.step()
        self._update_observables()
        policy_step = False
    self.cur_time += self.control_timestep
    reward, done, info = self._post_action(action)
    if self.viewer is not None and self.renderer != "mujoco":
        self.viewer.update()
    observations = self.viewer._get_observations() if self.viewer_get_obs else self._get_observations()
    return observations, reward, done, info

def patched_speed_property(self):
    return 0.02

def patched_speed_property_higher(self):
    return 0.04

def original_speed_property(self):
    return 0.01

def apply_patches(higher=False):
    base_controller_module.Controller.scale_action = patched_scale_action

    if higher:
        panda_gripper_module.PandaGripper.speed = property(patched_speed_property_higher)
    else:
        panda_gripper_module.PandaGripper.speed = property(patched_speed_property)
    base_module.MujocoEnv.step = step

def revert_patches():
    base_controller_module.Controller.scale_action = patched_scale_action
    panda_gripper_module.PandaGripper.speed = property(original_speed_property)
    base_module.MujocoEnv.step = step

def batch_interpolate_eef(position_actuals_batch, quat_actuals_batch, gripper_actuals_batch, time_indices, target_indices):
        # 批量插值（假设 time_indices 和 target_indices 相同）
    batch_size = position_actuals_batch.shape[0]
    # 位置插值：批量处理
    pos_interp_batch = np.zeros((batch_size, len(target_indices), 3))
    for i in range(batch_size):
        for dim in range(3):
            pos_interp_batch[i, :, dim] = np.interp(
                target_indices, time_indices, position_actuals_batch[i, :, dim]
            )
    # 四元数和夹爪插值：批量处理
    quat_interp_batch = np.zeros((batch_size, len(target_indices), 4))
    gripper_interp_batch = np.zeros((batch_size, len(target_indices)))
    for i in range(batch_size):
        # 修正四元数符号翻转
        quat_pts = quat_actuals_batch[i].copy()
        for j in range(1, len(quat_pts)):
            if np.dot(quat_pts[j-1], quat_pts[j]) < 0:
                quat_pts[j] *= -1
        # 四元数插值
        for dim in range(4):
            quat_interp_batch[i, :, dim] = np.interp(
                target_indices, time_indices, quat_pts[:, dim]
            )
        # 归一化
        norms = np.linalg.norm(quat_interp_batch[i], axis=1, keepdims=True)
        quat_interp_batch[i] /= (norms + 1e-12)
        
        # 夹爪插值
        gripper_interp_batch[i] = np.interp(
            target_indices, time_indices, gripper_actuals_batch[i]
        )
    # 批量构建 EEF 轨迹
    eef_trajectories_batch = np.concatenate(
        [pos_interp_batch, quat_interp_batch], axis=2
    )
    return eef_trajectories_batch


def process_obs_for_libero(predicted_obs_unnorm_batch, batch_obs, time_indices, target_indices):
    # 批量提取 EEF 状态
    init_positions_batch = batch_obs[:, :3]
    init_angles_batch = batch_obs[:, 3:6]
    init_grippers_batch = batch_obs[:, 6]
    batch_size = predicted_obs_unnorm_batch.shape[0]
    seq_len = predicted_obs_unnorm_batch.shape[1]
    pos_batch = predicted_obs_unnorm_batch[:, :, :3]  # [B, seq_len, 3]
    angle_batch = predicted_obs_unnorm_batch[:, :, 3:6]  # [B, seq_len, 3]
    gripper_batch = predicted_obs_unnorm_batch[:, :, 6]  # [B, seq_len]
    
    position_actuals_batch = np.concatenate(
        [init_positions_batch[:, None, :], pos_batch], axis=1
    )  # [B, seq_len+1, 3]
    angle_actuals_batch = np.concatenate(
        [init_angles_batch[:, None, :], angle_batch], axis=1
    )  # [B, seq_len+1, 3]
    gripper_actuals_batch = np.concatenate(
        [init_grippers_batch[:, None], gripper_batch], axis=1
    )  # [B, seq_len+1]
    
    # 批量转换四元数
    angle_flat = angle_actuals_batch.reshape(-1, 3)  # [B*(seq_len+1), 3]
    quat_flat = np.array([axisangle2quat(angle) for angle in angle_flat])  # [B*(seq_len+1), 4]
    quat_actuals_batch = quat_flat.reshape(batch_size, seq_len + 1, 4)  # [B, seq_len+1, 4]
    eef_trajectories_batch = batch_interpolate_eef(position_actuals_batch, quat_actuals_batch, gripper_actuals_batch, time_indices, target_indices)
    return [eef_trajectories_batch[i] for i in range(batch_size)]


def process_obs_for_real(predicted_obs_unnorm_batch, batch_obs, time_indices, target_indices, single_arm=False):
    # real state: [42]: 14D qpos 14D eef
    eef_dim = 14 if not single_arm else 7
    eef_predicted_batch = predicted_obs_unnorm_batch[:, :, -eef_dim:]
    eef_init_batch = batch_obs[:, -eef_dim:]
    batch_size = predicted_obs_unnorm_batch.shape[0]

    def process_for_one_eef(eef_predicted, eef_init):
        # for spatialmath rpy angle convert
        position_all = np.concatenate([eef_init[:, None, :3], eef_predicted[:, :, :3]], axis=1)
        angle_all = np.concatenate([eef_init[:, None, 3:6], eef_predicted[:, :, 3:6]], axis=1)
        gripper_actuals = np.concatenate([eef_init[:, None, 6], eef_predicted[:, :, 6]], axis=1)
        sc_rots = [R.from_euler('xyz', angle_all[i], degrees=False) for i in range(batch_size)]
        quat_all = np.stack([sc_rot.as_quat() for sc_rot in sc_rots], axis=0)
        return position_all, quat_all, gripper_actuals
    if not single_arm:
        left_positions, left_quats, left_grippers = process_for_one_eef(eef_predicted_batch[:, :, :7], eef_init_batch[:, :7])
        left_eef_trajectories = batch_interpolate_eef(left_positions, left_quats, left_grippers, time_indices, target_indices)
        right_positions, right_quats, right_grippers = process_for_one_eef(eef_predicted_batch[:, :, 7:14], eef_init_batch[:, 7:14])
        right_eef_trajectories = batch_interpolate_eef(right_positions, right_quats, right_grippers, time_indices, target_indices)
        all_eef_trajectories = np.concatenate([left_eef_trajectories, right_eef_trajectories], axis=-1)
    else:
        positions, quats, grippers = process_for_one_eef(eef_predicted_batch, eef_init_batch)
        all_eef_trajectories = batch_interpolate_eef(positions, quats, grippers, time_indices, target_indices)
    return [all_eef_trajectories[i] for i in range(batch_size)]


def calculate_model_size(model: torch.nn.Module, verbose: bool = True) -> Dict[str, Any]:
    """
    计算模型大小（参数数量、内存占用等）
    
    Args:
        model: PyTorch模型
        verbose: 是否打印详细信息
        
    Returns:
        包含模型大小信息的字典
    """
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0
    
    # 计算参数数量
    for param in model.parameters():
        param_size = param.numel()
        total_params += param_size
        if param.requires_grad:
            trainable_params += param_size
        else:
            non_trainable_params += param_size
    
    # 计算模型大小（MB）
    # 假设使用float32，每个参数4字节
    model_size_mb = total_params * 4 / (1024 ** 2)
    
    # 计算缓冲区大小（如BatchNorm的running_mean等）
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.numel()
    buffer_size_mb = buffer_size * 4 / (1024 ** 2)
    
    # 总内存占用
    total_size_mb = model_size_mb + buffer_size_mb
    
    result = {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'non_trainable_params': non_trainable_params,
        'model_size_mb': model_size_mb,
        'buffer_size_mb': buffer_size_mb,
        'total_size_mb': total_size_mb,
        'total_params_million': total_params / 1e6,
        'trainable_params_million': trainable_params / 1e6,
    }
    
    if verbose:
        print("=" * 60)
        print("模型大小统计")
        print("=" * 60)
        print(f"总参数数量: {total_params:,} ({total_params / 1e6:.2f}M)")
        print(f"可训练参数: {trainable_params:,} ({trainable_params / 1e6:.2f}M)")
        print(f"不可训练参数: {non_trainable_params:,} ({non_trainable_params / 1e6:.2f}M)")
        print(f"模型参数大小: {model_size_mb:.2f} MB")
        print(f"缓冲区大小: {buffer_size_mb:.2f} MB")
        print(f"总内存占用: {total_size_mb:.2f} MB")
        print("=" * 60)
    
    return result


def calculate_flops_per_layer(module: torch.nn.Module, input_shape: Tuple, output_shape: Tuple = None) -> int:
    """
    计算单个层的FLOPs
    
    Args:
        module: PyTorch模块
        input_shape: 输入形状（batch_size, ...）
        output_shape: 输出形状（可选）
        
    Returns:
        FLOPs数量
    """
    flops = 0
    
    if isinstance(module, torch.nn.Linear):
        # Linear层: (batch_size, in_features) -> (batch_size, out_features)
        # FLOPs = batch_size * in_features * out_features
        batch_size = input_shape[0]
        in_features = module.in_features
        out_features = module.out_features
        flops = batch_size * in_features * out_features
        # 加上bias的加法操作
        if module.bias is not None:
            flops += batch_size * out_features
    
    elif isinstance(module, torch.nn.Conv2d):
        # Conv2d层: FLOPs = batch_size * out_channels * kernel_size^2 * in_channels * output_h * output_w
        batch_size = input_shape[0]
        in_channels = module.in_channels
        out_channels = module.out_channels
        kernel_size = module.kernel_size
        if isinstance(kernel_size, tuple):
            kernel_size = kernel_size[0] * kernel_size[1]
        else:
            kernel_size = kernel_size ** 2
        
        if output_shape is not None:
            output_h, output_w = output_shape[2], output_shape[3]
        else:
            # 估算输出尺寸（简化处理）
            output_h, output_w = input_shape[2], input_shape[3]
        
        flops = batch_size * out_channels * kernel_size * in_channels * output_h * output_w
        if module.bias is not None:
            flops += batch_size * out_channels * output_h * output_w
    
    elif isinstance(module, (torch.nn.ReLU, torch.nn.LeakyReLU, torch.nn.GELU, torch.nn.Sigmoid, torch.nn.Tanh)):
        # 激活函数: 每个元素一个操作
        batch_size = input_shape[0]
        num_elements = 1
        for dim in input_shape[1:]:
            num_elements *= dim
        flops = batch_size * num_elements
    
    elif isinstance(module, torch.nn.LayerNorm):
        # LayerNorm: 每个元素计算均值和方差，然后归一化
        batch_size = input_shape[0]
        num_elements = 1
        for dim in input_shape[1:]:
            num_elements *= dim
        # 计算均值: num_elements次加法
        # 计算方差: num_elements次减法和乘法
        # 归一化: num_elements次减法和除法
        flops = batch_size * num_elements * 4  # 简化估算
    
    elif isinstance(module, torch.nn.GRU) or isinstance(module, torch.nn.LSTM):
        # RNN层: 复杂，这里做简化估算
        # GRU: 每个时间步大约 3 * (input_size + hidden_size) * hidden_size 次乘法
        if input_shape is None or len(input_shape) < 2:
            # 如果无法获取形状信息，返回0（可能是PackedSequence，需要特殊处理）
            return 0
        
        # 对于 PackedSequence，input_shape 可能是 (total_length, feature_dim)
        # 对于普通输入，input_shape 是 (batch_size, seq_len, feature_dim) 或 (batch_size, seq_len)
        if len(input_shape) == 2:
            # 可能是 PackedSequence 的 data 形状: (total_length, feature_dim)
            # 估算：假设平均序列长度为 total_length / batch_size
            total_length = input_shape[0]
            # 使用一个合理的batch_size估算（这里使用32作为默认值）
            estimated_batch_size = 32
            seq_len = max(1, total_length // estimated_batch_size)
            batch_size = estimated_batch_size
        else:
            batch_size = input_shape[0]
            seq_len = input_shape[1] if len(input_shape) > 1 else 1
        
        input_size = module.input_size
        hidden_size = module.hidden_size
        num_layers = module.num_layers
        
        # 简化估算：每个时间步的FLOPs
        flops_per_step = 3 * (input_size + hidden_size) * hidden_size  # GRU
        if isinstance(module, torch.nn.LSTM):
            flops_per_step = 4 * (input_size + hidden_size) * hidden_size  # LSTM
        
        flops = batch_size * seq_len * num_layers * flops_per_step
    
    return flops


def calculate_loss_iteration_flops(
    model: torch.nn.Module,
    loss_fn: Callable,
    input_data: Any,
    target_data: Any = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """
    计算一次loss迭代的FLOPs（前向传播 + 反向传播）
    
    Args:
        model: PyTorch模型
        loss_fn: 损失函数
        input_data: 输入数据（可以是tuple/list，支持多个输入）
        target_data: 目标数据（用于计算loss）
        verbose: 是否打印详细信息
        
    Returns:
        包含FLOPs信息的字典
    """
    model.eval()
    
    # 前向传播FLOPs
    forward_flops = 0
    
    # 使用hook来捕获每层的输入输出
    layer_flops = []
    
    def hook_fn(module, input, output):
        # 处理输入形状
        input_shape = None
        if isinstance(input, (tuple, list)):
            if len(input) > 0:
                inp = input[0]
                # 处理 PackedSequence 对象
                if hasattr(inp, 'data') and hasattr(inp, 'batch_sizes'):
                    # 这是 PackedSequence，从 data 获取形状信息
                    # PackedSequence.data 的形状是 (total_length, feature_dim)
                    if hasattr(inp.data, 'shape'):
                        data_shape = inp.data.shape
                        # 尝试从 batch_sizes 估算 batch_size
                        if hasattr(inp, 'batch_sizes') and len(inp.batch_sizes) > 0:
                            max_batch_size = inp.batch_sizes[0].item() if hasattr(inp.batch_sizes[0], 'item') else int(inp.batch_sizes[0])
                            # 估算：使用 (total_length, max_batch_size, feature_dim) 作为形状
                            if len(data_shape) >= 2:
                                input_shape = (max_batch_size, len(inp.batch_sizes), data_shape[1])
                            else:
                                input_shape = data_shape
                        else:
                            input_shape = data_shape
                elif hasattr(inp, 'shape'):
                    input_shape = inp.shape
        else:
            # 处理 PackedSequence 对象
            if hasattr(input, 'data') and hasattr(input, 'batch_sizes'):
                if hasattr(input.data, 'shape'):
                    data_shape = input.data.shape
                    if hasattr(input, 'batch_sizes') and len(input.batch_sizes) > 0:
                        max_batch_size = input.batch_sizes[0].item() if hasattr(input.batch_sizes[0], 'item') else int(input.batch_sizes[0])
                        if len(data_shape) >= 2:
                            input_shape = (max_batch_size, len(input.batch_sizes), data_shape[1])
                        else:
                            input_shape = data_shape
                    else:
                        input_shape = data_shape
            elif hasattr(input, 'shape'):
                input_shape = input.shape
        
        # 处理输出形状
        output_shape = None
        if isinstance(output, (tuple, list)):
            if len(output) > 0:
                out = output[0]
                # 处理 PackedSequence 对象
                if hasattr(out, 'data') and hasattr(out, 'batch_sizes'):
                    if hasattr(out.data, 'shape'):
                        data_shape = out.data.shape
                        if hasattr(out, 'batch_sizes') and len(out.batch_sizes) > 0:
                            max_batch_size = out.batch_sizes[0].item() if hasattr(out.batch_sizes[0], 'item') else int(out.batch_sizes[0])
                            if len(data_shape) >= 2:
                                output_shape = (max_batch_size, len(out.batch_sizes), data_shape[1])
                            else:
                                output_shape = data_shape
                        else:
                            output_shape = data_shape
                elif hasattr(out, 'shape'):
                    output_shape = out.shape
        else:
            # 处理 PackedSequence 对象
            if hasattr(output, 'data') and hasattr(output, 'batch_sizes'):
                if hasattr(output.data, 'shape'):
                    data_shape = output.data.shape
                    if hasattr(output, 'batch_sizes') and len(output.batch_sizes) > 0:
                        max_batch_size = output.batch_sizes[0].item() if hasattr(output.batch_sizes[0], 'item') else int(output.batch_sizes[0])
                        if len(data_shape) >= 2:
                            output_shape = (max_batch_size, len(output.batch_sizes), data_shape[1])
                        else:
                            output_shape = data_shape
                    else:
                        output_shape = data_shape
            elif hasattr(output, 'shape'):
                output_shape = output.shape
        
        flops = calculate_flops_per_layer(module, input_shape, output_shape)
        layer_flops.append({
            'module': type(module).__name__,
            'flops': flops,
            'input_shape': str(input_shape) if input_shape is not None else 'N/A',
            'output_shape': str(output_shape) if output_shape is not None else 'N/A'
        })
        return None
    
    # 注册hook
    hooks = []
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # 只hook叶子节点
            hook = module.register_forward_hook(hook_fn)
            hooks.append(hook)
    
    # 执行前向传播
    with torch.no_grad():
        if isinstance(input_data, (tuple, list)):
            output = model(*input_data)
        else:
            output = model(input_data)
    
    # 计算loss的FLOPs（简化估算）
    if target_data is not None and loss_fn is not None:
        with torch.no_grad():
            if isinstance(output, (tuple, list)):
                loss = loss_fn(output[0], target_data) if len(output) > 0 else None
            else:
                loss = loss_fn(output, target_data)
        
        # Loss计算的FLOPs（通常是元素级别的操作）
        if loss is not None and hasattr(loss, 'shape'):
            loss_flops = loss.numel()
        else:
            loss_flops = 0
    else:
        loss_flops = 0
    
    # 移除hook
    for hook in hooks:
        hook.remove()
    
    # 汇总FLOPs
    forward_flops = sum(item['flops'] for item in layer_flops)
    
    # 反向传播的FLOPs通常是前向传播的2-3倍（简化估算为2倍）
    backward_flops = forward_flops * 2
    
    # 总FLOPs
    total_flops = forward_flops + backward_flops + loss_flops
    
    result = {
        'forward_flops': forward_flops,
        'backward_flops': backward_flops,
        'loss_flops': loss_flops,
        'total_flops': total_flops,
        'forward_flops_giga': forward_flops / 1e9,
        'backward_flops_giga': backward_flops / 1e9,
        'total_flops_giga': total_flops / 1e9,
        'layer_details': layer_flops
    }
    
    if verbose:
        print("=" * 60)
        print("一次Loss迭代FLOPs统计")
        print("=" * 60)
        print(f"前向传播FLOPs: {forward_flops:,} ({forward_flops / 1e9:.4f} GFLOPs)")
        print(f"反向传播FLOPs: {backward_flops:,} ({backward_flops / 1e9:.4f} GFLOPs)")
        print(f"Loss计算FLOPs: {loss_flops:,}")
        print(f"总FLOPs: {total_flops:,} ({total_flops / 1e9:.4f} GFLOPs)")
        print("\n各层FLOPs详情:")
        print("-" * 60)
        for i, layer_info in enumerate(layer_flops[:10]):  # 只显示前10层
            print(f"  [{i+1}] {layer_info['module']}: {layer_info['flops']:,} FLOPs")
            print(f"      输入形状: {layer_info['input_shape']}")
            print(f"      输出形状: {layer_info['output_shape']}")
        if len(layer_flops) > 10:
            print(f"  ... 还有 {len(layer_flops) - 10} 层未显示")
        print("=" * 60)
    
    return result
