import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.SelfAttention_Family import FullAttention, FullAttention_temp, AttentionLayer
from layers.Embed import PositionalEmbedding, ChannelEmbedding
import numpy as np
import matplotlib.pyplot as plt
import os

class PerVarProjectionHead(nn.Module):
    def __init__(self, n_vars, d_model, pred_len, sampling_rates, num_global_tokens=1, dropout=0.0):
        super().__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.pred_len = pred_len
        self.sampling_rates = sampling_rates
        self.num_global_tokens = num_global_tokens
        self.dropout = nn.Dropout(dropout)

        # 각 변수마다 다른 출력 길이 (sampling rate 기반)
        min_sr = min(sampling_rates)
        self.out_lens = [pred_len // int(sr / min_sr) for sr in sampling_rates]

        # 변수별 projection head 생성
        self.heads = nn.ModuleList([
            nn.Linear(d_model * num_global_tokens, self.out_lens[v])
            for v in range(n_vars)
        ])

    def forward(self, x, n_vars, num_global_tokens=1):
        """
        Args:
            x: [B, n_vars, D] if num_global_tokens == 1
               [B, n_vars, num_global_tokens, D] otherwise
        Returns:
            out: [B, pred_len, n_vars]
        """
        B = x.shape[0]
        outputs = []

        for v in range(n_vars):
            if num_global_tokens == 1:
                glb_token = x[:, v, :]                  # [B, D]
                glb_token = glb_token.view(B, -1)       # [B, D]
            else:
                glb_token = x[:, v, :, :]               # [B, G, D]
                glb_token = glb_token.reshape(B, -1)    # [B, G*D]

            out_v = self.heads[v](glb_token)            # [B, L_v]
            out_v = self.dropout(out_v)

            # zero-pad to pred_len
            out_padded = torch.zeros(B, self.pred_len, device=x.device)
            out_padded[:, :out_v.shape[1]] = out_v

            outputs.append(out_padded.unsqueeze(-1))    # [B, pred_len, 1]

        out = torch.cat(outputs, dim=-1)                # [B, pred_len, n_vars]
        return out

class PerRateProjectionHead(nn.Module):
    def __init__(self, n_vars, d_model, pred_len, sampling_rates, num_global_tokens=1, dropout=0.0):
        super().__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.pred_len = pred_len
        self.num_global_tokens = num_global_tokens
        self.dropout = nn.Dropout(dropout)

        min_sr = min(sampling_rates)
        self.unique_rates = sorted(set(sampling_rates))
        self.rate_to_outlen = {
            rate: pred_len // int(rate / min_sr) for rate in self.unique_rates
        }

        # rate별 Linear layer
        self.rate_heads = nn.ModuleDict({
            f"rate_{int(rate * 100)}": nn.Linear(d_model * num_global_tokens, self.rate_to_outlen[rate])
            for rate in self.unique_rates
        })
        self.var_to_rate = [f"rate_{int(rate * 100)}" for rate in sampling_rates]


    def forward(self, x, n_vars, num_global_tokens=1):
        """
        Args:
            x: [B, n_vars, D] if num_global_tokens == 1
               [B, n_vars, num_global_tokens, D] otherwise
        Returns:
            out: [B, pred_len, n_vars]
        """
        B = x.shape[0]
        outputs = []    

        for v in range(n_vars):
            rate_key = self.var_to_rate[v]
            head = self.rate_heads[rate_key]

            if num_global_tokens == 1:
                glb_token = x[:, v, :]               # [B, D]
                glb_token = glb_token.view(B, -1)    # [B, D]
            else:
                glb_token = x[:, v, :, :]            # [B, G, D]
                glb_token = glb_token.reshape(B, -1) # [B, G*D]

            out_v = head(glb_token)                  # [B, L_v]
            out_v = self.dropout(out_v)

            # zero-pad to pred_len
            out_padded = torch.zeros(B, self.pred_len, device=x.device)
            out_padded[:, :out_v.shape[1]] = out_v

            outputs.append(out_padded.unsqueeze(-1))  # [B, pred_len, 1]

        out = torch.cat(outputs, dim=-1)  # [B, pred_len, n_vars]
        return out



class SingleProjectionHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0, num_global_tokens=1):
        super().__init__()
        self.n_vars = n_vars
        self.num_global_tokens = num_global_tokens
        if self.num_global_tokens == 1:
            self.linear = nn.Linear(nf, target_window)
        else:
            self.linear = nn.Linear(nf*2, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x, n_vars, num_global_tokens=1):  # x: [bs x nvars x d_model x patch_num]
        if num_global_tokens == 1:
            x = self.linear(x)
        else:
            x = self.linear(x.reshape(x.shape[0], n_vars, -1))
            
        x = self.dropout(x)
        x = x.permute(0, 2, 1)
        return x


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        # Batch-wise masking: (B, 1, 1, ...) to broadcast over channels/tokens
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        keep_mask = torch.rand(shape, dtype=x.dtype, device=x.device) < keep_prob
        return x * keep_mask

    
