import logging
import torch
from dcase_util.data import DecisionEncoder

def tag2_framelevel_multihot(tag_strings, label2id, start_times, end_times, max_feature_len,
                            downsample_factor=4, hop_size_sec=0.01, center_offset_frames=2.0):
    """
    Convert tags to frame-level multi-hot encoding.
    
    Args:
        tag_strings: List of tag strings for each sample.
        label2id: Dictionary mapping tags to class indices.
        start_times: List of start times for each tag in seconds.
        end_times: List of end times for each tag in seconds.
        downsample_factor: Downsampling factor of the model.
        hop_size_sec: Hop size in seconds.
        
    Returns:
        A tensor of shape (num_samples, num_classes, num_frames) with multi-hot encoding.
    """
    num_classes = len(label2id)
    num_samples = len(tag_strings)
    num_output_frames = max_feature_len // downsample_factor  # Assuming feature_lens is in frames
    # Calculate number of frames

    frame_level_target = torch.zeros(
        (num_samples, num_output_frames, num_classes),
        dtype=torch.float32
    )

    for i, (tag_str, start_t, end_t) in enumerate(zip(tag_strings, start_times, end_times)):
        tags = tag_str.split(";")
        for tag, s, e in zip(tags, start_t, end_t):
            class_idx = label2id[tag]
            start_frame = time_to_frame_index(s, downsample_factor, hop_size_sec, center_offset_frames)
            end_frame = time_to_frame_index(e, downsample_factor, hop_size_sec, center_offset_frames)
            # Ensure indices are within bounds
            start_frame = max(0, start_frame)
            end_frame = min(num_output_frames - 1, end_frame)
            frame_level_target[i, start_frame:end_frame+1, int(class_idx)] = 1.0
            
    return frame_level_target

def decode_intervals(predictions, id2label):
    """
    Decode intervals for predicted sound events from binary predictions.

    Args:
        predictions: Binary predictions of shape (T, C) where T is the number of frames and C is the number of classes.

    Returns:
        List of decoded strong labels.
    """
    decision_encoder = DecisionEncoder()
    result_labels = []
    for i, label_column in enumerate(predictions.T):
        # Convert binary predictions to strong labels
        change_indices = decision_encoder.find_contiguous_regions(label_column)
        for row in change_indices:
            start_frame, end_frame = row[0], row[-1]
            result_labels.append([
                id2label[i],
                frame_index_to_time(start_frame),
                frame_index_to_time(end_frame)  # +1 to include the end frame
            ])

def time_to_frame_index(
    time_sec: float,
    downsample_factor: int = 4,
    hop_size_sec: float = 0.01,
    center_offset_frames: float = 2.0
) -> int:
    """
    Map a time in seconds to the nearest Zipformer output frame index.

    Args:
        time_sec: Timestamp in seconds.
        downsample_factor: Total downsampling factor of the encoder (e.g., 4).
        hop_size_sec: Frame hop size in seconds (usually 0.01s).
        center_offset_frames: Effective center offset in input frames due to asymmetric truncation.

    Returns:
        Closest frame index (int).
    """
    frame_duration_sec = downsample_factor * hop_size_sec
    adjusted_time = time_sec - center_offset_frames * hop_size_sec
    return int(round(adjusted_time / frame_duration_sec))

def frame_index_to_time(
    frame_idx: int,
    downsample_factor: int = 4,
    hop_size_sec: float = 0.01,
    center_offset_frames: float = 2.0
) -> float:
    """
    Map a Zipformer output frame index to its corresponding time in seconds.

    Args:
        frame_idx: Frame index (int).
        downsample_factor: Total downsampling factor of the encoder (e.g., 4).
        hop_size_sec: Frame hop size in seconds (usually 0.01s).
        center_offset_frames: Effective center offset in input frames due to asymmetric truncation.

    Returns:
        Time in seconds (float) corresponding to the frame center.
    """
    frame_duration_sec = downsample_factor * hop_size_sec
    center_time = frame_idx * frame_duration_sec + center_offset_frames * hop_size_sec
    return center_time

def get_timestamps(frame_indices, downsample_factor=4, hop_size_sec=0.01, center_offset_frames=2.0):
    frame_duration_sec = downsample_factor * hop_size_sec
    return frame_indices * frame_duration_sec