# -*- coding: utf-8 -*-
from typing import Dict, Optional, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..common.features_and_fusion import FeatureExtractorAndFusion
from ..modules.dynamic_segmentation import DynamicSegmentationModule
from ..modules.causal_graph import CausalGraphModuleLite

def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

MLPHead = _try_import("backbone.head", "MLPHead") or _try_import("model.backbone.head", "MLPHead")
AEBlock = _try_import("backbone.ae", "AEBlock") or _try_import("model.backbone.ae", "AEBlock")

class TCNBlock(nn.Module):
    """Temporal Convolutional Network block"""
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.1):
        super().__init__()


        self.padding = (kernel_size - 1) * dilation
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, 
                              dilation=dilation, padding=self.padding)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                              dilation=dilation, padding=self.padding)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.norm2 = nn.BatchNorm1d(out_channels)
        

        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        
    def forward(self, x):
        # x: [B, C, T]
        residual = x
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        

        if self.downsample is not None:
            residual = self.downsample(residual)
        

        if out.size(-1) > x.size(-1):
            out = out[:, :, :x.size(-1)]
        

        if residual.size(-1) != out.size(-1):
            residual = residual[:, :, :out.size(-1)]
        
        out += residual
        
        out = self.relu(out)
        return out

class PositionalEncoding(nn.Module):
    """Positional encoding for transformer"""
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        # x: [B, T, D]
        return x + self.pe[:, :x.size(1)]

class ForecastHead(nn.Module):
    """Forecast head for time series prediction"""
    def __init__(self, in_dim, horizon=24, hidden=128, n_layers=2, dropout=0.1, 
                 use_tcn=True, use_multiscale=True, use_positional_encoding=True,
                 use_time_bias=True):
        super().__init__()
        self.horizon = horizon
        self.hidden = hidden
        self.use_tcn = use_tcn
        self.use_multiscale = use_multiscale
        self.use_positional_encoding = use_positional_encoding
        self.use_time_bias = use_time_bias
        

        self.proj_in = nn.Linear(in_dim, hidden)
        

        self.horizon_queries = nn.Parameter(torch.randn(horizon, hidden))
        nn.init.normal_(self.horizon_queries, std=0.02)
        

        if use_positional_encoding:
            self.pos_encoding = PositionalEncoding(hidden, max_len=horizon*2)
        

        if use_tcn:
            # TCN
            self.tcn_layers = nn.ModuleList()
            in_channels = hidden
            for i in range(n_layers):
                dilation = 2 ** i
                self.tcn_layers.append(
                    TCNBlock(in_channels, hidden, kernel_size=3, 
                            dilation=dilation, dropout=dropout)
                )
                in_channels = hidden
        else:
            # GRU
            self.rnn = nn.GRU(hidden, hidden, num_layers=n_layers, 
                            batch_first=True, dropout=dropout if n_layers > 1 else 0)
        

        if use_time_bias:

            self.global_time_bias = nn.Parameter(torch.zeros(1))

            self.step_time_bias = nn.Parameter(torch.zeros(horizon))

            self.trend_bias = nn.Parameter(torch.zeros(1))

            nn.init.normal_(self.global_time_bias, std=0.01)
            nn.init.normal_(self.step_time_bias, std=0.01)
            nn.init.normal_(self.trend_bias, std=0.01)
        

        if use_multiscale:

            self.scales = [24, 12, 6, 3, 1]
            self.scale_heads = nn.ModuleList()
            for scale in self.scales:

                self.scale_heads.append(nn.Linear(hidden, 1))
        else:

            self.out = nn.Linear(hidden, 1)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, feat):
        # feat: [B, D]
        B, D = feat.shape
        
        # 1.
        h0 = self.proj_in(feat)  # [B, hidden]
        
        # 2.
        Q = self.horizon_queries.unsqueeze(0).expand(B, -1, -1)  # [B, horizon, hidden]
        h = h0.unsqueeze(1) + Q
        
        # 3.
        if self.use_positional_encoding:
            h = self.pos_encoding(h)  # [B, horizon, hidden]
        
        # 4.
        if self.use_tcn:
            # TCN
            h_tcn = h.transpose(1, 2)  # [B, hidden, horizon]
            for tcn_layer in self.tcn_layers:
                h_tcn = tcn_layer(h_tcn)
            h = h_tcn.transpose(1, 2)  # [B, horizon, hidden]
        else:
            # GRU
            h, _ = self.rnn(h)  # [B, horizon, hidden]
        
        h = self.dropout(h)
        
        # 5.
        if self.use_multiscale:
            outputs = {}
            for i, (scale, head) in enumerate(zip(self.scales, self.scale_heads)):
                if scale == self.horizon:

                    out = head(h)  # [B, horizon, 1] -> squeeze to [B, horizon]
                    out = out.squeeze(-1) if out.dim() == 3 else out  # [B, horizon]
                else:

                    indices = torch.linspace(0, self.horizon-1, scale, dtype=torch.long, device=h.device)
                    h_sub = h[:, indices, :]  # [B, scale, hidden]
                    out = head(h_sub)  # [B, scale, 1] -> squeeze to [B, scale]
                    out = out.squeeze(-1) if out.dim() == 3 else out  # [B, scale]
                

                if self.use_time_bias:
                    if scale == self.horizon:

                        out = out + self.global_time_bias
                        out = out + self.step_time_bias.unsqueeze(0)

                        trend = self.trend_bias * torch.arange(self.horizon, device=out.device).float()
                        out = out + trend.unsqueeze(0)
                    else:

                        out = out + self.global_time_bias
                        trend = self.trend_bias * torch.arange(scale, device=out.device).float()
                        out = out + trend.unsqueeze(0)
                
                outputs[f'forecast_scale_{scale}'] = out
            

            return outputs
        else:

            out = self.out(h).squeeze(-1)  # [B, horizon]
            

            if self.use_time_bias:
                out = out + self.global_time_bias
                out = out + self.step_time_bias.unsqueeze(0)

                trend = self.trend_bias * torch.arange(self.horizon, device=out.device).float()
                out = out + trend.unsqueeze(0)
            
            return out