class GlobalPatchEmbedding(nn.Module):
    def __init__(self, n_vars, d_model, patch_len, dropout):
        super(GlobalPatchEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
        self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
        # self.glb_token = nn.Parameter(torch.ones(1, n_vars, 1, d_model))
        # self.glb_token = nn.Parameter(torch.zeros(1, n_vars, 1, d_model))
        # self.glb_token = nn.Parameter(torch.zeros(1, n_vars, 1, d_model))
        self.position_embedding = PositionalEmbedding(d_model)
        self.n_vars = n_vars
        self.dropout = nn.Dropout(dropout)
        # self.glb_linear = nn.Linear(d_model, d_model)
        # self.relu = nn.ReLU()

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
        
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # Input encoding
        x = self.value_embedding(x) + self.position_embedding(x)
        x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
        x = torch.cat([x, glb], dim=2)
        # Calculate the mean of patches for each variate
        
        # glb_token = x.mean(dim=2, keepdim=True)
        # print("glb_token.shape: ", glb_token.shape)
        # glb_token = self.glb_linear(glb_token)
        # glb_token = self.relu(glb_token)
        # x = torch.cat([x, glb_token], dim=2)
        x = torch.reshape(x, (x.shape[0]*x.shape[1], x.shape[2], x.shape[3]))
        return self.dropout(x), n_vars


class ChannelWisePatchEmbedding(nn.Module):
    def __init__(self, n_vars, d_model, patch_lens, sampling_rates, num_global_tokens=1):
        super(ChannelWisePatchEmbedding, self).__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.num_global_tokens = num_global_tokens
        if isinstance(patch_lens, list):
            self.patch_lens = torch.tensor(patch_lens)
        else:
            self.patch_lens = patch_lens
        if isinstance(sampling_rates, list):
            self.sampling_rates = torch.tensor(sampling_rates)
        else:
            self.sampling_rates = sampling_rates
        # time series patch tokenization for each variable with their respective patch length
        self.value_embeddings = nn.ModuleList([
            nn.Linear(patch_len, d_model) for patch_len in self.patch_lens
        ])
        # channel-token
        # self.glb_token = nn.Parameter(torch.randn(1, n_vars, d_model))
        # channel-tokens
        self.glb_tokens = nn.Parameter(torch.randn(1, num_global_tokens, d_model))
        #single_global_token = torch.randn(1, 1, num_global_tokens, d_model)
        #self.glb_tokens = nn.Parameter(single_global_token.expand(1, n_vars, num_global_tokens, d_model))
        self.position_embedding = PositionalEmbedding(d_model)
        self.channel_embedding = ChannelEmbedding(d_model, n_vars)
        #self.patch_time_range_embedding = PatchTimeRangeEmbedding(d_model, patch_lens, sampling_rates)
    def forward(self, x, patch_prob = 1, masking=True):
        """
        Apply channel-wise patching with different patch lengths per channel.
        Args:
            x (torch.Tensor): Input tensor of shape [B, n_vars, L]
            masking (bool): Whether to generate and return attention masks for padded tokens
        Returns:
            x (torch.Tensor): Embedded patches with shape [B*n_vars, num_patches+num_global_tokens, d_model]
            n_vars (int): Number of channels
            num_patches_list (List[int]): List of patch counts per variable (excluding global tokens)
        """
        B, n_vars, L = x.shape
        min_sampling_rate = torch.min(self.sampling_rates).item()
        patch_embeddings = []
        num_patches_list = []
        keep_prob = patch_prob
        for v in range(n_vars):
            var_data = x[:, v, :]  # (B, L)
            patch_len = self.patch_lens[v].item()
            sampling_rate = self.sampling_rates[v].item()
            sampling_factor = int(sampling_rate / min_sampling_rate)
            num_patches = L // (patch_len * sampling_factor)
            patches = []
            for i in range(num_patches):
                if self.training and i != num_patches - 1:
                    if torch.rand(1).item() > keep_prob:
                        continue
                start_idx = i * patch_len * sampling_factor
                patch_data = []
                for j in range(patch_len):
                    idx = start_idx + j * sampling_factor
                    if idx < L:
                        patch_data.append(var_data[:, idx])
                if len(patch_data) < patch_len:
                    continue
                patch = torch.stack(patch_data, dim=1)  # (B, patch_len)
                patches.append(patch)
            # patch가 전부 제거되었을 경우: zero patch 하나 추가
            if len(patches) == 0:
                patch = torch.zeros((B, patch_len), device=var_data.device)
                patches_tensor = patch.unsqueeze(1)  # (B, 1, patch_len)
            else:
                patches_tensor = torch.stack(patches, dim=1)  # (B, real_num_patches, patch_len)
            real_num_patches = patches_tensor.shape[1]
            num_patches_list.append(real_num_patches)
            # patch embedding + pos + channel
            embedded = self.value_embeddings[v](patches_tensor)  # (B, real_num_patches, d_model)
            embedded = embedded + self.position_embedding(embedded)
            embedded = torch.cat([embedded, self.glb_tokens.repeat(B, 1, 1)], dim=1)
            embedded = self.channel_embedding(torch.tensor(v, device=embedded.device)) + embedded
            # global token 붙이기
            # embedded = torch.cat([embedded, self.glb_tokens.repeat(B, 1, 1)], dim=1)
            patch_embeddings.append(embedded)
        patch_embeddings = torch.cat(patch_embeddings, dim=1)  # (B, total_tokens, d_model)
        return patch_embeddings, num_patches_list, n_vars

class ChannelWisePatchEmbedding_share(nn.Module):
    def __init__(self, n_vars, d_model, patch_lens, sampling_rates, num_global_tokens=1):
        super(ChannelWisePatchEmbedding_share, self).__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.num_global_tokens = num_global_tokens
        if isinstance(patch_lens, list):
            self.patch_lens = torch.tensor(patch_lens)
        else:
            self.patch_lens = patch_lens
        if isinstance(sampling_rates, list):
            self.sampling_rates = torch.tensor(sampling_rates)
        else:
            self.sampling_rates = sampling_rates
        # time series patch tokenization for each variable with their respective patch length
        self.embedding_dict = nn.ModuleDict()
        self.var_to_key = []

        for v in range(n_vars):
            patch_len = int(self.patch_lens[v].item())
            key = f"pl{patch_len}"
            self.var_to_key.append(key)
            if key not in self.embedding_dict:
                self.embedding_dict[key] = nn.Linear(patch_len, d_model)

        # channel-token
        # self.glb_token = nn.Parameter(torch.randn(1, n_vars, d_model))
        # channel-tokens
        self.glb_tokens = nn.Parameter(torch.randn(1, num_global_tokens, d_model))
        #single_global_token = torch.randn(1, 1, num_global_tokens, d_model)
        #self.glb_tokens = nn.Parameter(single_global_token.expand(1, n_vars, num_global_tokens, d_model))
        self.position_embedding = PositionalEmbedding(d_model)
        self.channel_embedding = ChannelEmbedding(d_model, n_vars)
        #self.patch_time_range_embedding = PatchTimeRangeEmbedding(d_model, patch_lens, sampling_rates)
    def forward(self, x, patch_prob = 1, missing_flag = None):
        """
        Apply channel-wise patching with different patch lengths per channel.
        Args:
            x (torch.Tensor): Input tensor of shape [B, n_vars, L]
            masking (bool): Whether to generate and return attention masks for padded tokens
        Returns:
            x (torch.Tensor): Embedded patches with shape [B*n_vars, num_patches+num_global_tokens, d_model]
            n_vars (int): Number of channels
            num_patches_list (List[int]): List of patch counts per variable (excluding global tokens)
        """
        B, n_vars, L = x.shape
        min_sampling_rate = torch.min(self.sampling_rates).item()
        patch_embeddings = []
        num_patches_list = []
        keep_prob = patch_prob
        for v in range(n_vars):
            var_data = x[:, v, :]  # (B, L)
            patch_len = self.patch_lens[v].item()
            sampling_rate = self.sampling_rates[v].item()
            sampling_factor = int(sampling_rate / min_sampling_rate)
            num_patches = L // (patch_len * sampling_factor)
            patches = []

            for i in range(num_patches):
                if self.training and i != num_patches - 1:
                    if torch.rand(1).item() > keep_prob:
                        continue

                start_idx = i * patch_len * sampling_factor
                patch_data = []
                missing_check = []

                for j in range(patch_len):
                    idx = start_idx + j * sampling_factor
                    if idx < L:
                        patch_data.append(var_data[:, idx])  # (B,)
                        if missing_flag is not None:
                            missing_check.append(missing_flag[0, idx, v].item())  # scalar

                if len(patch_data) < patch_len:
                    continue

                # ✅ 해당 patch가 완전히 결측이면 건너뛰기
                if missing_flag is not None and all(m == 0 for m in missing_check):
                    if not self._printed_missing_notice:
                        print(f"[missing occurs at channel{v} during patching.]")
                        self._printed_missing_notice = True
                    continue

                patch = torch.stack(patch_data, dim=1)  # (B, patch_len)
                patches.append(patch)

            # patch가 전부 제거되었을 경우: zero patch 하나 추가
            if len(patches) == 0:
                patch = torch.zeros((B, patch_len), device=var_data.device)
                patches_tensor = patch.unsqueeze(1)  # (B, 1, patch_len)
            else:
                patches_tensor = torch.stack(patches, dim=1)  # (B, real_num_patches, patch_len)
            real_num_patches = patches_tensor.shape[1]
            num_patches_list.append(real_num_patches)
            # patch embedding + pos + channel
            # forward 안의 해당 부분만 수정
            key = self.var_to_key[v]
            linear = self.embedding_dict[key]
            embedded = linear(patches_tensor)  # (B, real_num_patches, d_model)

            embedded = embedded + self.position_embedding(embedded)
            embedded = torch.cat([embedded, self.glb_tokens.repeat(B, 1, 1)], dim=1)
            embedded = self.channel_embedding(torch.tensor(v, device=embedded.device)) + embedded
            # global token 붙이기
            # embedded = torch.cat([embedded, self.glb_tokens.repeat(B, 1, 1)], dim=1)
            patch_embeddings.append(embedded)
        patch_embeddings = torch.cat(patch_embeddings, dim=1)  # (B, total_tokens, d_model)
        return patch_embeddings, num_patches_list, n_vars


class Encoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, num_patches_list, n_vars, x_mask=None, tau=None, delta=None):
        attn_list = []
        for layer in self.layers:
            x, attn, mask, num_tokens_list = layer(x, num_patches_list, n_vars, x_mask=x_mask, tau=tau, delta=delta)
            attn_list.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x, attn_list, mask, num_tokens_list
    


