import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.SelfAttention_Family_rev import FullAttention, FullAttention_temp, AttentionLayer, FocalSelfAttention
from layers.Embed_rev import PositionalEmbedding, PatchTimeRangeEmbedding, ChannelEmbedding
import numpy as np


class SingleSharedDecoder(nn.Module):
    def __init__(self, n_vars, d_model, target_window, dropout=0.0, features='M', num_global_tokens=1):
        super().__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.target_window = target_window
        self.dropout = nn.Dropout(dropout)
        self.features = features
        self.num_global_tokens = num_global_tokens
        
        # Single shared decoder for all variables
        # With bidirectional context, the global tokens contain richer information
        self.unified_head = nn.Linear(d_model * num_global_tokens, target_window)
        
        # Optional projection layer to combine information from all global tokens
        # This can be useful when using multiple global tokens
        if num_global_tokens > 1:
            self.token_projection = nn.Linear(d_model * num_global_tokens, d_model * num_global_tokens)
        else:
            self.token_projection = None
        
    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        bs, nvars, d_model, patch_num = x.shape

        if self.features == 'M':
            # Extract multiple global tokens for all variables
            # Last num_global_tokens positions are global tokens
            global_tokens = x[:, :, :, -self.num_global_tokens:]  # bs x nvars x d_model x num_global_tokens
            
            # Process each variable separately with the shared head
            out_per_var = []
            for i in range(nvars):
                # Extract global tokens for variable i
                var_tokens = global_tokens[:, i]  # bs x d_model x num_global_tokens
                
                # Flatten the tokens
                flattened_tokens = var_tokens.reshape(bs, d_model * self.num_global_tokens)
                
                # Optional projection to combine information from multiple global tokens
                if self.token_projection is not None:
                    flattened_tokens = self.token_projection(flattened_tokens)
                
                # Apply shared head
                var_out = self.unified_head(flattened_tokens)
                var_out = self.dropout(var_out)
                out_per_var.append(var_out)
            
            # Stack outputs for all variables
            out = torch.stack(out_per_var, dim=-1)  # bs x target_window x nvars
            
        elif self.features == 'MS':
            # Extract all global tokens for target variable (last variable)
            global_tokens_target = x[:, -1, :, -self.num_global_tokens:]  # bs x d_model x num_global_tokens
            
            # Flatten all global tokens
            flattened_tokens = global_tokens_target.reshape(bs, d_model * self.num_global_tokens)
            
            # Optional projection to combine information from multiple global tokens
            if self.token_projection is not None:
                flattened_tokens = self.token_projection(flattened_tokens)
            
            # Apply unified decoder
            out_target_var = self.unified_head(flattened_tokens)
            out_target_var = self.dropout(out_target_var)
            
            # Add nvars dimension
            out = out_target_var.unsqueeze(-1)

        return out

