# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import sys
import datetime
import pytz
import torch
import torch.distributed as dist


def create_logger(logging_dir, rank, filename="log"):
    """
    Create a logger that writes to a log file and stdout.
    """
    if rank == 0 and logging_dir is not None:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[
                logging.StreamHandler(), 
                logging.FileHandler(f"{logging_dir}/{filename}.txt")
            ]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger


def get_latest_ckpt(checkpoint_dir):
    step_dirs = [d for d in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, d))]
    if len(step_dirs) == 0:
        return None
    step_dirs = sorted(step_dirs, key=lambda x: int(x))
    latest_step_dir = os.path.join(checkpoint_dir, step_dirs[-1])
    return latest_step_dir


def _change_builtin_print(is_master):
    import builtins as __builtin__
    
    builtin_print = __builtin__.print
    if type(builtin_print) != type(open):
        return
    
    def prt(*args, **kwargs):
        force = kwargs.pop('force', False)
        clean = kwargs.pop('clean', False)
        deeper = kwargs.pop('deeper', False)
        if is_master or force:
            if not clean:
                f_back = sys._getframe().f_back
                if deeper and f_back.f_back is not None:
                    f_back = f_back.f_back
                file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
                time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
                builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
            else:
                builtin_print(*args, **kwargs)
    
    __builtin__.print = prt


def count_bin_avg_loss(packed_timesteps, mse, device, curr_step, loss_dict):
    if packed_timesteps is not None and len(packed_timesteps) > 0 and curr_step % 10 == 0:
        # Define 10 bins: [0, 0.1), [0.1, 0.2), ..., [0.9, 1.0]
        num_bins = 10
        
        # Calculate bin indices for each timestep (clamp to handle edge case of timestep=1.0)
        bin_indices = torch.clamp(
            torch.floor(packed_timesteps * num_bins).long(), 
            0, num_bins - 1
        )
        
        # Calculate per-sample MSE loss (mean over d dimension)
        per_sample_mse = mse.detach().clone()
        if per_sample_mse.dim() > 1:
            per_sample_mse = per_sample_mse.mean(dim=-1)
        
        # Initialize bin statistics
        bin_loss_sum = torch.zeros(num_bins, device=device)
        bin_count = torch.zeros(num_bins, device=device)
        
        # Accumulate statistics for each bin
        for bin_idx in range(num_bins):
            mask = (bin_indices == bin_idx)
            if mask.any():
                bin_loss_sum[bin_idx] = per_sample_mse[mask].sum()
                bin_count[bin_idx] = mask.sum().float()
        
        # All-reduce to get global statistics
        dist.all_reduce(bin_loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(bin_count, op=dist.ReduceOp.SUM)
        
        # Calculate global average loss for each bin
        bin_avg_loss = torch.zeros(num_bins, device=device)
        for bin_idx in range(num_bins):
            if bin_count[bin_idx] > 0:
                bin_avg_loss[bin_idx] = bin_loss_sum[bin_idx] / bin_count[bin_idx]
            else:
                bin_avg_loss[bin_idx] = torch.tensor(float('nan'), device=device)
        
        # Store bin statistics in loss_dict for logging
        loss_dict["bin_avg_loss"] = bin_avg_loss
        loss_dict["bin_count"] = bin_count
    return loss_dict