class EncoderLayer(nn.Module):
    def __init__(self, self_attention,
                 d_model, d_ff=None, dropout=0.1, activation="relu",
                 n_vars=7, num_global_tokens=1, n_heads=8, ind_glb=0, batch_size=16, drop_rate=0, drop_path=0):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model

        self.self_attention = self_attention        

        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.d_model = d_model
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.num_global_tokens = num_global_tokens
        self.n_vars = n_vars
        self.n_heads = n_heads
        self.ind_glb = ind_glb
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.batch_size = batch_size
        self.drop_rate = drop_rate
        self.drop_path = DropPath(drop_prob=drop_path)

    def forward(self, x, num_patches_list, n_vars, x_mask=None, tau=None, delta=None):
        """
        x: (B, n_vars*L, D)
        """
        
        # if x.shape != (self.batch_size, -1, self.d_model):
        # print("EncoderLayer input x.shape: ", x.shape)
        # BN, L, D = x.shape
        B, L, D = x.shape
        
        # B = BN // self.n_vars

        # x = x.reshape(B, self.n_vars*L, D)
        # print("x.shape: ", x.shape)
        self.mask, self.num_tokens_list = self.build_attention_mask(B, self.n_heads, self.n_vars, L, num_patches_list, self.num_global_tokens, device=x.device, ind_glb=self.ind_glb, drop_rate=self.drop_rate)
        attn_output, attn_scores = self.self_attention(x, x, x, attn_mask=self.mask, tau=tau, delta=delta)

        # Residual connection + LayerNorm
        x = x + self.drop_path(attn_output)
        x = self.norm1(x)
        # x = x.reshape(self.batch_size*self.n_vars, -1, self.d_model)
        # Position-wise Feed Forward
        y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        out = self.norm2(x + self.drop_path(y))
        out = out.reshape(B, L, D)
        return out, attn_scores, self.mask, self.num_tokens_list
    
    def build_attention_mask(self, B, n_heads, n_vars, L, num_patches_list, num_global_tokens=1, device='cuda', ind_glb=0, drop_rate=0.1):
        #L : 전체 patch 수
        # print("num_patches_list: ", num_patches_list)
        # print("num_global_tokens: ", num_global_tokens)
        # print("L: ", L)
        def add_random_masking_safe_vectorized(mask, drop_rate=0.5):
            B, H, N, _ = mask.shape
            mask_flat = mask.view(-1)

            open_idx = torch.where(mask_flat == False)[0]
            num_to_mask = int(open_idx.numel() * drop_rate)
            if num_to_mask > 0:
                selected = open_idx[torch.randperm(open_idx.numel(), device=mask.device)[:num_to_mask]]
                mask_flat[selected] = True

            mask = mask_flat.view(B, H, N, N)

            masked_count = mask.sum(dim=-1)  # [B, H, N]
            needs_fix = masked_count == N

            if needs_fix.any():
                b_idx, h_idx, q_idx = torch.where(needs_fix)
                for b, h, q in zip(b_idx, h_idx, q_idx):
                    q_var = q.item() // L
                    local_start = q_var * L
                    local_end = local_start + (L - num_global_tokens)

                    open_range = list(range(local_start, local_end))
                    rand_idx = torch.tensor([np.random.choice(open_range)], device=mask.device)
                    mask[b, h, q, rand_idx] = False

            return mask

        assert sum(num_patches_list) + n_vars*num_global_tokens == L
        total_tokens = L
        
        num_tokens_list = [num_global_tokens + x for x in num_patches_list]

        mask = torch.ones((total_tokens, total_tokens), dtype=torch.bool, device=device)
        strategy = 'CI'
        
        for q in range(n_vars):
            curr_q = sum(num_tokens_list[:q])
            
            for k in range(n_vars):
                curr_k = sum(num_tokens_list[:k])

                if k == q:
                    mask[curr_q:curr_q+num_tokens_list[q], curr_k:curr_k+num_patches_list[k]] = False
                else:
                    if strategy == 'CI':
                        continue
                    else:
                        if ind_glb:
                            for i in range(num_global_tokens):
                                mask[curr_q+num_patches_list[q]+i, curr_k+num_patches_list[k]+i] = False
                        else:
                            mask[curr_q+num_patches_list[q]:curr_q+num_tokens_list[q], curr_k+num_patches_list[k]:curr_k+num_tokens_list[k]] = False

        mask = mask.repeat(B, n_heads, 1, 1)
        # print("Before drop: # unmasked =", (~mask).sum().item())
        mask = add_random_masking_safe_vectorized(mask, drop_rate=drop_rate)
        # print("After drop:  # unmasked =", (~mask).sum().item())
        
        # bool_mask = mask.clone()
        mask = mask.masked_fill(mask == 1, float('-inf'))

        return mask, num_tokens_list

