# -*- coding: utf-8 -*-
import os, math
from typing import Dict, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F

def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

DynamicSegmenter = _try_import("layers.dynamic_segmenter", "DynamicSegmenter") or _try_import("model.layers.dynamic_segmenter", "DynamicSegmenter")
SimpleDynamicSegmenter = _try_import("layers.simple_dynamic_segmenter", "SimpleDynamicSegmenter") or _try_import("model.layers.simple_dynamic_segmenter", "SimpleDynamicSegmenter")
GatedMultiscaleAttention = _try_import("layers.gma", "GatedMultiscaleAttention") or _try_import("model.layers.gma", "GatedMultiscaleAttention")
SegmentImportanceScorer  = _try_import("layers.segment_importance", "SegmentImportanceScorer") or _try_import("model.layers.segment_importance", "SegmentImportanceScorer")
compute_irr_bcr = _try_import("backbone.eamc_utils", "compute_irr_bcr") or _try_import("model.backbone.eamc_utils", "compute_irr_bcr")


class DynamicSegmentationModule(nn.Module):

    def __init__(self, shared_dim: int, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = shared_dim
        

        import os
        self.use_nbc = self.config.get('use_nbc', self._get_bool_env('USE_NBC', True))
        self.frequency_branch_enabled = self._get_bool_env('FREQUENCY_BRANCH_ENABLED', True)
        self.use_fft = self._get_bool_env('USE_FFT', True)
        self.use_wavelet = self._get_bool_env('USE_WAVELET', True)
        
        self.use_causal_frequency = self._get_bool_env('USE_CAUSAL_FREQUENCY', True)
        self.causal_frequency_config = {
            'use_stft': self.use_fft,
            'use_wavelet': self.use_wavelet,
            'stft_config': {
                'n_fft': self.config.get('fft_bins', 64),
                'hop_length': self.config.get('hop_length', 16),
                'win_length': self.config.get('win_length', 32),
                'window': 'hann'
            },
            'wavelet_config': {
                'wavelet_type': self.config.get('wavelet_type', 'db4'),
                'levels': self.config.get('wavelet_levels', 4)
            }
        }

        class SegConfig:
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        max_segments = min(self.config.get("max_segments", 8), self.config.get("node_budget", 8))
        
        seg_cfg = SegConfig(
            d_model=shared_dim, d_pe=shared_dim,
            max_len=self.config.get("max_seq_len", 96),  
            desired_threshold=self.config.get("desired_threshold", 0.8),
            fixed_max_segments=max_segments,  
            fixed_max_len=self.config.get("segment_len", 12),  
            segment_mask_top_k_ratio=self.config.get("segment_mask_top_k_ratio", 0.3),
            trimmer_types=self.config.get("trimmer_types", ["Decomposition"]),  
            num_heads=self.config.get("num_heads", 8),
        )
        
        self.dynamic_segmenter = None
        try:
            if SimpleDynamicSegmenter:
                self.dynamic_segmenter = SimpleDynamicSegmenter(seg_cfg)
            elif DynamicSegmenter:
                self.dynamic_segmenter = DynamicSegmenter(seg_cfg)
        except Exception as e:
            self.dynamic_segmenter = None

        gma_cfg = {
            "scales": [8,16,32,64], 
            "num_heads": 8, 
            "head_dim": 16,  
            "time_branch": {
                "enabled": True, 
                "filter_orders": [3,5,7,9],
                "conv_channels": shared_dim,  
                "output_dim": shared_dim
            },
            "frequency_branch": {
                "enabled": False,  
                "use_fft": False,
                "use_wavelet": False,
                "wavelet_levels": 2,
                "fft_bins": 64,
                "output_dim": shared_dim
            },
            "filter_bank": {
                "num_filters": 4,  
                "filter_sizes": [3,5,7,9], 
                "dilation_rates": [1,2,4,8],
                "output_dim": shared_dim
            },
            "gating": {
                "depth": 2,  
                "hidden_dim": shared_dim,
                "output_dim": shared_dim
            },
            "fusion": {
                "method": "mean",  
                "output_dim": shared_dim
            }
        }
        self.gma_fusion = nn.Identity()  

        self.segment_importance_scorer = SegmentImportanceScorer(
            feature_dim=shared_dim,
            scoring_method=self.config.get("eamc_scoring_method", "energy"),
            hidden_dim=self.config.get("eamc_hidden_dim", 128),
            num_heads=self.config.get("eamc_num_heads", 8),
        ) if SegmentImportanceScorer else None

        self.last_info: Dict = {}

    def _causal_frequency_transform(self, x: torch.Tensor) -> torch.Tensor:
        """
        """
        B, T, D = x.shape
        device = x.device
        
        if not self.use_causal_frequency:
            return x
        
        freq_features = []

        if self.causal_frequency_config['use_stft']:
            stft_feat = self._causal_stft(x)
            freq_features.append(stft_feat)
        
        if self.causal_frequency_config['use_wavelet']:
            wavelet_feat = self._causal_wavelet(x)
            freq_features.append(wavelet_feat)
        
        if not freq_features:
            return x
        
        if len(freq_features) == 1:
            return freq_features[0]
        else:
            min_length = min(feat.size(1) for feat in freq_features)
            adjusted_features = []
            for feat in freq_features:
                if feat.size(1) > min_length:
                    feat = feat[:, :min_length, :]
                adjusted_features.append(feat)
            return torch.cat(adjusted_features, dim=2)
    
    def _causal_stft(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        device = x.device
        
        stft_config = self.causal_frequency_config['stft_config']
        n_fft = stft_config['n_fft']
        hop_length = stft_config['hop_length']
        win_length = stft_config['win_length']
        
        n_frames = max(1, (T - win_length) // hop_length + 1)
                
        stft_output = torch.zeros(B, n_frames, D, device=device)
        
        for t in range(n_frames):
            start_idx = t * hop_length
            end_idx = min(start_idx + win_length, T)
            
            if end_idx - start_idx < win_length:
                windowed_x = torch.zeros(B, win_length, D, device=device)
                windowed_x[:, :end_idx-start_idx, :] = x[:, start_idx:end_idx, :]
            else:
                windowed_x = x[:, start_idx:end_idx, :]
            
            for d in range(D):
                signal = windowed_x[:, :, d]  # [B, win_length]
                fft_result = torch.fft.rfft(signal, n=n_fft, dim=1)
                fft_mag = torch.abs(fft_result).mean(dim=1) 
                stft_output[:, t, d] = fft_mag
        
        return stft_output
    
    def _causal_wavelet(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        device = x.device
        
        wavelet_config = self.causal_frequency_config['wavelet_config']
        levels = wavelet_config['levels']
        
        wavelet_coeffs = []
        current_signal = x
        
        for level in range(min(levels, 3)):  
            if current_signal.size(1) < 4:
                break
            
            if current_signal.size(1) % 2 == 1:
                current_signal = current_signal[:, :-1, :]
            
            low_coeffs = current_signal[:, ::2, :]
            high_coeffs = current_signal[:, 1::2, :] - current_signal[:, ::2, :]    
            
            wavelet_coeffs.append(high_coeffs)
            current_signal = low_coeffs
        
        wavelet_coeffs.append(current_signal)
        
        if wavelet_coeffs:
            min_length = min(coeff.size(1) for coeff in wavelet_coeffs)
            adjusted_coeffs = []
            for coeff in wavelet_coeffs:
                if coeff.size(1) > min_length:
                    coeff = coeff[:, :min_length, :]
                adjusted_coeffs.append(coeff)

            result = torch.cat(adjusted_coeffs, dim=2)  # [B, T', levels*D]
        else:
            result = torch.zeros(B, 1, D, device=device)
        
        return result

    def _get_bool_env(self, env_var: str, default: bool = True) -> bool:
        import os
        value = os.environ.get(env_var, '')
        if value.lower() in ('true', '1', 'yes', 'on'):
            return True
        elif value.lower() in ('false', '0', 'no', 'off'):
            return False
        else:
            return default

    def _keep_topB_merge(self, seg_tokens, scores, B):
        if torch.isnan(seg_tokens).any():
            print(f"[ERROR] _keep_topB_merge input seg_tokens contains NaN!")
            seg_tokens = torch.where(torch.isnan(seg_tokens), torch.zeros_like(seg_tokens), seg_tokens)
        
        if torch.isnan(scores).any():
            print(f"[ERROR] _keep_topB_merge input scores contains NaN!")
            scores = torch.where(torch.isnan(scores), torch.ones_like(scores), scores)
        
        with torch.no_grad():
            Bsz, S, D = seg_tokens.shape
            B = min(B, S)  
            
            scores = torch.where(torch.isfinite(scores), scores, torch.zeros_like(scores))
            
            topk = torch.topk(scores, k=B, dim=1, largest=True, sorted=True)
            keep = topk.indices  # [Bsz, B']
    
            maps = []
            for b in range(Bsz):
                klist = keep[b].tolist()
                ksorted = sorted(klist)
                mp = []
                for s in range(S):
                    if s in klist: 
                        mp.append(klist.index(s))
                    else:          
                        mp.append(min(range(len(ksorted)), key=lambda i: abs(ksorted[i]-s)))
                maps.append(torch.tensor(mp, device=seg_tokens.device, dtype=torch.long))  # [S]
            assign = torch.stack(maps, dim=0)  # [Bsz, S]
    
            Bp = keep.size(1)
            A = torch.zeros(Bsz, S, Bp, device=seg_tokens.device, dtype=torch.float32)
            A.scatter_(2, assign.unsqueeze(-1), 1.0)
    
            count = A.sum(dim=1, keepdim=True).clamp_min_(1e-8)  # [B,1,B']
    
        try:
            new_tokens = (A.detach().transpose(1,2) @ seg_tokens) / count.transpose(1,2)  # [B,B',D]
            
            if torch.isnan(new_tokens).any():
                print(f"[ERROR] _keep_topB_merge output contains NaN!")
                print(f"A stats: min={A.min():.6f}, max={A.max():.6f}, mean={A.mean():.6f}")
                print(f"count stats: min={count.min():.6f}, max={count.max():.6f}, mean={count.mean():.6f}")
                print(f"seg_tokens stats: min={seg_tokens.min():.6f}, max={seg_tokens.max():.6f}, mean={seg_tokens.mean():.6f}")

                new_tokens = seg_tokens.mean(dim=1, keepdim=True).expand(-1, B, -1)
            
        except Exception as e:
            print(f"[ERROR] _keep_topB_merge computation failed: {e}")
            new_tokens = seg_tokens.mean(dim=1, keepdim=True).expand(-1, B, -1)
            
        return new_tokens, assign, None, None


        # ===== Helper =====
    def _get_bool_env(self, env_var: str, default: bool = True) -> bool:
        value = os.environ.get(env_var, "")
        if value.lower() in ("true", "1", "yes", "on"):
            return True
        elif value.lower() in ("false", "0", "no", "off"):
            return False
        return default

    def _keep_topB_merge(self, seg_tokens, scores, B):
        Bsz, S, D = seg_tokens.shape
        B = min(B, S)
        topk = torch.topk(scores, k=B, dim=1, largest=True, sorted=True)
        keep = topk.indices  # [Bsz,B]
        maps = []
        for b in range(Bsz):
            idx = keep[b]
            mp = []
            for s in range(S):
                mp.append(int((idx - s).abs().argmin()))
            maps.append(torch.tensor(mp, device=seg_tokens.device))
        assign = torch.stack(maps, 0)  # [Bsz,S]
        A = F.one_hot(assign, num_classes=B).float()
        count = A.sum(1, keepdim=True).clamp_min(1e-6)
        new_tokens = (A.transpose(1,2) @ seg_tokens) / count.transpose(1,2)
        return new_tokens, assign

    def _compute_irr_bcr(self, ts_feat: torch.Tensor, assign_T: torch.Tensor):
        B, T, D = ts_feat.shape
        irr_list, bcr_list = [], []
        for b in range(B):
            rec = torch.zeros_like(ts_feat[b])
            for s in torch.unique(assign_T[b]):
                mask = (assign_T[b]==s)
                rec[mask] = ts_feat[b,mask].mean(0,keepdim=True)
            L_ae = torch.mean((ts_feat[b]-rec)**2).item()
            z_bar = ts_feat[b].mean(0,keepdim=True)
            var = torch.mean(((ts_feat[b]-z_bar)**2).sum(1)).item()
            irr = 1 - (L_ae/var) if var>1e-6 else 1.0
            irr_list.append(max(0,min(1,irr)))
            cut = (assign_T[b,1:]!=assign_T[b,:-1]).float().mean().item() if T>1 else 0
            bcr_list.append(cut)
        return {"IRR": sum(irr_list)/B, "BCR": sum(bcr_list)/B}

    # ===== Forward =====
    def forward(self, ts_feat: torch.Tensor, ts_lengths: torch.Tensor):
        B, T, D = ts_feat.shape
        device = ts_feat.device

        if self.dynamic_segmenter is not None:
            seg_out = self.dynamic_segmenter(ts_feat, times=None)
            seg_seq = seg_out[1] if isinstance(seg_out, tuple) else seg_out
        else:
            S = min(self.config.get("max_segments", 8), self.config.get("node_budget", 8))
            L = math.ceil(T/S)
            x = F.pad(ts_feat,(0,0,0,L*S-T)) if T<L*S else ts_feat[:,:L*S]
            seg_seq = x.view(B,S,L,D)

        B,S,L,D = seg_seq.shape
        seg_seq_3d = seg_seq.view(B*S,L,D)


        gma_out = self.gma_fusion(seg_seq_3d)
        seg_tokens = gma_out.mean(1) if gma_out.dim()==3 else gma_out
        seg_tokens = seg_tokens.view(B,S,-1)

        budget_override = getattr(self, '_node_budget_override', None)
        budget = int(budget_override) if budget_override is not None else self.config.get("node_budget", S)
        if self.use_nbc and S > budget:
            scores = (self.segment_importance_scorer(seg_tokens) 
                      if self.segment_importance_scorer else seg_tokens.norm(p=2,dim=-1))
            seg_tokens, assign = self._keep_topB_merge(seg_tokens, scores, budget)
            seg_seq = seg_tokens.unsqueeze(2) 
            S = budget
        else:
            assign = torch.arange(S, device=device).unsqueeze(0).expand(B,S)

        edges = torch.linspace(0, T, steps=S+1, device=device).long()
        assign_T = torch.zeros(B,T,dtype=torch.long,device=device)
        for b in range(B):
            for s in range(S):
                assign_T[b,edges[s]:edges[s+1]] = s
        irr_bcr = self._compute_irr_bcr(ts_feat, assign_T)

        self.last_info = {
            "S": torch.tensor([S]*B, device=device),
            "budget_B": int(budget),
            "IRR": irr_bcr["IRR"],
            "BCR": irr_bcr["BCR"],
            "assign_T": assign_T
        }
        return seg_tokens, seg_seq, self.last_info

