"""
Simplified dynamic segmenter
Does not depend on SeriesDecomposition and WaveletTransformer
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List, Tuple, Optional


class SimpleDynamicSegmenter(nn.Module):

    
    def __init__(self, args):
        super().__init__()
        self.d_model = args.d_model
        self.max_len = args.max_len
        self.desired_threshold = args.desired_threshold
        self.fixed_max_segments = 64
        self.fixed_max_len = args.fixed_max_len
        self.segment_mask_top_k_ratio = getattr(args, 'segment_mask_top_k_ratio', 0.3)
        self.pos_embed = nn.Parameter(torch.randn(1, 167, args.d_model))
        # Simple attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=self.d_model,
            num_heads=8,
            batch_first=True
        )
        
        # Segment decision network
        self.segment_decision = nn.Sequential(
            nn.Linear(self.d_model, self.d_model // 2),
            nn.ReLU(),
            nn.Linear(self.d_model // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, seq: torch.Tensor, times: torch.Tensor = None):

        bs, t, d = seq.shape
        device = seq.device
        t = seq.size(1)
        seq = seq + self.pos_embed[:, :t, :]
        # print(seq[0, :5, :5])


        seq = F.layer_norm(seq, seq.shape[-1:])
        attn_output, attn_weights = self.attention(seq, seq, seq)
        # print(attn_weights.shape)  # [bs, num_heads, t, t]
        # print(attn_weights[0, :5, :5])

        segment_probs = self.segment_decision(attn_output)  # [bs, t, 1]
        segment_probs = segment_probs.squeeze(-1)  # [bs, t]
        

        segment_info, seg_seq, seg_mask = self._generate_segments(
            seq, segment_probs, device
        )
        
        return segment_info, seg_seq, seg_mask
    
    def _generate_segments(self, seq: torch.Tensor, segment_probs: torch.Tensor, device: torch.device):

        bs, t, d = seq.shape
        min_len = max(2, int(0.05 * t))  
    
        diffs = torch.abs(segment_probs[:, 1:] - segment_probs[:, :-1])  # [bs, t-1]
    
        smoothed = F.avg_pool1d(diffs.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)  # [bs, t-1]

        threshold = self.desired_threshold * smoothed.mean()
        candidate_points = (smoothed > threshold).float()
    
        segment_info, all_segments, all_masks = [], [], []
    
        for b in range(bs):
            points = torch.where(candidate_points[b] > 0)[0] + 1  # +1 to align with difference indices
    
            if len(points) == 0:
                points = torch.linspace(0, t-1, self.fixed_max_segments+1, device=device).long()[1:-1]
            points = torch.cat([torch.tensor([0], device=device), points, torch.tensor([t], device=device)])
            points = torch.unique(points).sort()[0]
            filtered = [points[0].item()]
            for p in points[1:]:
                if p - filtered[-1] >= min_len:
                    filtered.append(p.item())
            points = torch.tensor(filtered, device=device)

            if len(points) > self.fixed_max_segments + 1:
                importance = smoothed[b][points[1:-1]-1]
                top_k = min(self.fixed_max_segments - 1, len(importance))
                _, top_indices = torch.topk(importance, top_k)
                selected_points = points[1:-1][top_indices]
                points = torch.cat([torch.tensor([0], device=device), selected_points, torch.tensor([t], device=device)])
                points = torch.unique(points).sort()[0]
    
            segments, seg_info = [], []
            for i in range(len(points) - 1):
                start, end = points[i].item(), points[i + 1].item()
                segment = seq[b, start:end]  # [seg_len, d]
    
                # pad/truncate
                seg_len = end - start
                if seg_len < self.fixed_max_len:
                    padding = self.fixed_max_len - seg_len
                    segment = F.pad(segment, (0, 0, 0, padding))
                elif seg_len > self.fixed_max_len:
                    segment = segment[:self.fixed_max_len]
    
                segments.append(segment)
                seg_info.append({
                    'batch_id': b,
                    'seg_id': i,
                    'start': start,
                    'end': end,
                    'attention_mean': segment_probs[b, start:end].mean().item()
                })
    
            # Fill up segment count
            while len(segments) < self.fixed_max_segments:
                empty_segment = torch.zeros(self.fixed_max_len, d, device=device)
                segments.append(empty_segment)
                seg_info.append({
                    'batch_id': b,
                    'seg_id': len(segments) - 1,
                    'start': 0,
                    'end': 0,
                    'attention_mean': 0.0
                })
    
            segments = segments[:self.fixed_max_segments]
            seg_info = seg_info[:self.fixed_max_segments]
    
            all_segments.append(torch.stack(segments))
            all_masks.append(torch.ones(self.fixed_max_segments, device=device))
            segment_info.extend(seg_info)
    
        seg_seq = torch.stack(all_segments)  # [bs, max_segments, max_len, d]
        seg_mask = torch.stack(all_masks)    # [bs, max_segments]
    
        return segment_info, seg_seq, seg_mask