class Model(nn.Module):

    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.features = configs.features
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.use_norm = configs.use_norm
        self.patch_lens = configs.patch_lens
        
        self.n_vars = 1 if configs.features == 'MS' else configs.enc_in
        self.num_global_tokens = configs.num_global_tokens
        # temperature 초기화
        if configs.cross_temp == -1:
            self.cross_temp = nn.Parameter(torch.ones(self.n_vars) * 0.1, requires_grad=True)
        else:
            self.cross_temp = configs.cross_temp
        self.sampling_rates = configs.sampling_rates
        # Embedding
        # self.channel_wise_patch_embedding = ChannelWisePatchEmbedding(
        #     self.n_vars, 
        #     configs.d_model, 
        #     self.patch_lens, 
        #     configs.sampling_rates,
        #     num_global_tokens=self.num_global_tokens
        # ) 
        self.channel_wise_patch_embedding = ChannelWisePatchEmbedding_share(
            self.n_vars, 
            configs.d_model, 
            self.patch_lens, 
            configs.sampling_rates,
            num_global_tokens=self.num_global_tokens
        )       
        self.latest_attention = 0
        self.latest_mask = 0
        self.keep_prob = configs.keep_prob

        self.encoder = Encoder(
            [
                EncoderLayer(
                    # 어텐션 (패딩 마스킹 적용)
                    self_attention=AttentionLayer(
                        FullAttention(True, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=True, temp=configs.self_temp),
                        configs.d_model, configs.n_heads),
                    d_model=configs.d_model,
                    d_ff=configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation,
                    n_vars=self.n_vars,
                    num_global_tokens=self.num_global_tokens,
                    n_heads=configs.n_heads,
                    ind_glb=configs.ind_glb,
                    batch_size=configs.batch_size,
                    drop_rate=configs.drop_rate,
                    drop_path=configs.drop_path
                )
                for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        
        #self.head = PerVarProjectionHead(
        #    n_vars=self.n_vars,
        #    d_model=configs.d_model,
        #    pred_len=configs.pred_len,
        #    sampling_rates=configs.sampling_rates,
        #    num_global_tokens=self.num_global_tokens,
        #    dropout=configs.dropout
        #)

        self.head = PerRateProjectionHead(
            n_vars=configs.enc_in,
            d_model=configs.d_model,
            pred_len=configs.pred_len,
            sampling_rates=configs.sampling_rates,
            num_global_tokens=configs.num_global_tokens,
            dropout=configs.dropout
        )



        # Use Single head
        #self.head = SingleProjectionHead(configs.enc_in, configs.d_model, configs.pred_len, head_dropout=configs.dropout, num_global_tokens=self.num_global_tokens)

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        _, _, N = x_enc.shape

        global_patch_embed, n_vars = self.global_patch_embedding(x_enc[:, :, -1].unsqueeze(-1).permute(0, 2, 1))

        enc_out, attn = self.encoder(global_patch_embed, n_vars)
        enc_out = torch.reshape(
            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
        # z: [bs x nvars x d_model x patch_num]
        enc_out = enc_out.permute(0, 1, 3, 2)

        dec_out = self.head(enc_out)  # z: [bs x nvars x target_window]
        #dec_out = dec_out.permute(0, 2, 1)

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, -1:].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out, attn


    # def forecast_multi(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    #     if self.use_norm:
    #         # Normalization from Non-stationary Transformer
    #         means = x_enc.mean(1, keepdim=True).detach()
    #         x_enc = x_enc - means
    #         stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    #         x_enc /= stdev

    #     _, _, N = x_enc.shape
    #     channel_wise_patch_embed, n_vars, attn_mask = self.channel_wise_patch_embedding(
    #         x_enc.permute(0, 2, 1),
    #         masking=True  # Always create padding masks
    #     )
    #     # print("channel_wise_patch_embed.shape: ", channel_wise_patch_embed.shape)
    #     enc_out, attn = self.encoder(channel_wise_patch_embed, x_mask=attn_mask)
    #     self.latest_attention = attn
    #     # print("enc_out.shape: ", enc_out.shape)
    #     # Reshape output for the decoder
    #     enc_out = torch.reshape(
    #         enc_out, (-1, n_vars, enc_out.shape[-2]//n_vars, enc_out.shape[-1]))
    #     # z: [bs x nvars x d_model x patch_num]
    #     enc_out = enc_out.permute(0, 1, 3, 2)

    #     dec_out = self.head(enc_out)
    #     if self.use_norm:
    #         # De-Normalization from Non-stationary Transformer
    #         dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    #         dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        
    #     # return dec_out
    #     return dec_out, attn, 0 

    # def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    #     if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
    #         if self.features == 'M':
    #             dec_out, attn, _ = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec)
    #             return dec_out[:, -self.pred_len:, :], attn, 0  # [B, L, D]
    #             # for attention map 안 쓰는 모델
    #             # dec_out = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec)
    #             # return dec_out[:, -self.pred_len:, :]  # [B, L, D]
    #         else:
    #             # for attention map
    #             dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
    #             return dec_out[:, -self.pred_len:, :]  # [B, L, D]
    #             # for not attention map
    #             # dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
    #             # return dec_out[:, -self.pred_len:, :], attn  # [B, L, D]
                
    #     else:
    #         return None
    def forecast_multi(self, x_enc, x_mark_enc, x_dec, x_mark_dec, missing_flag):
        # print("x_enc.shape: ", x_enc.shape)
        B, L, n_vars = x_enc.shape
        min_sampling_rate = torch.min(torch.tensor(self.sampling_rates, device=x_enc.device)).item()
        # print("min_sampling_rate: ", min_sampling_rate)
        if self.use_norm:
            means_list = []
            stdevs_list = []
            for v in range(n_vars):
                sampling_rate = self.sampling_rates[v]
                sampling_factor = int(sampling_rate / min_sampling_rate)
                sample_indices = torch.arange(0, L, sampling_factor, device=x_enc.device)
                x_enc_sampled = x_enc[:, sample_indices, v]  # (B, L_sampled)
                mean_v = x_enc_sampled.mean(1, keepdim=True).detach()  # (B, 1)
                stdev_v = torch.sqrt(torch.var(x_enc_sampled, dim=1, keepdim=True, unbiased=False) + 1e-5)  # (B, 1)
                means_list.append(mean_v)
                stdevs_list.append(stdev_v)
            means = torch.cat(means_list, dim=1).unsqueeze(1)  # (B, 1, D)
            stdevs = torch.cat(stdevs_list, dim=1).unsqueeze(1)  # (B, 1, D)
            x_enc = x_enc - means
            x_enc /= stdevs
        # Generate embeddings with proper padding masks for zero-padded tokens
        channel_wise_patch_embed, num_patches_list, n_vars = self.channel_wise_patch_embedding(
            x_enc.permute(0, 2, 1), self.keep_prob, missing_flag
        )
        # print("channel_wise_patch_embed.shape: ", channel_wise_patch_embed.shape)
        # Pass the attention mask to the encoder for handling zero-padded tokens
            
        enc_out, attn, mask, num_tokens_list = self.encoder(channel_wise_patch_embed, num_patches_list, n_vars)
        # print("enc_out.shape: ", enc_out.shape)
        self.latest_attention = attn
        self.latest_mask = mask
        self.latest_num_tokens_list = num_tokens_list
        # Reshape output for the decoder
        # print("enc_out.shape: ", enc_out.shape)
        # enc_out = torch.reshape(
        #     enc_out, (enc_out.shape[0]//n_vars, n_vars, -1, enc_out.shape[-1]))
        # z: [bs x nvars x d_model x patch_num]
        # enc_out = enc_out.permute(0, 1, 3, 2)
        glb_token_list = []
        for i in range(n_vars):
            patch_idx = sum(num_tokens_list[:i])
            if self.num_global_tokens == 1:
                global_token_i = enc_out[:, patch_idx, :]
            else:
                global_token_i = enc_out[:, patch_idx:patch_idx+self.num_global_tokens, :]

            glb_token_list.append(global_token_i)
        glb_token_list = torch.stack(glb_token_list, dim=1)

        dec_out = self.head(glb_token_list, n_vars, num_global_tokens=self.num_global_tokens)  # z: [bs x nvars x target_window]
        # print("dec_out.shape: ", dec_out.shape)
        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdevs[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out, attn, mask
    
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, missing_flag=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            if self.features == 'M':
                dec_out, attn, mask = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec, missing_flag)
                # print("dec_out.shape: ", dec_out.shape)
                # print(dec_out[0, :, 0])
                return dec_out[:, -self.pred_len:, :]  # [B, L, D]
                # for attention map 안 쓰는 모델
                # dec_out = self.forecast_multi(x_enc, x_mark_enc, x_dec, x_mark_dec)
                # return dec_out[:, -self.pred_len:, :]  # [B, L, D]
            else:
                # for attention map
                dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
                return dec_out[:, -self.pred_len:, :]  # [B, L, D]
                # for not attention map
                # dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
                # return dec_out[:, -self.pred_len:, :], attn  # [B, L, D]
                
        else:
            return None