# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Union

# -----
def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

ImageEncoder = _try_import("..layers.image_encoder", "ImageEncoder") or _try_import("model.layers.image_encoder", "ImageEncoder")
TextEncoder  = _try_import("..layers.text_encoder",  "TextEncoder")  or _try_import("model.layers.text_encoder",  "TextEncoder")
TimeSeriesEncoder = _try_import("..layers.ts_encoder", "TimeSeriesEncoder") or _try_import("model.layers.ts_encoder", "TimeSeriesEncoder")
StaticEncoder = _try_import("..layers.st_encoder", "StaticEncoder") or _try_import("model.layers.st_encoder", "StaticEncoder")
TemperatureScaledAttention = _try_import("..layers.temperature_attention", "TemperatureScaledAttention") or _try_import("model.layers.temperature_attention", "TemperatureScaledAttention")


if TemperatureScaledAttention is None:
    class TemperatureScaledAttention(nn.Module):
        def __init__(self, d, nhead=8, dropout=0.1, batch_first=True, init_tau=1.0, learnable=True):
            super().__init__()
            self.attn = nn.MultiheadAttention(d, nhead, dropout=dropout, batch_first=batch_first)
            self.tau = nn.Parameter(torch.tensor(float(init_tau))) if learnable else torch.tensor(float(init_tau))
        def forward(self, q, k, v, key_padding_mask=None, attn_mask=None):
            out, _ = self.attn(q/self.tau, k/self.tau, v/self.tau, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return out


class FeatureExtractorAndFusion(nn.Module):

    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = self.config.get("shared_dim", 128)

        # ---- Ablation Control Flags ----
        self.use_multimodal = self.config.get("use_multimodal", True)
        self.use_image = self.config.get("use_image", True)
        self.use_text = self.config.get("use_text", True)
        self.use_static = self.config.get("use_static", True)

        # ---- Encoders ----
        # Image encoder -
        if ImageEncoder:
            self.image_encoder = ImageEncoder(out_dim=self.config.get("image_dim", 128))
            self.img_norm = nn.LayerNorm(self.config.get("image_dim", 128))
        else:
            self.image_encoder = nn.Identity()
            self.img_norm = nn.Identity()

        # Text encoder -
        if TextEncoder:
            self.text_encoder = TextEncoder()
            self.text_proj = nn.Linear(768, self.config.get("text_dim", 128))
            self.txt_norm = nn.LayerNorm(self.config.get("text_dim", 128))
        else:
            self.text_encoder = nn.Identity()

            self.text_proj = nn.Linear(768, self.config.get("text_dim", 128))
            self.txt_norm = nn.LayerNorm(self.config.get("text_dim", 128))

        self.ts_encoder = TimeSeriesEncoder(out_dim=self.config.get("ts_dim", 64)) if TimeSeriesEncoder else nn.Identity()
        self.ts_proj = nn.Linear(self.config.get("ts_dim", 64) if TimeSeriesEncoder else self.shared_dim, self.shared_dim)

        # Static encoder (conditional on use_static)
        if self.use_static and StaticEncoder:
            self.static_encoder = StaticEncoder(in_dim=self.config.get("static_dim", 16),
                                                out_dim=self.config.get("static_dim", 16))
            self.static_norm = nn.LayerNorm(self.config.get("static_dim", 16))
        else:
            self.static_encoder = nn.Identity()
            self.static_norm = nn.Identity()

        # ---- Cross-modal ----
        self.img_proj = nn.Linear(self.config.get("image_dim", 128), self.shared_dim)
        self.txt_proj_unified = nn.Linear(self.config.get("text_dim", 128), self.shared_dim)
        self.static_proj_unified = nn.Linear(self.config.get("static_dim", 16), self.shared_dim)

        self.img_txt_attn = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_img_txt", 1.0),
                                                       learnable=self.config.get("tau_img_txt_learnable", True))
        self.img_ts_attn  = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_img_ts", 0.1),
                                                       learnable=self.config.get("tau_img_ts_learnable", False))
        self.txt_ts_attn  = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_txt_ts", 0.5),
                                                       learnable=self.config.get("tau_txt_ts_learnable", True))

        self.fusion_attn = nn.MultiheadAttention(self.shared_dim, 8, 0.1, batch_first=True)
        self.fusion_out  = nn.Sequential(nn.Linear(self.shared_dim, self.shared_dim),
                                         nn.GELU(), nn.Dropout(0.1),
                                         nn.Linear(self.shared_dim, self.shared_dim))
        

        self.use_causal_pooling = self.config.get("use_causal_pooling", True)
        self.pooling_tau = self.config.get("pooling_tau", 5.0)

    # --------
    def extract_features(self,
                         images: Optional[torch.Tensor],
                         text_item: Optional[Dict[str, torch.Tensor]],
                         packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                         static_data: torch.Tensor):
        device = next(self.parameters()).device
        B = static_data.size(0)
    
        # -----------------------
        # image
        # -----------------------
        if self.use_multimodal and self.use_image and images is not None:
            if images.ndim == 5:  # [B,N,3,H,W]
                B, N, C, H, W = images.shape
                img_flat = images.view(B*N, C, H, W)
                if not isinstance(self.image_encoder, nn.Identity):
                    img_feat = self.image_encoder(img_flat)
                else:
                    img_feat = img_flat.mean(dim=[2,3])  # [B*N, C]
                    img_feat = torch.zeros(B*N, self.config.get("image_dim", 128), device=img_feat.device) + 0.01
                img_feat = img_feat.view(B, N, -1)
            else:  # [B,3,H,W]
                if not isinstance(self.image_encoder, nn.Identity):
                    img_feat = self.image_encoder(images)  # [B, D]
                else:
                    img_feat = images.mean(dim=[2,3])  # [B, C]
                    img_feat = torch.zeros(B, self.config.get("image_dim", 128), device=img_feat.device) + 0.01
                img_feat = img_feat.unsqueeze(1)  # [B,1,D]
            img_feat = self.img_norm(img_feat)
            img_tok = self.img_proj(img_feat)         # [B,Ni,D]
        elif not self.use_multimodal and not self.use_image:
            # no_multimodal
            if images is not None and images.ndim == 5:  # [B,N,3,H,W]
                B, N, C, H, W = images.shape

                zero_images = torch.zeros_like(images)
                img_flat = zero_images.view(B*N, C, H, W)
                img_feat = self.image_encoder(img_flat)
                img_feat = img_feat.view(B, N, -1)
            elif images is not None and images.ndim == 4:  # [B,3,H,W]

                zero_images = torch.zeros_like(images)
                img_feat = self.image_encoder(zero_images)  # [B, D]
                img_feat = img_feat.unsqueeze(1)  # [B,1,D]
            else:

                img_feat = torch.zeros(B, 1, self.config.get("image_dim", 128), device=device)
            img_feat = self.img_norm(img_feat)
            img_tok = self.img_proj(img_feat)         # [B,Ni,D]
        else:

            img_tok = torch.zeros(B, 1, self.shared_dim, device=device)
    
        # -----------------------
        # text
        # -----------------------
        if self.use_multimodal and self.use_text and (text_item is not None) and not isinstance(self.text_encoder, nn.Identity):
            try:
                txt_raw = self.text_encoder(text_item=text_item, preprocessed=True, return_dict=False)
            except TypeError:
                txt_raw = self.text_encoder(text_item)
            if txt_raw.dim() == 2:
                txt_raw = txt_raw.unsqueeze(1)   # [B,1,D]
            txt_feat = self.text_proj(txt_raw)
            txt_feat = self.txt_norm(txt_feat)
            txt_tok = self.txt_proj_unified(txt_feat) # [B,Nt,D]
        elif not self.use_multimodal and not self.use_text:
            # no_multimodal
            if text_item is not None:

                zero_text_item = {}
                for key, value in text_item.items():
                    if isinstance(value, torch.Tensor):
                        zero_text_item[key] = torch.zeros_like(value)
                    else:
                        zero_text_item[key] = value
                
                try:
                    txt_raw = self.text_encoder(text_item=zero_text_item, preprocessed=True, return_dict=False)
                except TypeError:
                    txt_raw = self.text_encoder(zero_text_item)
                if txt_raw.dim() == 2:
                    txt_raw = txt_raw.unsqueeze(1)   # [B,1,D]
            else:

                txt_raw = torch.zeros(B, 1, 768, device=device)
            txt_feat = self.text_proj(txt_raw)
            txt_feat = self.txt_norm(txt_feat)
            txt_tok = self.txt_proj_unified(txt_feat) # [B,Nt,D]
        else:

            txt_raw = torch.zeros(B, 1, 768, device=device)
            txt_feat = self.text_proj(txt_raw)
            txt_feat = self.txt_norm(txt_feat)
            txt_tok = self.txt_proj_unified(txt_feat) # [B,Nt,D]
    
        # -----------------------
        # time-series (mask-aware)
        # -----------------------
        if packed_ts is not None:
            if isinstance(packed_ts, dict) and "ts_data" in packed_ts:
                ts_raw = packed_ts["ts_data"].float()  # expect [B, N_ts, T, n, 2]
            else:
                ts_raw = packed_ts.float()
    
            if ts_raw.dim() == 5:  # [B, N_ts, T, n, 2]
                B, N_ts, T, n, two = ts_raw.shape
                assert two == 2, f"Expected last dim=2 (value,mask), got {two}"
    
                values = ts_raw[..., 0]         # [B, N_ts, T, n]
                masks  = ts_raw[..., 1]         # [B, N_ts, T, n], 1=
    
                # ----------

                values_filled = values.clone()
                for t in range(1, T):
                    missing = (masks[:, :, t, :] == 1) & (masks[:, :, t-1, :] == 0)
                    values_filled[:, :, t, :][missing] = values_filled[:, :, t-1, :][missing]

                values_filled = values_filled * (1 - masks)
    
                # ----------
                # mask_inverted: 1=
                mask_inverted = (1 - masks).float()
                ts_input = torch.cat([values_filled, mask_inverted], dim=-1)  # [B,N_ts,T,n*2]
    
                # ----------
                all_missing = (mask_inverted.sum(dim=(2,3)) == 0)  # [B,N_ts]
                if all_missing.any():
                    ts_input[all_missing] = 0.0
    
                # ----------
                # ts_flat = ts_input.view(B*N_ts, T, -1)
                if not isinstance(self.ts_encoder, nn.Identity):
                    ts_lat = self.ts_encoder(ts_input)  # [B, N_ts, T, ts_dim]
                    if torch.isnan(ts_lat).any() or torch.isinf(ts_lat).any():
                        print(f)
                        ts_lat = torch.where(torch.isnan(ts_lat) | torch.isinf(ts_lat),
                                             torch.zeros_like(ts_lat), ts_lat)
                else:
                    ts_lat = ts_input
    
                ts_lat = ts_lat.view(B, N_ts, T, -1)  # [B,N_ts,T,ts_dim]
    
            elif ts_raw.dim() == 4:  # [B, N_ts, T, F]
                B, N_ts, T, F = ts_raw.shape
                ts_flat = ts_raw.view(B*N_ts, T, F)
                ts_lat = self.ts_encoder(ts_flat) if not isinstance(self.ts_encoder, nn.Identity) else ts_flat
                ts_lat = ts_lat.view(B, N_ts, T, -1)
    
            elif ts_raw.dim() == 3:  # [B, T, F]
                ts_lat = self.ts_encoder(ts_raw) if not isinstance(self.ts_encoder, nn.Identity) else ts_raw
                ts_lat = ts_lat.unsqueeze(1)  # [B,1,T,ts_dim]
    
            else:
                raise ValueError(f"Unexpected time series shape: {ts_raw.shape}")
    
            ts_feat = self.ts_proj(ts_lat)  # [B,N_ts,T,D]
            ts_tok  = ts_feat
        else:
            ts_tok = None
    
        # -----------------------
        # static
        # -----------------------
        if self.use_static:
            st = self.static_encoder(static_data) if not isinstance(self.static_encoder, nn.Identity) else static_data
            st = self.static_norm(st)
            if st.dim() == 2:
                st = st.unsqueeze(1)  # [B,1,D]
        else:

            st = torch.zeros(B, 1, self.shared_dim, device=device)
        st_tok = self.static_proj_unified(st)  # [B,Ns,D]
    
        return img_tok, txt_tok, ts_tok, st_tok



    # --------
    def cross_modal_fusion(self, img_tok, txt_tok, ts_seg_tokens, static_vec):
        st_tok = static_vec.unsqueeze(1)  # [B,1,D] - static_vec is already projected to shared_dim
        

        tokens_list = []
        modality_names = []
        

        if img_tok is not None and not torch.allclose(img_tok, torch.zeros_like(img_tok)):
            tokens_list.append(img_tok)
            modality_names.append("image")
        

        if txt_tok is not None and not torch.allclose(txt_tok, torch.zeros_like(txt_tok)):
            tokens_list.append(txt_tok)
            modality_names.append("text")
        

        if ts_seg_tokens is not None and not torch.allclose(ts_seg_tokens, torch.zeros_like(ts_seg_tokens)):
            tokens_list.append(ts_seg_tokens)
            modality_names.append("timeseries")
        

        if st_tok is not None and not torch.allclose(st_tok, torch.zeros_like(st_tok)):
            tokens_list.append(st_tok)
            modality_names.append("static")
        
        if len(tokens_list) == 0:

            B = static_vec.size(0)
            return torch.zeros(B, self.shared_dim, device=static_vec.device)
        

        normalized_tokens = []
        for token, name in zip(tokens_list, modality_names):
            if token.dim() == 2:  # [B, D] -> [B, 1, D]
                token = token.unsqueeze(1)
            elif token.dim() == 3:  # [B, N, D] -
                pass
            elif token.dim() == 4:  # [B, N, L, D] -> [B, N*L, D]
                B, N, L, D = token.shape
                token = token.view(B, N*L, D)
            else:
                print(f)
                continue
            normalized_tokens.append(token)
        

        tokens = torch.cat(normalized_tokens, dim=1)  # [B, Nall, D]
        

        if len(tokens_list) == 1:
            return tokens.mean(dim=1)  # [B,D]
        

        fused, _ = self.fusion_attn(tokens, tokens, tokens)
        fused = fused + self.fusion_out(fused)
        

        fused_global = self._causal_weighted_pooling(fused)
        return fused_global
    
    def _causal_weighted_pooling(self, fused: torch.Tensor) -> torch.Tensor:

        if not self.use_causal_pooling:

            return fused.mean(dim=1)
        
        B, N, D = fused.shape
        device = fused.device
        
        if N == 1:

            return fused.squeeze(1)
        

        # w[i] = exp(-i/tau)，i=0
        indices = torch.arange(N, device=device, dtype=torch.float)
        weights = torch.exp(-indices / self.pooling_tau)
        

        weights = weights.flip(0)  # [N]
        

        weights = weights / weights.sum()
        

        weighted_fused = fused * weights.unsqueeze(0).unsqueeze(-1)  # [B, N, D]
        fused_global = weighted_fused.sum(dim=1)  # [B, D]
        
        return fused_global