# channel-wise patching / tokenization / channel-token insertion
class ChannelWisePatchEmbedding(nn.Module):
    def __init__(self, n_vars, d_model, patch_lens, sampling_rates, num_global_tokens=1):
        super(ChannelWisePatchEmbedding, self).__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.num_global_tokens = num_global_tokens
        
        if isinstance(patch_lens, list):
            self.patch_lens = torch.tensor(patch_lens)
        else:
            self.patch_lens = patch_lens
        
        if isinstance(sampling_rates, list):
            self.sampling_rates = torch.tensor(sampling_rates)
        else:
            self.sampling_rates = sampling_rates
        
        # time series patch tokenization for each variable with their respective patch length
        self.value_embeddings = nn.ModuleList([
            nn.Linear(patch_len, d_model) for patch_len in self.patch_lens
        ])
        
        # channel-token
        # self.glb_token = nn.Parameter(torch.randn(1, n_vars, d_model))
        # channel-tokens
        self.glb_tokens = nn.Parameter(torch.randn(1, n_vars, num_global_tokens, d_model))
        #single_global_token = torch.randn(1, 1, num_global_tokens, d_model)
        #self.glb_tokens = nn.Parameter(single_global_token.expand(1, n_vars, num_global_tokens, d_model))

        self.position_embedding = PositionalEmbedding(d_model)
        self.channel_embedding = ChannelEmbedding(d_model, n_vars)
        #self.patch_time_range_embedding = PatchTimeRangeEmbedding(d_model, patch_lens, sampling_rates)
        

    def forward(self, x, masking=True):
        """
        Apply channel-wise patching with different patch lengths per channel.
        
        Args:
            x (torch.Tensor): Input tensor of shape [B, n_vars, L]
            masking (bool): Whether to generate and return attention masks for padded tokens
            
        Returns:
            x (torch.Tensor): Embedded patches with shape [B*n_vars, num_patches+num_global_tokens, d_model]
            n_vars (int): Number of channels
            attn_mask (torch.Tensor): Attention mask for padded tokens
        """
        B, n_vars, L = x.shape

        # Process each variable with its own patch length
        patch_embeddings = []
        num_patches_list = []
        max_patches = 0

        for v in range(n_vars):
            var_data = x[:, v, :]  # (B, L)
            patch_len = self.patch_lens[v].item()

            # Calculate number of patches for each variable
            num_patches = L // patch_len
            num_patches_list.append(num_patches)
            max_patches = max(max_patches, num_patches)

            # Create patches for each variable
            patches = var_data.unfold(dimension=-1, size=patch_len, step=patch_len)  # (B, num_patches, patch_len)

            # Embed patches
            embedded = self.value_embeddings[v](patches)  # (B, num_patches, d_model)
            
            # Apply positional embedding
            embedded = embedded + self.position_embedding(embedded) + self.channel_embedding(torch.tensor(v, device=embedded.device))

            # Store embeddings for each variable
            patch_embeddings.append(embedded)
        
        # Pad all variables to have the same number of patches & channel-tokens insertion
        padded_embeddings = []
        for v in range(n_vars):
            embedded = patch_embeddings[v]  # (B, num_patches_v, d_model)
            num_patches_v = embedded.size(1)

            if num_patches_v < max_patches:
                # Pad with zeros to match max_patches
                padding = torch.zeros(B, max_patches - num_patches_v, self.d_model, device=embedded.device)
                embedded = torch.cat([embedded, padding], dim=1)  # (B, max_patches, d_model)

            # Add global tokens - these don't get time embeddings as they're not associated with specific time points
            glb = self.glb_tokens[:, v, :, :].repeat((B, 1, 1))  # (B, num_global_tokens, d_model)
            embedded = torch.cat([embedded, glb], dim=1)  # (B, max_patches+num_global_tokens, d_model)

            # Store embeddings for each variable
            padded_embeddings.append(embedded.unsqueeze(1))  # (B, 1, max_patches+num_global_tokens, d_model)
        
        # Concatenate all variables
        x = torch.cat(padded_embeddings, dim=1)  # (B, n_vars, max_patches+num_global_tokens, d_model)
        
        # Reshape to format expected by encoder: [B*n_vars, max_patches+num_global_tokens, d_model]
        x = torch.reshape(x, (x.shape[0]*x.shape[1], x.shape[2], x.shape[3]))

        # Always generate attention mask for padding
        attn_mask = self._generate_mask(B, n_vars, num_patches_list, max_patches, masking, x.device)

        return x, n_vars, attn_mask
            
    def _generate_mask(self, B, n_vars, num_patches_list, max_patches, masking, device):
        """
        Generate attention mask for padded tokens.
        
        Args:
            B (int): Batch size
            n_vars (int): Number of variables/channels
            num_patches_list (list): List containing number of patches for each variable
            max_patches (int): Maximum number of patches across all variables
            masking (bool): Whether to apply masking
            device (torch.device): Device to create tensor on
            
        Returns:
            torch.Tensor: Boolean mask of shape [B*n_vars, max_patches+num_global_tokens]
                         where True indicates valid tokens and False indicates padded tokens
        """
        # Create a boolean mask where True means "keep" (valid token) and False means "mask" (padding token)
        mask = torch.ones((B, n_vars, max_patches+self.num_global_tokens), dtype=torch.bool, device=device)

        for b in range(B):
            for v in range(n_vars):
                if num_patches_list[v] < max_patches:
                    # Mark padded positions as False (to be masked out)
                    mask[b, v, num_patches_list[v]:max_patches] = False
                
                # Global tokens should never be masked
                mask[b, v, max_patches:] = True
        
        # Reshape to batch * nvars x total_patches for use in attention mechanisms
        return mask.view(B * n_vars, max_patches+self.num_global_tokens)


class Encoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, x_mask=None, tau=None, delta=None):
        attn_list = []
        for layer in self.layers:
            x, attn = layer(x, x_mask=x_mask, tau=tau, delta=delta)
            attn_list.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x, attn_list


