import sys
sys.path.append(".")

import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from typing import TypedDict, List, Any, Dict, Optional, Union

Instance = TypedDict("Instance", {
    "input_ids": torch.Tensor, "attention_mask": torch.Tensor, "labels": torch.Tensor,
    "serial_id": int, "new_tokens": int, "emotion_label": int
})

SimpleInstance = TypedDict("SimpleInstance", {
    "serial_id": int, "new_tokens": int, "start": int, "middle": int, "end": int, "emotion_label": int
})

class AudioSlowRawDataset:
    def __init__(self, feature_path: str, emotion_label: int, tokenizer: PreTrainedTokenizer, prefix_length: int, stride_size: int, is_eval: bool = False):
        self.tokenizer = tokenizer
        self.feature_path = feature_path
        self.emotion_label = emotion_label
        self.prefix_length = prefix_length
        self.init_window_size = 256  # Changed from 1024 to 256 as requested
        self.stride_size = stride_size
        self.input_ids: Optional[torch.Tensor] = None
        self.raw_dataset: List[SimpleInstance] = []
        self.is_eval = is_eval

    def load_from_input_ids(self, input_ids: torch.Tensor):
        self.input_ids = input_ids
        self.raw_dataset.clear()
        
        # Emotion label to token ID mapping
        emotion_to_token_id = {
            'hap': 38561,  # <emo_hap>
            'sad': 38562,  # <emo_sad>
            'ang': 38563,  # <emo_ang>
            'neu': 38564,  # <emo_neu>
        }
        
        if not self.is_eval:
            # Training set uses original fixed-length chunking logic
            # Insert special tokens and emotion tokens at corresponding positions
            offset = 0  # Record the number of inserted special tokens
            modified_input_ids = self.input_ids
            
            for timestamp, emotion in self.emotion_label:
                # Calculate frame position corresponding to 25Hz sampling rate
                frame_position = int(timestamp * 25)
                # Consider offset of inserted tokens
                adjusted_position = frame_position + offset
                
                if adjusted_position >= modified_input_ids.size(-1):
                    continue
                
                # Insert audio end token and emotion token
                audio_end_token = torch.tensor([[38565]], dtype=input_ids.dtype)
                emotion_token_id = emotion_to_token_id[emotion]
                emotion_token = torch.tensor([[emotion_token_id]], dtype=input_ids.dtype)
                
                # Insert token at specified position
                modified_input_ids = torch.cat([
                    modified_input_ids[:, :adjusted_position],
                    audio_end_token,
                    emotion_token,
                    modified_input_ids[:, adjusted_position:]
                ], dim=-1)
                
                offset += 2  # Update offset
            
            self.input_ids = modified_input_ids
            
            # Redesign chunking logic
            current_pos = 0
            serial_id = 0
            
            while current_pos < self.input_ids.size(-1):
                # Find next audio end token
                next_end_pos = current_pos
                while next_end_pos < self.input_ids.size(-1):
                    if self.input_ids[0, next_end_pos] == 38565:  # Find audio end token
                        break
                    next_end_pos += 1
                
                if next_end_pos >= self.input_ids.size(-1):
                    break
                
                # Calculate start position of context
                context_start = max(0, next_end_pos - self.prefix_length)
                
                # Add to dataset
                self.raw_dataset.append(SimpleInstance(
                    serial_id=serial_id,
                    new_tokens=1,  # Only predict emotion token
                    start=context_start,
                    middle=next_end_pos + 1,  # Position of emotion token
                    end=next_end_pos + 2,  # Include emotion token
                    emotion_label=self.input_ids[0, next_end_pos + 1].item()  # Use actual token ID as label
                ))
                
                serial_id += 1
                current_pos = next_end_pos + 2  # Move to next position
            
            # Verify processed sequence format
            if len(self.raw_dataset) == 0:
                raise ValueError("No valid audio end token and emotion token pairs found")
        else:
            # Test set uses variable length chunking logic
            offset = 0  # Record the number of inserted special tokens
            modified_input_ids = self.input_ids
            
            for timestamp, emotion in self.emotion_label:
                frame_position = int(timestamp * 25)  # 25Hz sampling rate
                adjusted_position = frame_position + offset  # Consider offset of inserted tokens
                
                if adjusted_position >= modified_input_ids.size(-1):
                    continue
                
                # Insert audio end token and emotion token
                audio_end_token = torch.tensor([[38565]], dtype=modified_input_ids.dtype)
                emotion_token_id = emotion_to_token_id[emotion]
                emotion_token = torch.tensor([[emotion_token_id]], dtype=modified_input_ids.dtype)
                
                # Insert token at specified position
                modified_input_ids = torch.cat([
                    modified_input_ids[:, :adjusted_position],
                    audio_end_token,
                    emotion_token,
                    modified_input_ids[:, adjusted_position:]
                ], dim=-1)
                
                # Add to dataset
                self.raw_dataset.append(SimpleInstance(
                    serial_id=len(self.raw_dataset),
                    new_tokens=1,  # Only predict emotion token
                    start=0,
                    middle=adjusted_position + 1,  # Position of emotion token
                    end=adjusted_position + 2,  # Sequence end
                    emotion_label=emotion_token_id
                ))
                
                offset += 2  # Update offset
            
            self.input_ids = modified_input_ids

    def load_from_npy(self):
        """Load audio features from npy file and convert to input_ids"""
        features = np.load(self.feature_path)
        # Convert numpy array to tensor, add batch dimension, and offset by 32000
        input_ids = torch.from_numpy(features).unsqueeze(0) + 32000
        self.load_from_input_ids(input_ids)

