import copy
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Literal, List

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import einsum, rearrange
from jaxtyping import Float, Int
from transformers import T5PreTrainedModel
from transformers.models.t5.modeling_t5 import ACT2FN, T5LayerNorm, T5Stack
from .modeling_t5_instance_rope import ACT2FN as ACT2FN_rope, T5LayerNorm as T5LayerNorm_rope, T5Stack as T5Stack_rope
from transformers.utils import ModelOutput

from .utils import InstanceNorm, Patch, get_log_decay_weights
from .configuration_kairos import KairosConfig

ROPE_VARIANTS = [
    "instance_wise_rope"
]

def size_to_mask(
        max_size: int,
        sizes: torch.Tensor,
):
    mask = torch.arange(max_size, device=sizes.device)
    return torch.lt(mask, sizes.unsqueeze(-1))


@dataclass
class KairosOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    prediction_outputs: Optional[torch.Tensor] = None
    attentions: Optional[torch.Tensor] = None
    cross_attentions: Optional[torch.Tensor] = None
    future_target: Optional[torch.Tensor] = None


class DynamicPatch(nn.Module):
    def __init__(
            self,
            max_patch_size: int,
            patch_stride: int,
            levels: int,
            n_null_experts: int = 0,
            n_activated_experts: int = 1,
            moe_inter_dim: int = 1408,
            update_bias_rate: float = 0.001,
            target_dist: list = None,
            route_scale: float = 1.0,
    ):
        super().__init__()
        self.max_patch_size = max_patch_size
        self.patch_stride = patch_stride
        self.levels = levels
        self.patch = Patch(max_patch_size, patch_stride)
        args = ModelArgs()
        args.dim = max_patch_size
        args.n_real_experts = levels
        args.n_null_experts = n_null_experts
        args.n_routed_experts = args.n_real_experts + args.n_null_experts
        args.n_activated_experts = n_activated_experts
        args.moe_inter_dim = moe_inter_dim
        args.update_bias_rate = update_bias_rate
        args.target_dist = target_dist
        args.route_scale = route_scale
        self.moe = MoE(args)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        input:
            x: [1, 8] = [[ 1.5550, -0.0704,  0.4447, -0.4042,  0.0302, -0.1819,  0.0508,  0.4408]]
            mask: [1, 8] = [[1., 1., 1., 1., 1., 1., 1., 1.]]
        output:
            patched_x: [1, 3, 4] = [[[ 1.5550, -0.0704,  0.0000,  0.0000],
                                [ 0.4447, -0.4042,  0.0000,  0.0000],
                                [ 0.0302, -0.1819,  0.0508,  0.4408]]]
            patched_mask: [1, 3, 4] = [[[1., 1., 0., 0.],
                                    [1., 1., 0., 0.],
                                    [1., 1., 1., 1.]]]
            size: [1, 3] = [[2, 2, 4]]
        """
        patched_x, patched_mask = self.patch(x), self.patch(mask)
        size = torch.full(patched_x.shape[:-1], self.max_patch_size, dtype=torch.int64, device=x.device)
        patched_x, patched_mask, size, weights, indices, x_final = self._divide_patches_by_moe(patched_x, patched_mask, size)
        return patched_x, patched_mask, size, weights, indices, x_final

    def _divide_patches(
            self, x: torch.Tensor, size: torch.Tensor, to_divide: torch.Tensor, weights: Optional[torch.Tensor] = None,
            expert_indices: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        batch, patch_len, patch_size = x.shape

        # Calculate indices for the new positions
        div_counts = to_divide.sum(dim=1)  # [batch]
        if div_counts.max().item() == 0:
            if weights is not None:
                return x, size, weights, expert_indices
            else:
                return x, size
        new_patch_len: int = patch_len + div_counts.max().item()  # type: ignore

        # Create position mapping for each batch
        batch_idx = torch.arange(batch, device=x.device)[:, None].expand(-1, patch_len)  # [batch, patch_len]

        # Create indices for scattered elements
        base_idx = torch.arange(patch_len, device=x.device)[None, :].expand(batch, -1)  # [batch, patch_len]
        offset = torch.cumsum(to_divide.float(), dim=1).long()  # [batch, patch_len]
        new_positions = base_idx + offset  # [batch, patch_len]

        # Initialize output tensors
        new_x = torch.zeros(batch, new_patch_len, patch_size, device=x.device, dtype=x.dtype)
        new_size = torch.zeros(batch, new_patch_len, device=size.device, dtype=size.dtype)
        if weights is not None:
            new_weights_shape = (batch, new_patch_len) + weights.shape[2:]
            new_weights = torch.zeros(new_weights_shape, dtype=weights.dtype, device=weights.device)
        if expert_indices is not None:
            new_expert_indices_shape = (batch, new_patch_len) + expert_indices.shape[2:]
            new_expert_indices = torch.zeros(new_expert_indices_shape, dtype=expert_indices.dtype,
                                             device=expert_indices.device)

        # Scatter undivided patches
        undivided = ~to_divide
        new_x[batch_idx[undivided], new_positions[undivided]] = x[undivided]
        new_size[batch_idx[undivided], new_positions[undivided]] = size[undivided]
        if weights is not None:
            new_weights[batch_idx[undivided], new_positions[undivided]] = weights[undivided]
        if expert_indices is not None:
            new_expert_indices[batch_idx[undivided], new_positions[undivided]] = expert_indices[undivided]

        # Scatter divided patches
        divided = to_divide
        # Get the sizes for divided patches
        div_sizes = size[divided].div(2, rounding_mode="floor")

        # First half of divided patches
        first_half_idx = torch.arange(patch_size, device=x.device)[None, :] < div_sizes[:, None]
        new_x[batch_idx[divided], new_positions[divided] - 1] = torch.where(
            first_half_idx, x[divided], torch.zeros_like(x[divided])
        )
        new_size[batch_idx[divided], new_positions[divided] - 1] = div_sizes
        if weights is not None:
            new_weights[batch_idx[divided], new_positions[divided] - 1] = weights[divided]
        if expert_indices is not None:
            new_expert_indices[batch_idx[divided], new_positions[divided] - 1] = expert_indices[divided]

        # Second half of divided patches
        second_half_idx = (torch.arange(patch_size, device=x.device)[None, :] >= div_sizes[:, None]) & (
            torch.arange(patch_size, device=x.device)[None, :] < size[divided][:, None]
        )
        second_half_values = torch.where(second_half_idx, x[divided], torch.zeros_like(x[divided]))
        new_x[batch_idx[divided], new_positions[divided]] = torch.roll(
            second_half_values, -div_sizes.max().item(), dims=1
        )
        new_size[batch_idx[divided], new_positions[divided]] = size[divided] - div_sizes
        if weights is not None:
            new_weights[batch_idx[divided], new_positions[divided]] = weights[divided]
            if expert_indices is not None:
                new_expert_indices[batch_idx[divided], new_positions[divided]] = expert_indices[divided]
                return new_x, new_size, new_weights, new_expert_indices
            return new_x, new_size, new_weights, None
        else:
            return new_x, new_size

    def _divide_patches_by_moe(
            self, x: torch.Tensor, mask: torch.Tensor, size: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        input:
            x: [1, 2, 4] = [[[1.5550, -0.0704,  0.4447, -0.4042], [0.0302, -0.1819,  0.0508,  0.4408]]]
            mask: [1, 2, 4] = [[[1., 1., 1., 1.], [1., 1., 1., 1.]]]
            size: [1, 2] = [[4, 4]]
            weights: [1, 2, 2] = [[[0,6, 0.3], [0.5, 0.3]]]
            indices: [1, 2, 2] = [[[1, 0], [1, 2]]]
        output:
            new_x: [1, 6, 4] = [[[ 1.5550, -0.0704,  0.0000,  0.0000],
                                [ 0.4447, -0.4042,  0.0000,  0.0000],
                                [ 0.0302, 0.0000,  0.0000,  0.0000],
                                [ -0.1819, 0.0000,  0.0000,  0.0000],
                                [ 0.0508, 0.0000,  0.0000,  0.0000],
                                [ 0.4408, 0.0000,  0.0000,  0.0000]]]
            new_mask: [1, 6, 4] = [[[1., 1., 0., 0.],
                                    [1., 1., 0., 0.],
                                    [1., 0., 0., 0.],
                                    [1., 0., 0., 0.],
                                    [1., 0., 0., 0.],
                                    [1., 0., 0., 0.]]]
            new_size: [1, 6] = [[2, 2, 1, 1, 1, 1]]
            weights: [1, 6, 2] = [[[0,6, 0.3], [0,6, 0.3], [0.5, 0.3], [0.5, 0.3], [0.5, 0.3], [0.5, 0.3]]]
            indices: [1, 6, 2] = [[[1, 0], [1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]]
            parent_blocks: [1, 6, 4] = [[[ 1.5550, -0.0704,  0.4447, -0.4042],
                                        [ 1.5550, -0.0704,  0.4447, -0.4042],
                                        [ 0.0302, -0.1819,  0.0508,  0.4408],
                                        [ 0.0302, -0.1819,  0.0508,  0.4408],
                                        [ 0.0302, -0.1819,  0.0508,  0.4408],
                                        [ 0.0302, -0.1819,  0.0508,  0.4408]]]
            parent_blocks_mask: [1, 6, 4] = [[[1., 1., 1., 1.],
                        [1., 1., 1., 1.],
                        [1., 1., 1., 1.],
                        [1., 1., 1., 1.],
                        [1., 1., 1., 1.],
                        [1., 1., 1., 1.]]]
            granularity_mask: [1, 6, 3, 4] = [[[[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 0., 0.]],
                        [[1., 1., 1., 1.], [0., 0., 1., 1.], [0., 0., 0., 0.]],
                        [[0., 0., 0., 0.], [1., 1., 0., 0.], [1., 0., 0., 0.]],
                        [[0., 0., 0., 0.], [1., 1., 0., 0.], [0., 1., 0., 0.]],
                        [[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 0.]],
                        [[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 0., 1.]]]]
        """
        batch, patch_num, patch_len = x.shape
        weights, indices = self.moe(x)
        expert_indices = indices.view(batch, patch_num, -1)  # [batch, patch_num, n_experts]
        weights = weights.view(batch, patch_num, -1)  # [batch, patch_num, n_experts]
        n_real_experts = self.moe.n_real_experts 
        is_real_expert_mask = (expert_indices < n_real_experts)
        masked_indices = torch.where(is_real_expert_mask, expert_indices, -1)
        indices, _ = torch.max(masked_indices, dim=-1) # [batch, patch_num]
        indices[indices == -1] = 0

        # Create initial parent mapping and blocks
        original_patches, original_mask, parent_mapping = self._create_initial_setup(x, mask, size)

        # Save original expert indices for granularity mask generation
        original_expert_indices = expert_indices

        history_x: List[torch.Tensor] = []
        history_mask: List[torch.Tensor] = []
        history_parent_mapping: List[torch.Tensor] = []
        history_position_mapping: List[torch.Tensor] = []

        # Initialize position mapping for tracking patch positions within original patches
        position_mapping = self._create_initial_position_mapping(x, size)

        history_x.append(x)
        history_mask.append(mask)
        history_parent_mapping.append(parent_mapping)
        history_position_mapping.append(position_mapping)

        # Apply division to both x and mask
        for i in range(self.levels - 1):
            current_patch_num = x.size(1)
            to_divide = (indices > i).to(x.device)
            div_counts = to_divide.sum(dim=1)  # [B]

            new_x, new_size, weights, expert_indices = self._divide_patches(x, size, to_divide, weights, expert_indices)
            new_mask, _ = self._divide_patches(mask, size, to_divide)

            # Update parent mapping and position mapping
            new_parent_mapping = self._update_parent_mapping(parent_mapping, to_divide, div_counts, x.device)
            new_position_mapping = self._update_position_mapping(position_mapping, to_divide, div_counts, x.device)

            history_x.append(new_x)
            history_mask.append(new_mask)
            history_parent_mapping.append(new_parent_mapping)
            history_position_mapping.append(new_position_mapping)

            new_patch_nums = current_patch_num + div_counts
            total_elements = new_patch_nums.sum().item()

            expand_mask = torch.stack([
                torch.ones_like(to_divide),
                to_divide
            ], dim=-1).view(batch, -1)  # [B, 2L]

            expanded_indices = indices.unsqueeze(-1).expand(-1, -1, 2).reshape(batch, -1)
            valid_indices = expanded_indices[expand_mask]

            assert expand_mask.sum().item() == total_elements, \
                f"Mask sum {expand_mask.sum().item()} vs total {total_elements}"
            assert valid_indices.numel() == total_elements, \
                f"Indices {valid_indices.numel()} vs total {total_elements}"

            max_len = new_patch_nums.max()
            indices = torch.full((batch, max_len), -1, dtype=torch.long, device=x.device)

            valid_pos_mask = (torch.arange(max_len, device=x.device)[None, :] < new_patch_nums[:, None])

            indices[valid_pos_mask] = valid_indices

            x, mask, size = new_x, new_mask, new_size
            parent_mapping = new_parent_mapping
            position_mapping = new_position_mapping

        parent_blocks = self._map_to_parent_blocks(original_patches, parent_mapping, x.shape)
        parent_blocks_mask = self._map_to_parent_blocks(original_mask, parent_mapping, mask.shape)
        granularity_mask = self._create_granularity_mask(original_expert_indices, parent_mapping, position_mapping,
                                                         x.shape, patch_len)
        # Generate x_final with rearranged features
        x_final = self._generate_x_final(parent_blocks, parent_blocks_mask, granularity_mask)

        return new_x, new_mask, new_size, weights, expert_indices, x_final

    def _create_initial_setup(self, x: torch.Tensor, mask: torch.Tensor, size: torch.Tensor) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Create initial original patches, masks and parent mapping.
        Assumes patch_num == num_original_patches (one patch per original patch).

        Args:
            x: [batch, patch_num, patch_len] - current patches
            mask: [batch, patch_num, patch_len] - current masks
            size: [batch, num_original_patches] - sizes of original patches

        Returns:
            original_patches: [batch, num_original_patches, patch_len] - reconstructed original patches
            original_mask: [batch, num_original_patches, patch_len] - reconstructed original masks
            parent_mapping: [batch, patch_num] - mapping from current patches to original patch indices
        """
        batch, patch_num, patch_len = x.shape
        num_original_patches = size.shape[1]

        # Verify expected structure
        assert patch_num == num_original_patches, f"patch_num {patch_num} != num_original_patches {num_original_patches}"

        # x and mask already represent one patch per original patch
        original_patches = x.clone()
        original_mask = mask.clone()

        # Create simple parent mapping: each patch maps to itself
        parent_mapping = torch.arange(patch_num, device=x.device).unsqueeze(0).expand(batch, -1)

        return original_patches, original_mask, parent_mapping

    def _create_initial_position_mapping(self, x: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
        """
        Create initial position mapping for tracking patch positions within original patches.

        Args:
            x: [batch, patch_num, patch_len] - current patches
            size: [batch, num_original_patches] - sizes of original patches

        Returns:
            position_mapping: [batch, patch_num, 2] - [start_pos, end_pos] within original patch
        """
        batch, patch_num, patch_len = x.shape

        # Initially, each patch covers the entire original patch space
        # position_mapping stores [start_pos, end_pos] within the original patch
        position_mapping = torch.zeros((batch, patch_num, 2), dtype=torch.long, device=x.device)

        # Each patch initially spans the full range [0, patch_len)
        position_mapping[:, :, 0] = 0  # start_pos = 0
        position_mapping[:, :, 1] = patch_len  # end_pos = patch_len

        return position_mapping

    def _update_parent_mapping(self, parent_mapping: torch.Tensor, to_divide: torch.Tensor,
                               div_counts: torch.Tensor, device: torch.device) -> torch.Tensor:
        """
        Update parent mapping for divided patches.
        """
        batch, current_patch_num = parent_mapping.shape
        new_patch_nums = current_patch_num + div_counts
        max_new_patch_num = new_patch_nums.max()

        # Create repeat counts: 1 for non-divided patches, 2 for divided patches
        repeat_counts = 1 + to_divide.long()  # [batch, current_patch_num]

        # Initialize output tensor
        new_parent_mapping = torch.full((batch, max_new_patch_num), -1, dtype=torch.long, device=device)

        # Calculate positions for each element
        positions = torch.cumsum(repeat_counts, dim=1) - repeat_counts  # Starting position for each element
        batch_idx = torch.arange(batch, device=device).unsqueeze(1).expand(-1,
                                                                           current_patch_num)  # [batch, current_patch_num]

        # Place first copy (always exists)
        valid_mask_1 = positions < max_new_patch_num
        new_parent_mapping[batch_idx[valid_mask_1], positions[valid_mask_1]] = parent_mapping[valid_mask_1]

        # Place second copy (only for divided patches)
        second_positions = positions + 1
        valid_mask_2 = to_divide.bool() & (second_positions < max_new_patch_num)
        new_parent_mapping[batch_idx[valid_mask_2], second_positions[valid_mask_2]] = parent_mapping[valid_mask_2]

        return new_parent_mapping

    def _update_position_mapping(self, position_mapping: torch.Tensor, to_divide: torch.Tensor,
                                 div_counts: torch.Tensor, device: torch.device) -> torch.Tensor:
        """
        Update position mapping for divided patches.
        """
        batch, current_patch_num, _ = position_mapping.shape
        new_patch_nums = current_patch_num + div_counts
        max_new_patch_num = new_patch_nums.max()

        # Create repeat counts: 1 for non-divided patches, 2 for divided patches
        repeat_counts = 1 + to_divide.long()  # [batch, current_patch_num]

        # Initialize output tensor
        new_position_mapping = torch.zeros((batch, max_new_patch_num, 2), dtype=torch.long, device=device)

        # Calculate positions for each element
        positions = torch.cumsum(repeat_counts, dim=1) - repeat_counts  # Starting position for each element
        batch_idx = torch.arange(batch, device=device).unsqueeze(1).expand(-1,
                                                                           current_patch_num)  # [batch, current_patch_num]

        # For non-divided patches: keep original position mapping
        valid_mask_1 = positions < max_new_patch_num
        new_position_mapping[batch_idx[valid_mask_1], positions[valid_mask_1]] = position_mapping[valid_mask_1]

        # For divided patches: split the position range
        second_positions = positions + 1
        valid_mask_2 = to_divide.bool() & (second_positions < max_new_patch_num)

        if valid_mask_2.any():
            # Get original start and end positions - correct indexing
            divided_position_mapping = position_mapping[valid_mask_2]  # [num_divided, 2]
            orig_start = divided_position_mapping[:, 0]  # [num_divided]
            orig_end = divided_position_mapping[:, 1]  # [num_divided]
            mid_pos = (orig_start + orig_end) // 2

            # Get batch and position indices for divided patches
            batch_indices_divided = batch_idx[valid_mask_2]  # [num_divided]
            first_positions_divided = positions[valid_mask_2]  # [num_divided]
            second_positions_divided = second_positions[valid_mask_2]  # [num_divided]

            # First half: [start, mid]
            new_position_mapping[batch_indices_divided, first_positions_divided, 0] = orig_start
            new_position_mapping[batch_indices_divided, first_positions_divided, 1] = mid_pos

            # Second half: [mid, end]
            new_position_mapping[batch_indices_divided, second_positions_divided, 0] = mid_pos
            new_position_mapping[batch_indices_divided, second_positions_divided, 1] = orig_end

        return new_position_mapping

    def _map_to_parent_blocks(self, original_patches: torch.Tensor, parent_mapping: torch.Tensor,
                              target_shape: Tuple[int, int, int]) -> torch.Tensor:
        """
        Map parent indices to actual parent blocks.

        Args:
            original_patches: [batch, num_original_patches, patch_len] - original patch representations
            parent_mapping: [batch, new_patch_num] - mapping from new patches to original patch indices
            target_shape: (batch, new_patch_num, patch_len) - target shape

        Returns:
            parent_blocks: [batch, new_patch_num, patch_len] - parent blocks for each new patch
        """
        batch, new_patch_num, patch_len = target_shape

        # Create mask for valid indices
        valid_mask = parent_mapping >= 0  # [batch, new_patch_num]

        # Clamp negative indices to 0 to avoid indexing errors
        safe_parent_mapping = torch.clamp(parent_mapping, min=0)  # [batch, new_patch_num]

        # Use advanced indexing to gather parent blocks
        batch_indices = torch.arange(batch, device=original_patches.device).view(-1, 1)  # [batch, 1]
        parent_blocks = original_patches[batch_indices, safe_parent_mapping]  # [batch, new_patch_num, patch_len]

        # Zero out invalid positions using broadcasting
        parent_blocks = parent_blocks * valid_mask.unsqueeze(-1).float()

        return parent_blocks

    def _create_granularity_mask(self, original_expert_indices: torch.Tensor, parent_mapping: torch.Tensor,
                                 position_mapping: torch.Tensor, target_shape: Tuple[int, int, int],
                                 original_patch_len: int) -> torch.Tensor:
        """
        Create granularity mask.

        Args:
            original_expert_indices: [batch, num_original_patches, n_experts] - expert indices for original patches
            parent_mapping: [batch, new_patch_num] - mapping from new patches to original patch indices
            position_mapping: [batch, new_patch_num, 2] - position ranges within original patches
            target_shape: (batch, new_patch_num, patch_len) - target shape
            original_patch_len: length of original patches

        Returns:
            granularity_mask: [batch, new_patch_num, max_granularity, original_patch_len] - granularity masks
        """
        batch, new_patch_num, patch_len = target_shape

        # Determine max granularity from expert indices
        max_granularity = self.levels

        granularity_mask = torch.zeros((batch, new_patch_num, max_granularity, original_patch_len),
                                       dtype=torch.float32, device=parent_mapping.device)

        # Create mask for valid patches
        valid_mask = parent_mapping >= 0  # [batch, new_patch_num]
        safe_parent_mapping = torch.clamp(parent_mapping, min=0)  # [batch, new_patch_num]

        # Get expert indices for each patch's parent
        batch_indices = torch.arange(batch, device=parent_mapping.device).view(-1, 1)  # [batch, 1]
        patch_expert_indices = original_expert_indices[
            batch_indices, safe_parent_mapping]  # [batch, new_patch_num, n_experts]

        # Get position ranges for each patch
        if hasattr(position_mapping, 'shape') and len(position_mapping.shape) == 3:
            start_positions = position_mapping[:, :, 0]  # [batch, new_patch_num]
            end_positions = position_mapping[:, :, 1]  # [batch, new_patch_num]
        else:
            # Fallback: compute positions based on patch structure
            start_positions, end_positions = self._compute_patch_positions(
                parent_mapping, new_patch_num, original_patch_len, batch)

        # For each granularity level, generate masks
        for granularity in range(max_granularity):
            # Check which patches have this granularity level
            has_granularity = (patch_expert_indices == granularity).any(dim=-1) & valid_mask  # [batch, new_patch_num]

            if has_granularity.any():
                # Generate mask pattern based on granularity level
                mask_pattern = self._generate_granularity_pattern(
                    granularity, start_positions, end_positions, original_patch_len,
                    has_granularity, max_granularity)

                granularity_mask[:, :, granularity, :] = mask_pattern

        return granularity_mask

    def _compute_patch_positions(self, parent_mapping: torch.Tensor, new_patch_num: int,
                                 original_patch_len: int, batch: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute patch positions.
        """
        device = parent_mapping.device

        # For simplicity in this implementation, assume equal division within parents
        # This works for the given example structure

        # Initialize positions - fallback to simple equal division
        start_positions = torch.zeros((batch, new_patch_num), dtype=torch.long, device=device)
        end_positions = torch.full((batch, new_patch_num), original_patch_len, dtype=torch.long, device=device)

        # Valid mask
        valid_mask = parent_mapping >= 0

        if valid_mask.any():
            # Create a mask for each possible parent value using broadcasting
            max_parent = parent_mapping.max().item() if parent_mapping.numel() > 0 else 0
            parent_values = torch.arange(max_parent + 1, device=device)  # [max_parent + 1]

            # Broadcasting: [batch, new_patch_num, 1] == [max_parent + 1]
            parent_masks = (parent_mapping.unsqueeze(-1) == parent_values.unsqueeze(0).unsqueeze(
                0))  # [batch, new_patch_num, max_parent + 1]

            # Count patches per parent per batch - [batch, max_parent + 1]
            patches_per_parent = parent_masks.sum(dim=1)

            # Create cumulative indices within each parent group using broadcasting
            # For each parent, create relative positions
            cumsum_masks = torch.cumsum(parent_masks.float(), dim=1)  # [batch, new_patch_num, max_parent + 1]
            relative_indices = (cumsum_masks - 1) * parent_masks.float()  # [batch, new_patch_num, max_parent + 1]

            # Calculate patch sizes per parent - [batch, max_parent + 1]
            patch_sizes = original_patch_len // patches_per_parent.clamp(min=1)

            # Calculate start positions using broadcasting - [batch, new_patch_num, max_parent + 1]
            start_positions_expanded = (relative_indices * patch_sizes.unsqueeze(1)).long()
            end_positions_expanded = start_positions_expanded + patch_sizes.unsqueeze(1)

            # Sum across parent dimension (only one will be non-zero per patch)
            final_start_positions = start_positions_expanded.sum(dim=-1)  # [batch, new_patch_num]
            final_end_positions = end_positions_expanded.sum(dim=-1)  # [batch, new_patch_num]

            # Apply only to valid patches
            start_positions = torch.where(valid_mask, final_start_positions, start_positions)
            end_positions = torch.where(valid_mask, final_end_positions.clamp(max=original_patch_len), end_positions)

        return start_positions, end_positions

    def _generate_granularity_pattern(self, granularity: int, start_positions: torch.Tensor,
                                      end_positions: torch.Tensor, original_patch_len: int,
                                      has_granularity: torch.Tensor, max_granularity: int) -> torch.Tensor:
        """
        Generate granularity pattern.
        """
        batch, new_patch_num = has_granularity.shape
        device = has_granularity.device

        mask_pattern = torch.zeros((batch, new_patch_num, original_patch_len), dtype=torch.float32, device=device)

        if has_granularity.any():
            if granularity == 0:
                # Coarsest granularity: activate entire patch
                mask_pattern[has_granularity] = 1.0

            else:
                # For other granularities, calculate activation length
                activation_length = original_patch_len // (2 ** granularity)
                activation_length = max(1, activation_length)  # Ensure at least 1

                # Get patches that have this granularity
                batch_idx, patch_idx = torch.where(has_granularity)

                if len(batch_idx) > 0:
                    # Get start positions for patches with this granularity
                    start_pos = start_positions[batch_idx, patch_idx]  # [num_patches_with_granularity]

                    # Calculate slot indices for all patches
                    slot_indices = start_pos // activation_length  # [num_patches_with_granularity]

                    # Calculate activation ranges for all patches
                    activation_starts = slot_indices * activation_length  # [num_patches_with_granularity]
                    activation_ends = torch.clamp(activation_starts + activation_length,
                                                  max=original_patch_len)  # [num_patches_with_granularity]

                    # Use broadcasting to create range masks
                    # Create position grid: [1, original_patch_len]
                    pos_grid = torch.arange(original_patch_len, device=device).unsqueeze(0)  # [1, original_patch_len]

                    # Expand activation ranges: [num_patches_with_granularity, 1]
                    starts_expanded = activation_starts.unsqueeze(1)  # [num_patches_with_granularity, 1]
                    ends_expanded = activation_ends.unsqueeze(1)  # [num_patches_with_granularity, 1]

                    # Create range masks using broadcasting: [num_patches_with_granularity, original_patch_len]
                    range_masks = (pos_grid >= starts_expanded) & (pos_grid < ends_expanded)

                    # Apply range masks to corresponding positions in mask_pattern
                    mask_pattern[batch_idx, patch_idx] = range_masks.float()

        return mask_pattern

    def _generate_x_final(self, parent_blocks: torch.Tensor, parent_blocks_mask: torch.Tensor,
                          granularity_mask: torch.Tensor) -> torch.Tensor:
        """
        Generate x_final with rearranged features.

        Args:
            parent_blocks: [batch, patch_num, patch_len]
            parent_blocks_mask: [batch, patch_num, patch_len]
            granularity_mask: [batch, patch_num, max_granularity, patch_len]

        Returns:
            x_final: [max_granularity, batch*patch_num, max_feat_size * 2]
        """
        batch, patch_num, patch_len = parent_blocks.shape
        max_granularity = granularity_mask.shape[2]
        total_patches = batch * patch_num

        # Assume we have predefined feature sizes for each granularity level
        # This should be consistent with self.in_features_ls
        feat_sizes = [patch_len // (2 ** i) for i in range(max_granularity)]
        feat_sizes = [max(1, size) for size in feat_sizes]  # Ensure at least 1
        max_feat_size = max(feat_sizes) if feat_sizes else patch_len

        # Initialize output tensor
        x_final = torch.zeros((max_granularity, total_patches, max_feat_size * 2),
                              dtype=parent_blocks.dtype, device=parent_blocks.device)

        # Flatten spatial dimensions for easier processing
        parent_blocks_flat = parent_blocks.reshape(total_patches, patch_len)  # [total_patches, patch_len]
        parent_blocks_mask_flat = parent_blocks_mask.reshape(total_patches, patch_len)  # [total_patches, patch_len]
        granularity_mask_flat = granularity_mask.reshape(total_patches, max_granularity,
                                                         patch_len)  # [total_patches, max_granularity, patch_len]

        # Process each granularity level
        for granularity in range(max_granularity):
            feat_size = feat_sizes[granularity]

            # Get granularity mask for this level
            current_granularity_mask = granularity_mask_flat[:, granularity, :]  # [total_patches, patch_len]

            # Apply granularity mask to parent blocks and parent blocks mask
            masked_parent_blocks = parent_blocks_flat * current_granularity_mask  # [total_patches, patch_len]
            masked_parent_blocks_mask = parent_blocks_mask_flat * current_granularity_mask  # [total_patches, patch_len]

            # Rearrange each part separately to maintain structure
            rearranged_parent_blocks = self._rearrange_effective_values(
                masked_parent_blocks, current_granularity_mask, feat_size, max_feat_size)

            rearranged_parent_blocks_mask = self._rearrange_effective_values(
                masked_parent_blocks_mask, current_granularity_mask, feat_size, max_feat_size)

            # Combine rearranged parts: [parent_blocks | parent_blocks_mask]
            rearranged_features = torch.cat([rearranged_parent_blocks, rearranged_parent_blocks_mask], dim=-1)

            # Store in x_final
            x_final[granularity] = rearranged_features

        return x_final

    def _rearrange_effective_values(self, features: torch.Tensor, mask: torch.Tensor,
                                    target_feat_size: int, max_feat_size: int) -> torch.Tensor:
        """
        Rearrange effective values to front positions.

        Args:
            features: [total_patches, feat_dim] - input features (single part: either parent_blocks or parent_blocks_mask)
            mask: [total_patches, feat_dim] - mask indicating effective positions
            target_feat_size: actual feature size for this granularity
            max_feat_size: maximum feature size across all granularities (for output shape consistency)

        Returns:
            rearranged: [total_patches, max_feat_size] - rearranged features (single part)
        """
        total_patches, feat_dim = features.shape
        device = features.device

        # Initialize output with max_feat_size to ensure consistent shape
        rearranged = torch.zeros((total_patches, max_feat_size), dtype=features.dtype, device=device)

        # Find effective positions
        effective_mask = mask > 0  # [total_patches, feat_dim]

        # Create position indices for sorting
        position_indices = torch.arange(feat_dim, device=device).unsqueeze(0).expand(total_patches,
                                                                                     -1)  # [total_patches, feat_dim]

        # Sort to get effective positions first - use large value for invalid positions
        masked_positions = torch.where(effective_mask, position_indices.float(),
                                       float('inf'))  # [total_patches, feat_dim]
        sorted_positions, sort_indices = torch.sort(masked_positions, dim=1)  # [total_patches, feat_dim]

        # Count effective values per patch
        effective_counts = effective_mask.sum(dim=1)  # [total_patches]

        # Create output position indices using broadcasting - limit to target_feat_size
        output_indices = torch.arange(target_feat_size, device=device).unsqueeze(0)  # [1, target_feat_size]

        # Determine which output positions should be filled
        take_counts = torch.clamp(effective_counts, max=target_feat_size).unsqueeze(1)  # [total_patches, 1]
        valid_output_mask = output_indices < take_counts  # [total_patches, target_feat_size]

        # Get all valid (batch, output_pos) pairs
        batch_indices, output_positions = torch.where(valid_output_mask)  # [num_valid], [num_valid]

        if len(batch_indices) > 0:
            # Get corresponding source
            source_indices = sort_indices[batch_indices, output_positions]  # [num_valid]

            # Gather source values
            source_values = features[batch_indices, source_indices]  # [num_valid]

            # Place values in output (only in the first target_feat_size positions)
            rearranged[batch_indices, output_positions] = source_values

        return rearranged


class ResidualBlock(nn.Module):
    def __init__(
            self,
            in_dim: int,
            h_dim: int,
            out_dim: int,
            act_fn_name: str,
            dropout_p: float = 0.0,
            use_layer_norm: bool = False,
            position_embedding_type: str = "T5PE",
    ) -> None:
        super().__init__()
        self.position_embedding_type = position_embedding_type

        self.dropout = nn.Dropout(dropout_p)
        self.hidden_layer = nn.Linear(in_dim, h_dim)
        if self.position_embedding_type in ROPE_VARIANTS:
            self.act = ACT2FN_rope[act_fn_name]
        else:
            self.act = ACT2FN[act_fn_name]
        self.output_layer = nn.Linear(h_dim, out_dim)
        self.residual_layer = nn.Linear(in_dim, out_dim)

        self.use_layer_norm = use_layer_norm
        if use_layer_norm:
            if self.position_embedding_type in ROPE_VARIANTS:
                self.layer_norm = T5LayerNorm_rope(out_dim)
            else:
                self.layer_norm = T5LayerNorm(out_dim)

    def forward(self, x: torch.Tensor):
        hid = self.act(self.hidden_layer(x))
        out = self.dropout(self.output_layer(hid))
        res = self.residual_layer(x)

        out = out + res

        if self.use_layer_norm:
            return self.layer_norm(out)
        return out


class MultiInResidualBlock(nn.Module):
    def __init__(
            self,
            in_dim_ls: Tuple[int, ...],
            h_dim: int,
            out_dim: int,
            act_fn_name: str,
            dropout_p: float = 0.0,
            use_layer_norm: bool = False,
            position_embedding_type: str = "T5PE",
    ) -> None:
        super().__init__()
        self.position_embedding_type = position_embedding_type

        self.dropout = nn.Dropout(dropout_p)
        self.hidden_layer = MultiInSizeLinear(in_dim_ls, h_dim)
        if self.position_embedding_type in ROPE_VARIANTS:
            self.act = ACT2FN_rope[act_fn_name]
        else:
            self.act = ACT2FN[act_fn_name]
        self.output_layer = nn.Linear(h_dim, out_dim)
        self.residual_layer = MultiInSizeLinear(in_dim_ls, out_dim)

        self.use_layer_norm = use_layer_norm
        if use_layer_norm:
            if self.position_embedding_type in ROPE_VARIANTS:
                self.layer_norm = T5LayerNorm_rope(out_dim)
            else:
                self.layer_norm = T5LayerNorm(out_dim)

    def forward(self, x: torch.Tensor, in_feat_size: torch.Tensor, expert_weights: Optional[torch.Tensor] = None,
                expert_indices: Optional[torch.Tensor] = None, x_final: Optional[torch.Tensor] = None):
        hid = self.act(self.hidden_layer(x, in_feat_size, expert_weights, expert_indices, x_final))
        out = self.dropout(self.output_layer(hid))
        res = self.residual_layer(x, in_feat_size, expert_weights, expert_indices, x_final)

        out = out + res

        if self.use_layer_norm:
            return self.layer_norm(out)
        return out


class MultiInSizeLinear(nn.Module):
    def __init__(
            self,
            in_features_ls: Tuple[int, ...],
            out_features: int,
            bias: bool = True,
            dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.in_features_ls = in_features_ls
        self.out_features = out_features

        self.weight = nn.Parameter(
            torch.empty((len(in_features_ls), out_features, max(in_features_ls) * 2), dtype=dtype)
        )

        if bias:
            self.bias = nn.Parameter(torch.empty((len(in_features_ls), out_features), dtype=dtype))
        else:
            self.register_parameter("bias", None)

        self.register_buffer(
            "mask",
            rearrange(
                torch.cat(
                    (
                        size_to_mask(max(in_features_ls), torch.as_tensor(in_features_ls)),
                        size_to_mask(max(in_features_ls), torch.as_tensor(in_features_ls)),
                    ),
                    dim=-1,
                ),
                "num_feats max_feat -> num_feats 1 max_feat",
            ),
            persistent=False,
        )
        self.register_buffer(
            "in_features_buffer",
            torch.tensor(in_features_ls),
            persistent=False,
        )

    def reset_parameters(self):
        for idx, feat_size in enumerate(self.in_features_ls):
            nn.init.kaiming_uniform_(self.weight[idx, :, :feat_size], a=math.sqrt(5))
            nn.init.zeros_(self.weight[idx, :, feat_size:])
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[idx, :, :feat_size])
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                nn.init.uniform_(self.bias[idx], -bound, bound)

    def forward(
            self,
            x: Float[torch.Tensor, "*batch max_feat"],
            in_feat_size: Int[torch.Tensor, "*batch"],
            expert_weights: Optional[torch.Tensor] = None,
            expert_indices: Optional[torch.Tensor] = None,  # [*batch, n_experts]
            x_final: Optional[torch.Tensor] = None,
    ) -> Float[torch.Tensor, "*batch out_feat"]:
        if expert_indices is not None:

            batch_shape = x.shape[:-1]
            x = x.view(-1, x.size(-1))
            x_final = torch.nan_to_num(x_final, nan=0.0)

            out = torch.zeros((x.size(0), self.out_features), device=x.device, dtype=x.dtype)
            expert_weights = expert_weights.view(-1, expert_weights.size(-1))  # [total_patches, n_experts]
            expert_indices = expert_indices.view(-1, expert_indices.size(-1))  # [total_patches, n_experts]
            n_real_experts = len(self.in_features_ls)

            for k in range(expert_indices.size(1)):
                indices_k = expert_indices[:, k]
                weights_k = expert_weights[:, k]

                is_real_expert_mask = (indices_k < n_real_experts)

                if not is_real_expert_mask.any():
                    continue

                for feat_idx, feat_size in enumerate(self.in_features_ls):
                    mask = (indices_k == feat_idx) & is_real_expert_mask
                    if not mask.any():
                        continue

                    weight = self.weight[feat_idx] * self.mask[feat_idx]
                    bias = self.bias[feat_idx] if self.bias is not None else 0
                    x_masked = x_final[feat_idx][mask]

                    expert_out = einsum(weight, x_masked, "out inp, ... inp -> ... out") + bias

                    real_expert_weights_sum = (expert_weights * (expert_indices < n_real_experts).float()).sum(dim=-1, keepdim=True)
                    real_expert_weights_sum[real_expert_weights_sum == 0] = 1.0

                    weights_k_norm = weights_k / real_expert_weights_sum.view(-1)

                    weighted_out = expert_out * weights_k_norm[mask].unsqueeze(-1)
                    
                    out[mask] += weighted_out

            return out.view(*batch_shape, self.out_features)
        out = torch.tensor(0)
        # x: [256, 163, 32 * 2]
        for idx, feat_size in enumerate(self.in_features_ls):
            # self.weight: [3, 2048, 32 * 2]
            # self.mask: [3, 1, 32 * 2]
            weight = self.weight[idx] * self.mask[idx]
            bias = self.bias[idx] if self.bias is not None else 0
            out = out + (
                    torch.eq(in_feat_size, feat_size).unsqueeze(-1)
                    * (einsum(weight, x, "out inp, ... inp -> ... out") + bias)
            )
        # expert_weights: [256, 163]
        # out: [256, 163, 2048]
        if expert_weights is not None:
            out = expert_weights.unsqueeze(-1) * out
        return out

    def extra_repr(self) -> str:
        return (
            f"in_features_ls={self.in_features_ls}, "
            f"out_features={self.out_features}, "
            f"bias={self.bias is not None}, "
            f"dtype={self.weight.dtype}"
        )


@dataclass
class ModelArgs:
    def __init__(self):
        self.dim = 4
        self.n_real_experts = 3
        self.n_null_experts = 1
        self.n_routed_experts = self.n_real_experts + self.n_null_experts
        self.n_activated_experts = 1
        self.moe_inter_dim = 1408
        self.update_bias_rate = 0.001
        self.target_dist = None
        self.route_scale = 1.0


class Gate(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        self.register_buffer('bias', torch.zeros(args.n_routed_experts, dtype=torch.float32))
        self.update_bias_rate = args.update_bias_rate
        self.target_dist = args.target_dist
        self.route_scale = args.route_scale
        if self.target_dist is not None:
            if isinstance(self.target_dist, float):
                self.target_dist = [self.target_dist]
            assert abs(sum(self.target_dist) - 1.0) < 1e-10
            self.target_dist = torch.tensor(args.target_dist)
        if self.training:
            self.reset_parameters()

    def reset_parameters(self) -> None:
        import torch.nn.init as init
        import math
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x: torch.Tensor):
        scores = F.linear(torch.nan_to_num(x, nan=0.0), self.weight)
        scores = scores + self.bias
        scores = scores.softmax(dim=-1, dtype=torch.float32)
        original_scores: torch.Tensor = scores

        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale

        flatten_indices = indices.view(-1)
        flatten_weights = weights.view(-1)
        expert_weights_sum = torch.bincount(flatten_indices, weights=flatten_weights, minlength=self.bias.size(0))
        total_weights = expert_weights_sum.sum()
        target_dist = self.target_dist.to(device=x.device, dtype=expert_weights_sum.dtype)
        expected_weights_sum = (target_dist * total_weights).to(x.device)
        load_error = expected_weights_sum - expert_weights_sum
        with torch.no_grad():
            self.bias += self.update_bias_rate * (load_error / total_weights)

        return weights.type_as(x), indices


class MoE(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_activated_experts = args.n_activated_experts
        self.n_real_experts = args.n_real_experts
        self.n_null_experts = args.n_null_experts
        self.gate = Gate(args)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        return weights, indices


class KairosModel(T5PreTrainedModel):
    config_class = KairosConfig
    _keys_to_ignore_on_load_missing = [
        r"input_patch_embedding\.",
        r"output_patch_embedding\.",
    ]
    _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]


    def __init__(self, config: KairosConfig):
        super().__init__(config)
        self.model_dim = config.d_model
        self.config = config
        self.num_segments = getattr(self.config, "num_decoder_segments", 1)
        if self.config.use_reg_token:
            config.reg_token_id = 1
        config.vocab_size = 2 if self.config.use_reg_token else 1
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        self.instance_norm = InstanceNorm()
        
        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        if self.config.position_embedding_type in ROPE_VARIANTS:
            self.encoder = T5Stack_rope(encoder_config, self.shared)
        else:
            self.encoder = T5Stack(encoder_config, self.shared)

        self._init_decoder(config)

        self.num_quantiles = len(self.config.quantiles)
        quantiles = torch.tensor(self.config.quantiles, dtype=self.dtype)
        self.register_buffer("quantiles", quantiles, persistent=False)

        self.output_patch_embedding = ResidualBlock(
            in_dim=config.d_model,
            h_dim=config.d_ff,
            out_dim=self.num_quantiles * self.config.prediction_length,
            act_fn_name=config.dense_act_fn,
            dropout_p=config.dropout_rate,
            position_embedding_type=self.config.position_embedding_type
        )

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # FFT/ACF feature normalization
        if self.config.position_embedding_type in ROPE_VARIANTS:
            feature_type = getattr(self.config, 'feature_type', 'fft')  # Default to fft
            if feature_type == 'fft':
                self.fft_norm = nn.LayerNorm(self.config.instance_rope_input_feature_dim)
            elif feature_type == 'acf':
                self.acf_norm = nn.LayerNorm(self.config.instance_rope_input_feature_dim)

        self.patch = DynamicPatch(
            max_patch_size=config.input_patch_size,
            patch_stride=config.input_patch_stride,
            levels=config.levels,
            n_null_experts=config.n_null_experts,
            n_activated_experts=config.n_activated_experts,
            moe_inter_dim=config.moe_inter_dim,
            update_bias_rate=config.update_bias_rate,
            target_dist=config.target_dist,
            route_scale=config.route_scale,
        )
        in_dim_ls = [config.input_patch_size]
        current_patch_size = config.input_patch_size
        for _ in range(config.levels - 1):
            assert current_patch_size % 2 == 0
            current_patch_size = current_patch_size // 2
            in_dim_ls.append(current_patch_size)
        self.input_patch_embedding = MultiInResidualBlock(
            in_dim_ls=tuple(in_dim_ls),
            h_dim=config.d_ff,
            out_dim=config.d_model,
            act_fn_name=config.dense_act_fn,
            dropout_p=config.dropout_rate,
            position_embedding_type=self.config.position_embedding_type
        )
        self.loss_weight_scheme = config.loss_weight_scheme

        self.post_init()

    def _init_weights(self, module):
        super()._init_weights(module)
        factor = self.config.initializer_factor
        if isinstance(module, (self.__class__)):
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
        elif isinstance(module, ResidualBlock):
            module.hidden_layer.weight.data.normal_(
                mean=0.0,
                std=factor * ((self.config.input_patch_size * 2) ** -0.5),
            )
            if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
                module.hidden_layer.bias.data.zero_()

            module.residual_layer.weight.data.normal_(
                mean=0.0,
                std=factor * ((self.config.input_patch_size * 2) ** -0.5),
            )
            if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
                module.residual_layer.bias.data.zero_()

            module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
                module.output_layer.bias.data.zero_()
        elif isinstance(module, MultiInResidualBlock):
            module.hidden_layer.reset_parameters()
            module.residual_layer.reset_parameters()

            module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
                module.output_layer.bias.data.zero_()

    def _init_decoder(self, config):
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        if self.config.position_embedding_type in ROPE_VARIANTS:
            self.decoder = T5Stack_rope(decoder_config, self.shared)
        else:
            self.decoder = T5Stack(decoder_config, self.shared)

    def fft_process(self, context: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Process the input context with FFT, masked appropriately.

        Args:
            context: Input time series data, shape (batch_size, seq_length)
            mask: Mask indicating valid values, shape (batch_size, seq_length)

        Returns:
            Normalized FFT features, shape (batch_size, self.config.s_feature)
        """
        # Apply mask to context (replace invalid values with zeros)
        masked_context = torch.where(mask.bool(), context, torch.zeros_like(context))

        # Apply FFT along the sequence dimension
        fft_result = torch.fft.rfft(masked_context, dim=-1)

        # Get amplitude spectrum (absolute values)
        fft_amplitude = torch.abs(fft_result)

        # Use LayerNorm instead of manual normalization
        s_feature = self.config.instance_rope_input_feature_dim
        fft_length = fft_amplitude.shape[-1]

        if fft_length >= s_feature:
            # Truncate to get the first s_feature components
            fft_features = fft_amplitude[..., :s_feature]
        else:
            # Pad with zeros if needed
            padding = torch.zeros((fft_amplitude.shape[0], s_feature - fft_length),
                                  device=fft_amplitude.device, dtype=fft_amplitude.dtype)
            fft_features = torch.cat([fft_amplitude, padding], dim=-1)

        # Apply LayerNorm for normalization
        fft_normalized = self.fft_norm(fft_features)

        return fft_normalized


    def encode(
            self, context: torch.Tensor, mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
        mask = mask.to(context.dtype) if mask is not None else torch.isnan(context).logical_not().to(context.dtype)

        batch_size, _ = context.shape
        if context.shape[-1] > self.config.context_length:
            context = context[..., -self.config.context_length:]
            mask = mask[..., -self.config.context_length:]

        # scaling
        context = torch.where(mask > 0.0, context, torch.nan)
        context, loc_scale = self.instance_norm(context)
        s_features = None
        if self.config.position_embedding_type in ROPE_VARIANTS:
            feature_type = getattr(self.config, 'feature_type', 'fft')  # Default to fft
            if feature_type == 'fft':
                s_features = self.fft_process(context, mask)
            elif feature_type == 'acf':
                s_features = self.acf_process(context, mask)
        # the scaling op above is done in 32-bit precision,
        # then the context is moved to model's dtype
        context = context.to(self.dtype)
        mask = mask.to(self.dtype)

        # patching
        patched_context, patched_mask, size, expert_weights, expert_indices, x_final = self.patch(context, mask)
        patched_mask = torch.nan_to_num(patched_mask, nan=0.0)
        patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
        # concat context and mask along patch dim
        patched_context = torch.cat([patched_context, patched_mask], dim=-1)

        # attention_mask = 1 if at least one item in the patch is observed
        attention_mask = patched_mask.sum(dim=-1) > 0  # (batch_size, patched_seq_length)

        input_embeds = self.input_patch_embedding(patched_context, size, expert_weights, expert_indices, x_final)

        if self.config.use_reg_token:
            # Append [REG]
            reg_input_ids = torch.full(
                (batch_size, 1),
                self.config.reg_token_id,
                device=input_embeds.device,
            )
            reg_embeds = self.shared(reg_input_ids)
            input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
            attention_mask = torch.cat(
                [
                    attention_mask.to(self.dtype),
                    torch.ones_like(reg_input_ids).to(self.dtype),
                ],
                dim=-1,
            )

        if s_features is not None:
            encoder_outputs = self.encoder(
                attention_mask=attention_mask,
                inputs_embeds=input_embeds,
                return_dict=True,
                output_attentions=output_attentions,
                s_features=s_features,
                size=size
            )
        else:
            encoder_outputs = self.encoder(
                attention_mask=attention_mask,
                inputs_embeds=input_embeds,
                return_dict=True,
                output_attentions=output_attentions,
                size=size
            )

        return encoder_outputs.last_hidden_state, loc_scale, input_embeds, attention_mask, encoder_outputs.attentions, s_features

    def forward(
            self,
            past_target: torch.Tensor,
            past_is_pad: Optional[torch.Tensor] = None,
            past_observed_values: Optional[torch.Tensor] = None,
            future_target: Optional[torch.Tensor] = None,
            future_is_pad: Optional[torch.Tensor] = None,
            future_observed_values: Optional[torch.Tensor] = None,
            generation: bool = False,
            prediction_length: Optional[int] = None,
            output_attentions: bool = False,
            infer_is_positive: Optional[bool] = False,
            force_flip_invariance: Optional[bool] = False,
            *args,
            **kwargs,
    ) -> KairosOutput:
        if generation:
            return self.generate(
                past_target=past_target,
                past_is_pad=past_is_pad,
                past_observed_values=past_observed_values,
                prediction_length=prediction_length,
                infer_is_positive=infer_is_positive,
                force_flip_invariance=force_flip_invariance,
            )
        batch_size = past_target.size(0)
        # Generate masks based on padding and observed values
        mask = ~torch.isnan(past_target)
        if past_is_pad is not None:
            mask = mask & ~past_is_pad.bool()

        if past_observed_values is not None:
            mask = mask & past_observed_values.bool()

        future_mask = ~torch.isnan(future_target) if future_target is not None else None
        if future_is_pad is not None and future_mask is not None:
            future_mask = future_mask & ~future_is_pad.bool()

        if future_observed_values is not None and future_mask is not None:
            future_mask = future_mask & future_observed_values.bool()

        (
            hidden_states,
            loc_scale,
            input_embeds,
            attention_mask,
            attention_scores,
            s_features
        ) = self.encode(context=past_target, mask=mask, output_attentions=output_attentions)

        sequence_output, cross_attention_scores = self.decode(
            input_embeds, attention_mask, hidden_states, output_attentions=output_attentions, s_features=s_features
        )

        total_prediction_length = self.num_segments * self.config.prediction_length
        # [B, N, num_quantiles * prediction_length]
        raw_preds = self.output_patch_embedding(sequence_output)

        # [B, N, Q * L] -> [B, N, Q, L]
        raw_preds_reshaped = raw_preds.view(
            batch_size,
            self.num_segments,
            self.num_quantiles,
            self.config.prediction_length
        )

        # [B, N, Q, L] -> [B, Q, N, L]
        raw_preds_permuted = raw_preds_reshaped.permute(0, 2, 1, 3)

        # [B, Q, N*L]
        quantile_preds_shape = (
            batch_size,
            self.num_quantiles,
            total_prediction_length,
        )
        quantile_preds = raw_preds_permuted.contiguous().view(*quantile_preds_shape)

        loss = None
        if future_target is not None:
            # normalize target
            future_target, _ = self.instance_norm(future_target, loc_scale)
            future_target = future_target.unsqueeze(1)  # type: ignore
            assert self.num_segments * self.config.prediction_length >= future_target.shape[-1]

            future_target = future_target.to(quantile_preds.device)
            future_mask = (
                future_mask.unsqueeze(1).to(quantile_preds.device) & ~torch.isnan(future_target)
                if future_mask is not None
                else ~torch.isnan(future_target)
            )
            future_target[~future_mask] = 0.0

            # pad target and target_mask if they are shorter than model's prediction_length
            if total_prediction_length > future_target.shape[-1]:
                padding_shape = (
                    *future_target.shape[:-1],
                    total_prediction_length - future_target.shape[-1],
                )
                future_target = torch.cat([future_target, torch.zeros(padding_shape).to(future_target)], dim=-1)
                future_mask = torch.cat([future_mask, torch.zeros(padding_shape).to(future_mask)], dim=-1)

            loss = (
                2
                * torch.abs(
                    (future_target - quantile_preds)
                    * ((future_target <= quantile_preds).float() - self.quantiles.view(1, self.num_quantiles, 1))
                )
                * future_mask.float()
            )
            # quantile_preds: [B, Q, N*L]

            loss = loss.nanmean(dim=-2)  # Mean over quantile levels

            if self.config.loss_weight_scheme == 'log_decay':
                weights = get_log_decay_weights(total_prediction_length, device=loss.device)
                loss = loss * weights
            loss = loss.nansum(dim=-1)  # Sum over prediction horizon

        # Unscale predictions
        quantile_preds = self.instance_norm.inverse(
            quantile_preds.view(batch_size, -1),
            loc_scale,
        ).view(*quantile_preds_shape)
        assert not torch.isnan(quantile_preds).any(), (
            f"{past_target[torch.isnan(quantile_preds).any(dim=-1).any(dim=-1)]}"
        )
        return KairosOutput(
            loss=loss,
            prediction_outputs=quantile_preds,
            attentions=attention_scores if output_attentions else None,
            cross_attentions=cross_attention_scores if output_attentions else None,
            future_target=future_target,
        )

    def _autoregressive_generate(
        self,
        past_target: torch.Tensor,
        past_is_pad: Optional[torch.Tensor],
        past_observed_values: Optional[torch.Tensor],
        prediction_length: int,
    ) -> torch.Tensor:
        central_idx = torch.abs(self.quantiles.clone().detach() - 0.5).argmin()
        max_pred_len_per_step = self.num_segments * self.config.prediction_length
        output = self(
            past_target=past_target,
            past_is_pad=past_is_pad,
            past_observed_values=past_observed_values,
            prediction_length=min(prediction_length, max_pred_len_per_step),
        )
        predictions_list = [output.prediction_outputs]
        remaining = prediction_length - max_pred_len_per_step
        while remaining > 0:
            current_prediction_chunk = predictions_list[-1]
            
            past_target = torch.cat([past_target, current_prediction_chunk[:, central_idx]], dim=-1)
            
            if past_observed_values is not None:
                past_observed_values = torch.cat(
                    [past_observed_values, torch.ones_like(current_prediction_chunk[:, central_idx])], dim=-1
                )
            if past_is_pad is not None:
                past_is_pad = torch.cat(
                    [past_is_pad, torch.zeros_like(current_prediction_chunk[:, central_idx])], dim=-1
                )
            output = self(
                past_target=past_target,
                past_is_pad=past_is_pad,
                past_observed_values=past_observed_values,
                prediction_length=min(remaining, max_pred_len_per_step),
            )
            predictions_list.append(output.prediction_outputs)
            remaining -= max_pred_len_per_step
        
        prediction = torch.cat(predictions_list, dim=-1)

        return prediction[:, :, :prediction_length]
    def generate(
        self,
        past_target: torch.Tensor,
        past_is_pad: Optional[torch.Tensor] = None,
        past_observed_values: Optional[torch.Tensor] = None,
        prediction_length: Optional[int] = None,
        infer_is_positive: Optional[bool] = False,
        force_flip_invariance: Optional[bool] = False,
    ) -> KairosOutput:
        
        if prediction_length is None:
            prediction_length = self.num_segments * self.config.prediction_length
        max_supported_len = self.num_segments * self.config.prediction_length
        if prediction_length > max_supported_len and not self.training:
            warnings.warn(
                f"Prediction length {prediction_length} is greater than the model's prediction length {max_supported_len}. "
            )
        is_positive_mask = None
        if infer_is_positive:
            is_positive_mask = (~torch.any(past_target < 0, dim=1)).view(-1, 1, 1)
        if force_flip_invariance:
            pred_original = self._autoregressive_generate(
                past_target, past_is_pad, past_observed_values, prediction_length
            )
            pred_flipped = self._autoregressive_generate(
                -past_target, past_is_pad, past_observed_values, prediction_length
            )
            
            pred_flipped_corrected = -torch.flip(pred_flipped, dims=[1])
            prediction = (pred_original + pred_flipped_corrected) / 2
        else:
            prediction = self._autoregressive_generate(
                past_target, past_is_pad, past_observed_values, prediction_length
            )
        
        if is_positive_mask is not None:
            prediction = torch.where(
                is_positive_mask,
                torch.maximum(prediction, torch.tensor(0.0, device=prediction.device)),
                prediction,
            )
        return KairosOutput(
            prediction_outputs=prediction,
        )

    def _init_decoder(self, config):
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        if self.config.position_embedding_type in ROPE_VARIANTS:
            self.decoder = T5Stack_rope(decoder_config, self.shared)
        else:
            self.decoder = T5Stack(decoder_config, self.shared)

    def decode(
            self,
            input_embeds,
            attention_mask,
            hidden_states,
            output_attentions=False,
            s_features: Optional[torch.Tensor] = None,
    ):
        """
        Parameters
        ----------
        input_embeds: torch.Tensor
            Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
        attention_mask: torch.Tensor
            Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
        hidden_states: torch.Tensor
            Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)

        Returns
        -------
        last_hidden_state
            Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
        """
        batch_size = input_embeds.shape[0]
        if self.config.diff_decoder_token_id:
            if self.num_segments == 1:
                custom_ids = [0]
            else:
                custom_ids = [0] + list(range(2, self.num_segments + 1))
            assert len(custom_ids) == self.num_segments
            input_sequence = torch.tensor(custom_ids, device=input_embeds.device, dtype=torch.long)
            decoder_input_ids = input_sequence.unsqueeze(0).expand(batch_size, -1)
        else:
            decoder_input_ids = torch.full(
                (batch_size, self.num_segments),
                self.config.decoder_start_token_id,
                device=input_embeds.device,
            )
        if s_features is not None:
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=hidden_states,
                encoder_attention_mask=attention_mask,
                output_attentions=output_attentions,
                return_dict=True,
                s_features=s_features
            )
        else:
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=hidden_states,
                encoder_attention_mask=attention_mask,
                output_attentions=output_attentions,
                return_dict=True
            )

        return decoder_outputs.last_hidden_state, decoder_outputs.cross_attentions  # sequence_outputs, b x 1 x d_model