class EncoderLayer(nn.Module):
    def __init__(self, self_attention, global_token_attention, local_to_global_attention, global_to_local_attention, 
                 d_model, d_ff=None, dropout=0.1, activation="relu", n_vars=7, num_global_tokens=1):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        # 4개의 독립적인 어텐션 레이어 사용
        self.self_attention = self_attention  # 로컬 토큰간 어텐션
        self.global_token_attention = global_token_attention  # 글로벌 토큰간 어텐션
        self.local_to_global_attention = local_to_global_attention  # 로컬→글로벌 어텐션
        self.global_to_local_attention = global_to_local_attention  # 글로벌→로컬 어텐션
        self.n_vars = n_vars
        self.num_global_tokens = num_global_tokens
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)
        self.norm5 = nn.LayerNorm(d_model)

        # Gating mechanism for bidirectional attention
        self.local_to_global_gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        self.global_to_local_gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        
        # # Learnable residual scaling factors
        # self.local_residual_scale = nn.Parameter(torch.ones(1) * 0.8)
        # self.global_residual_scale = nn.Parameter(torch.ones(1) * 0.8)

        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, x_mask=None, tau=None, delta=None):
        B_n_vars, L, D = x.shape # (B*n_vars, max_patches+num_global_tokens, d_model)
        B = B_n_vars // self.n_vars
        
        # Split local tokens and global tokens
        x_local = x[:, :-self.num_global_tokens, :]  # Local tokens
        x_global = x[:, -self.num_global_tokens:, :]  # Global tokens
        
        # Local self-attention with padding mask for local tokens
        local_mask = None
        if x_mask is not None:
            # Extract mask for local tokens only (excluding global tokens)
            local_mask = x_mask[:, :-self.num_global_tokens]
        
        # Apply local self-attention with proper masking
        local_attn_output, _ = self.self_attention(
            x_local, x_local, x_local,
            attn_mask=local_mask,  # Use mask for padding
            tau=tau, delta=None
        )
        x_local = x_local + self.dropout(local_attn_output)
        x_local = self.norm1(x_local)
        
        # (Bidirectional attention between local and global tokens)
        # Local → Global: Global tokens attend to local tokens with gating
        # Global tokens should only attend to valid local tokens (not padding)
        local_to_global_output, _ = self.local_to_global_attention(
            x_global, x_local, x_local,
            attn_mask=local_mask,  # Pass local mask to ignore padded tokens
            tau=tau, delta=None
        )
        
        # Apply gating mechanism to local-to-global attention
        local_to_global_gate = self.local_to_global_gate(x_global)
        local_to_global_output = local_to_global_gate * local_to_global_output
        
        # Apply learnable residual scaling factor
        #x_global = x_global + self.dropout(local_to_global_output) * self.global_residual_scale
        x_global = x_global + self.dropout(local_to_global_output)
        x_global = self.norm3(x_global)

        # Global tokens attention (between channels)
        global_attn_output, attn_glb = self.global_token_attention(
            x_global, x_global, x_global,
            attn_mask=None,  # No mask needed when attending to global tokens
            tau=tau, delta=None
        )
        x_global = x_global + self.dropout(global_attn_output)
        x_global = self.norm2(x_global)
        
        # Global → Local: Local tokens attend to global tokens with gating
        # Only non-padded local tokens should be updated
        if x_mask is not None:
            # Process updates to local tokens from global context
            # Get original local token mask to identify valid tokens
            local_token_mask = x_mask[:, :-self.num_global_tokens]  # [B*n_vars, local_len]
            
            # Compute updates for all local tokens
            global_to_local_output, _ = self.global_to_local_attention(
                x_local, x_global, x_global,
                attn_mask=None,  # No mask needed when attending to global tokens
                tau=tau, delta=None
            )
            
            # Apply gating mechanism to global-to-local attention
            global_to_local_gate = self.global_to_local_gate(x_local)
            global_to_local_output = global_to_local_gate * global_to_local_output
            
            # Only apply updates to valid (non-padded) local tokens
            # This preserves zero values for padded positions
            local_updated_masked = global_to_local_output * local_token_mask.unsqueeze(-1)
            #x_local = x_local + self.dropout(local_updated_masked) * self.local_residual_scale
            x_local = x_local + self.dropout(local_updated_masked)
        else:
            # If no masking is needed, update all local tokens
            global_to_local_output, _ = self.global_to_local_attention(
                x_local, x_global, x_global,
                attn_mask=None,
                tau=tau, delta=None
            )
            
            # Apply gating mechanism to global-to-local attention
            global_to_local_gate = self.global_to_local_gate(x_local)
            global_to_local_output = global_to_local_gate * global_to_local_output
            
            #x_local = x_local + self.dropout(global_to_local_output) * self.local_residual_scale
            x_local = x_local + self.dropout(global_to_local_output)
            
        x_local = self.norm4(x_local)
        
        # Recombine local and global tokens
        x = torch.cat([x_local, x_global], dim=1)
        
        # Apply FFN to combined representation
        y = x
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        
        return self.norm5(x + y), attn_glb


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.features = configs.features
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.use_norm = configs.use_norm
        self.patch_lens = configs.patch_lens
        self.n_vars = 1 if configs.features == 'MS' else configs.enc_in
        self.num_global_tokens = configs.num_global_tokens
        # initialize temperature
        if configs.cross_temp == -1:
            self.cross_temp = nn.Parameter(torch.ones(self.n_vars) * 0.1, requires_grad=True)
        else:
            self.cross_temp = configs.cross_temp
        
        # Embedding
        self.channel_wise_patch_embedding = ChannelWisePatchEmbedding(
            self.n_vars, 
            configs.d_model, 
            self.patch_lens, 
            configs.sampling_rates,
            num_global_tokens=self.num_global_tokens
        )
        self.latest_attention = 0

        # 각 어텐션 레이어를 독립적으로 초기화하여 양방향 Local-Global 어텐션 구현
        self.encoder = Encoder(
            [
                EncoderLayer(
                    # 로컬 토큰간 어텐션 (패딩 마스킹 적용)
                    self_attention=AttentionLayer(
                        FocalSelfAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=False, temp=configs.self_temp),
                        configs.d_model, configs.n_heads),
                    # 로컬→글로벌 어텐션 (글로벌 토큰이 로컬 토큰을 참조)
                    local_to_global_attention=AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=False, temp=configs.self_temp),
                        configs.d_model, configs.n_heads),
                    # 글로벌 토큰간 어텐션 (채널 간 정보 교환)
                    global_token_attention=AttentionLayer(
                        FullAttention_temp(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=True, temp=self.cross_temp, num_global_tokens=self.num_global_tokens),
                        configs.d_model, configs.n_heads),
                    # 글로벌→로컬 어텐션 (로컬 토큰이 글로벌 토큰을 참조)
                    global_to_local_attention=AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=False, temp=configs.self_temp),
                        configs.d_model, configs.n_heads),
                    d_model=configs.d_model,
                    d_ff=configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation,
                    n_vars=self.n_vars,
                    num_global_tokens=self.num_global_tokens
                )
                for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        self.head = SingleSharedDecoder(configs.enc_in, configs.d_model, configs.pred_len,
                                dropout=configs.dropout, features=configs.features, num_global_tokens=self.num_global_tokens)

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        _, _, N = x_enc.shape

        # Generate embeddings with proper padding masks for zero-padded tokens
        channel_wise_patch_embed, n_vars, attn_mask = self.channel_wise_patch_embedding(
            x_enc[:, :, -1].unsqueeze(-1).permute(0, 2, 1),
            masking=True  # Always create padding masks
        )

        # Pass the attention mask to the encoder for handling zero-padded tokens
        enc_out, attn = self.encoder(channel_wise_patch_embed, x_mask=attn_mask)
        
        # Reshape output for the decoder
        enc_out = torch.reshape(
            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
        # z: [bs x nvars x d_model x patch_num]
        enc_out = enc_out.permute(0, 1, 3, 2)

        dec_out = self.head(enc_out)  # z: [bs x nvars x target_window]

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out, attn


    def forecast_multi(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        _, _, N = x_enc.shape
        
        # Generate embeddings with proper padding masks for zero-padded tokens
        channel_wise_patch_embed, n_vars, attn_mask = self.channel_wise_patch_embedding(
            x_enc.permute(0, 2, 1),
            masking=True  # Always create padding masks
        )
        
        # Pass the attention mask to the encoder for handling zero-padded tokens
        enc_out, attn = self.encoder(channel_wise_patch_embed, x_mask=attn_mask)
        
        self.latest_attention = attn
        
        # Reshape output for the decoder
        enc_out = torch.reshape(
            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
        # z: [bs x nvars x d_model x patch_num]
        enc_out = enc_out.permute(0, 1, 3, 2)

        dec_out = self.head(enc_out)  # z: [bs x nvars x target_window]

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        
        return dec_out, attn

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast' or self.task_name == 'missing':
            if self.features == 'M':
                dec_out, attn = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec)
                return dec_out[:, -self.pred_len:, :]#, attn  # [B, L, D]
            else:
                # for attention map
                dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
                return dec_out[:, -self.pred_len:, :]  # [B, L, D]
                
        else:
            return None

# Simple linear layer tokenizer for time patches
class TimePatchTokenizer(nn.Module):
    def __init__(self, patch_len, d_model, sampling_rate=1.0):
        super(TimePatchTokenizer, self).__init__()
        # Simple linear layer with no activation
        self.linear = nn.Linear(patch_len, d_model)
        
    def forward(self, x):
        # x: [B, num_patches, patch_len]
        # Apply linear transformation (simplest way to encode patches)
        return self.linear(x)  # [B, num_patches, d_model]