class AudioDataset(Dataset):
    def __init__(self, dataset: AudioSlowRawDataset, is_eval: bool = False):
        self.dataset = dataset
        self.is_eval = is_eval
        
    def __len__(self) -> int:
        return len(self.dataset.raw_dataset)
    
    def __getitem__(self, idx: int) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
        ins: SimpleInstance = self.dataset.raw_dataset[idx]
        start, middle, end = ins["start"], ins["middle"], ins["end"]
        emotion_label = ins["emotion_label"]
        
        # First dictionary: emotion prediction
        context_input_ids = self.dataset.input_ids[:, start: middle].long()
        context_labels = torch.full_like(context_input_ids, -100, dtype=torch.long)
        inference_input_ids = self.dataset.input_ids[:, middle: end].long()

        emotion_window_input_ids = torch.cat([context_input_ids, inference_input_ids], dim=-1)
        emotion_window_attention_mask = torch.ones_like(emotion_window_input_ids).long()
        emotion_window_labels = torch.cat([context_labels, inference_input_ids], dim=-1)

        emotion_dict = {
            "input_ids": emotion_window_input_ids,
            "attention_mask": emotion_window_attention_mask,
            "labels": emotion_window_labels,
            "serial_id": ins["serial_id"],
            "new_tokens": ins["new_tokens"],
            "emotion_label": emotion_label
        }
        
        if self.is_eval:
            return emotion_dict  # Test set only returns emotion prediction dictionary

        # Second dictionary: audio token self-regression prediction
        # Take 256 tokens from end position as split point
        audio_middle = max(start, end - 256)
        
        audio_context = self.dataset.input_ids[:, start:audio_middle].long()
        audio_context_labels = torch.full_like(audio_context, -100, dtype=torch.long)
        audio_target = self.dataset.input_ids[:, audio_middle:end-2].long()  # Exclude audio end token and emotion token
        
        audio_window_input_ids = torch.cat([audio_context, audio_target], dim=-1)
        audio_window_attention_mask = torch.ones_like(audio_window_input_ids).long()
        audio_window_labels = torch.cat([audio_context_labels, audio_target], dim=-1)
        
        audio_dict = {
            "input_ids": audio_window_input_ids,
            "attention_mask": audio_window_attention_mask,
            "labels": audio_window_labels,
            "serial_id": ins["serial_id"],
            "new_tokens": audio_target.size(-1),  # Number of audio tokens to predict
            "emotion_label": -1  # Indicates this is an audio prediction task
        }
        
        return [emotion_dict, audio_dict] 