class ClassificationHead(nn.Module):
    """
    Output two heads:
    1. Regression prediction: [B, F]
    2. Missing prediction: [B, F]
    """
    def __init__(self, projection_dim, n_feature, out_len=1, dropout=0.1,
                 use_batch_norm=False, use_layer_norm=False):
        super().__init__()
        hidden = min(256, projection_dim * 2)


        self.reg_head = MLPHead(
            projection_dim,
            n_feature * out_len,
            hidden=hidden,
            dropout=dropout,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm
        )


        self.miss_head = MLPHead(
            projection_dim,
            n_feature * out_len,
            hidden=hidden,
            dropout=dropout,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm
        )

        self.n_feature = n_feature
        self.out_len = out_len

    def forward(self, x):
        B = x.size(0)
        y_pred = self.reg_head(x).view(B, self.n_feature)       # [B, F]
        m_pred = torch.sigmoid(self.miss_head(x).view(B, self.n_feature))  # [B, F]
        return {"y_pred": y_pred, "m_pred": m_pred}

class MultiModalPredictor(nn.Module):
    """
    Classification/temporal prediction dual heads:
      encoding → EAMC → cross-modal fusion → causal graph → concat → heads
    """
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = self.config.get("shared_dim", 128)
        

        self.dropout_rate = self.config.get("dropout_rate", 0.3)
        self.use_batch_norm = self.config.get("use_batch_norm", False)  # Changed to False to avoid batch size issues
        self.use_layer_norm = self.config.get("use_layer_norm", True)  # Enable layer norm by default


        import os
        self.use_dynseg = self.config.get('use_dynseg', self._get_bool_env('USE_DYNSEG', True))
        self.use_causal = self.config.get('use_causal', self._get_bool_env('USE_CAUSAL', True))
        self.use_nbc = self.config.get('use_nbc', self._get_bool_env('USE_NBC', True))
        self.frequency_branch_enabled = self.config.get('frequency_branch_enabled', self._get_bool_env('FREQUENCY_BRANCH_ENABLED', True))
        self.use_fft = self.config.get('use_fft', self._get_bool_env('USE_FFT', True))
        self.use_wavelet = self.config.get('use_wavelet', self._get_bool_env('USE_WAVELET', True))
        

        self.use_multimodal = self.config.get('use_multimodal', self._get_bool_env('USE_MULTIMODAL', True))
        self.use_image = self.config.get('use_image', self._get_bool_env('USE_IMAGE', True))
        self.use_text = self.config.get('use_text', self._get_bool_env('USE_TEXT', True))
        self.use_static = self.config.get('use_static', self._get_bool_env('USE_STATIC', True))
        

        self.use_causal_frequency = self._get_bool_env('USE_CAUSAL_FREQUENCY', True)
        self.use_causal_pooling = self._get_bool_env('USE_CAUSAL_POOLING', True)
        

        config.update({
            'use_dynseg': self.use_dynseg,
            'use_causal': self.use_causal,
            'use_nbc': self.use_nbc,
            'frequency_branch_enabled': self.frequency_branch_enabled,
            'use_fft': self.use_fft,
            'use_wavelet': self.use_wavelet,
            'use_causal_frequency': self.use_causal_frequency,
            'use_causal_pooling': self.use_causal_pooling,
            'use_multimodal': self.use_multimodal,
            'use_image': self.use_image,
            'use_text': self.use_text,
            'use_static': self.use_static,
            'pooling_tau': self.config.get('pooling_tau', 5.0)
            
        })

        self.ff = FeatureExtractorAndFusion(config)
        

        if self.use_dynseg:
            self.seg_mod = DynamicSegmentationModule(shared_dim=self.shared_dim, config=config)
        else:
            self.seg_mod = None
            
        if self.use_causal:
            self.cg_mod = CausalGraphModuleLite(shared_dim=self.shared_dim, config=config)
        else:
            self.cg_mod = None


        final_dim = self.shared_dim * 2
        self.feature_dropout = nn.Dropout(self.dropout_rate)
        


        self.feature_projection = nn.Sequential(
            nn.Linear(final_dim, final_dim // 2),
            nn.LayerNorm(final_dim // 2),
            nn.GELU(),
            nn.Dropout(self.dropout_rate)
        )
        

        projection_dim = final_dim // 2
        self.classification_head = ClassificationHead(
            projection_dim,
            self.config.get("out_dim", 20),
            out_len=self.config.get("out_len", 1),
            dropout=self.dropout_rate,
            use_batch_norm=self.use_batch_norm,
            use_layer_norm=self.use_layer_norm
        )
        
        self.forecast_head = ForecastHead(
            projection_dim,
            horizon=self.config.get("out_len", 24),
            hidden=min(256, projection_dim * 2),
            n_layers=2,
            dropout=self.dropout_rate,
            use_tcn=self.config.get("forecast_use_tcn", True),
            use_multiscale=self.config.get("forecast_use_multiscale", True),
            use_positional_encoding=self.config.get("forecast_use_positional_encoding", True),
            use_time_bias=self.config.get("forecast_use_time_bias", True)
        )


    def _get_bool_env(self, env_var: str, default: bool = True) -> bool:
        """Get boolean from environment variable"""
        import os
        value = os.environ.get(env_var, '')
        if value.lower() in ('true', '1', 'yes', 'on'):
            return True
        elif value.lower() in ('false', '0', 'no', 'off'):
            return False
        else:
            return default

    def forward(self,
                images: torch.Tensor,
                text_item: Optional[Dict[str, torch.Tensor]],
                packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                static_data: torch.Tensor,
                task: str = None,
                return_ae: bool = False) -> Dict[str, torch.Tensor]:

        img_tok, txt_tok, ts_feat, static_vec = self.ff.extract_features(images, text_item, packed_ts, static_data)
        MULTI_PATH = False
        if not MULTI_PATH:
            ts_feat = ts_feat[:, 0]
            static_vec = static_vec[:, 0]
            img_tok = img_tok# .mean(dim=1)
            txt_tok = txt_tok# .mean(dim=1)
        

        if isinstance(packed_ts, dict) and ("seq_lengths" in packed_ts) and packed_ts["seq_lengths"].dim() >= 1:
            ts_lengths = packed_ts["seq_lengths"]
            if ts_lengths.dim() == 2:
                ts_lengths = ts_lengths.min(dim=1)[0]
        else:
            ts_lengths = torch.full((ts_feat.size(0),), ts_feat.size(1), device=ts_feat.device, dtype=torch.long)


        if self.use_dynseg and self.seg_mod is not None:
            seg_tokens, seg_seq, eamc_info = self.seg_mod(ts_feat, ts_lengths)
        else:

            seg_tokens = ts_feat
            seg_seq = ts_feat.unsqueeze(2) if ts_feat.dim() == 3 else ts_feat  # [B,T,1,D]
            eamc_info = {"use_dynseg": False, "use_nbc": False}
        

        fused_global = self.ff.cross_modal_fusion(img_tok, txt_tok, seg_tokens, static_vec)
        

        if self.use_causal and self.cg_mod is not None:
            graph_nodes, graph_info = self.cg_mod(seg_seq, img_tok, txt_tok)
            graph_pool = graph_nodes.mean(dim=1)
        else:

            if seg_tokens.dim() == 3:  # [B,T,D]
                graph_pool = seg_tokens.mean(dim=1)  # [B,D]
            else:  # [B,D]
                graph_pool = seg_tokens
            graph_info = {"use_causal": False}
        

        feat = torch.cat([fused_global, graph_pool], dim=-1)
        

        feat = self.feature_dropout(feat)
        feat = self.feature_projection(feat)
        

        eamc_info["use_dynseg"] = self.use_dynseg
        eamc_info["use_nbc"] = self.use_nbc
        graph_info["use_causal"] = self.use_causal


        self.last_eamc_info = eamc_info
        self.last_graph_info = graph_info


        if task == 'forecasting' or task == 'regression':
            forecast_output = self.forecast_head(feat)

            if isinstance(forecast_output, dict):
                logits = forecast_output['forecast_scale_24']
            else:
                logits = forecast_output
        else:
            logits = self.classification_head(feat)
        

        clip_loss = self._compute_alignment_loss(img_tok, txt_tok, seg_tokens, static_vec)
        

        result = {
            "logits": logits,
            # "logits_cls": self.classification_head(feat),
            "clip_loss": clip_loss,
            "eamc_info": eamc_info,
            "graph_info": graph_info
        }
        

        forecast_output = self.forecast_head(feat)
        if isinstance(forecast_output, dict):

            result["forecast"] = forecast_output['forecast_scale_24']
            result.update(forecast_output)
        else:

            result["forecast"] = forecast_output
        
        return result
    
    def _compute_alignment_loss(self, img_tok, txt_tok, ts_tok, static_vec):
        """Compute multimodal alignment loss"""
        device = next(self.parameters()).device
        alignment_loss = torch.tensor(0.0, device=device)
        

        modalities = []
        modality_names = []
        
        if img_tok is not None and not torch.allclose(img_tok, torch.zeros_like(img_tok)):

            if img_tok.dim() == 3:  # [B, N, D]
                img_global = img_tok.mean(dim=1)  # [B, D]
            elif img_tok.dim() == 2:  # [B, D]
                img_global = img_tok
            else:
                img_global = img_tok.view(img_tok.size(0), -1)
            modalities.append(img_global)
            modality_names.append("image")
        
        if txt_tok is not None and not torch.allclose(txt_tok, torch.zeros_like(txt_tok)):

            if txt_tok.dim() == 3:  # [B, N, D]
                txt_global = txt_tok.mean(dim=1)  # [B, D]
            elif txt_tok.dim() == 2:  # [B, D]
                txt_global = txt_tok
            else:
                txt_global = txt_tok.view(txt_tok.size(0), -1)
            modalities.append(txt_global)
            modality_names.append("text")
        
        if ts_tok is not None and not torch.allclose(ts_tok, torch.zeros_like(ts_tok)):

            if ts_tok.dim() == 3:  # [B, N, D]
                ts_global = ts_tok.mean(dim=1)  # [B, D]
            elif ts_tok.dim() == 2:  # [B, D]
                ts_global = ts_tok
            else:
                ts_global = ts_tok.view(ts_tok.size(0), -1)
            modalities.append(ts_global)
            modality_names.append("timeseries")
        
        if static_vec is not None and not torch.allclose(static_vec, torch.zeros_like(static_vec)):

            if static_vec.dim() == 2:  # [B, D]
                modalities.append(static_vec)
            else:
                static_global = static_vec.view(static_vec.size(0), -1)
                modalities.append(static_global)
            modality_names.append("static")
        

        if len(modalities) < 2:
            return alignment_loss
        

        for i, (feat, name) in enumerate(zip(modalities, modality_names)):
            if feat.dim() != 2:
                modalities[i] = feat.view(feat.size(0), -1)
        

        if len(modalities) > 0:
            min_dim = min(feat.size(1) for feat in modalities)
            for i, feat in enumerate(modalities):
                if feat.size(1) != min_dim:

                    modalities[i] = feat[:, :min_dim]
        

        batch_size = modalities[0].size(0)
        temperature = 0.07
        

        total_loss = 0.0
        pair_count = 0
        
        for i in range(len(modalities)):
            for j in range(i + 1, len(modalities)):
                feat_i = modalities[i]  # [B, D]
                feat_j = modalities[j]  # [B, D]
                


                feat_i_norm = F.normalize(feat_i, p=2, dim=1)  # [B, D]
                feat_j_norm = F.normalize(feat_j, p=2, dim=1)  # [B, D]
                

                sim_matrix = torch.mm(feat_i_norm, feat_j_norm.t()) / temperature
                

                labels = torch.arange(batch_size, device=device)
                

                loss_i2j = F.cross_entropy(sim_matrix, labels)
                loss_j2i = F.cross_entropy(sim_matrix.t(), labels)
                

                pair_loss = (loss_i2j + loss_j2i) / 2.0
                total_loss += pair_loss
                pair_count += 1
        

        if pair_count > 0:
            alignment_loss = total_loss / pair_count
        
        return alignment_loss
