import torch
import torch.nn as nn
import torch.nn.functional as F

##############################################################################
# 1) Multi-Kernel MA Module
##############################################################################

class MultiKernelMAModule(nn.Module):
    """
    Learns H different kernels, each of length p, then merges them
    (like a multi-scale smoothing approach) in a single convolution call.
    """
    def __init__(self, p=5, H=3):
        super().__init__()
        self.p = p
        self.H = H
        # Now each of the H kernels is stored in one weight tensor of shape (H, 1, p).
        self.raw_weights = nn.Parameter(torch.zeros(H, 1, p))
        # Combination weights: shape (H,)
        self.raw_combination = nn.Parameter(torch.zeros(H))

    def forward(self, x):
        """
        x: [B, L]
        Returns: [B, L] after merging H different 1D convs
        """
        B, L = x.shape
        # 1) Expand x to (B,1,L) so that in_channels=1.
        x_3d = x.unsqueeze(1)   # shape: (B, 1, L)

        # 2) Constrain each kernel to sum to 1 via softmax along the kernel dimension
        #    so that raw_weights[i] is a valid MA kernel of length p.
        w = F.softmax(self.raw_weights, dim=-1)  # shape: (H, 1, p)

        # 3) A single convolution with out_channels=H, in_channels=1
        out_3d = F.conv1d(x_3d, w, padding=self.p - 1)  # shape: (B, H, L+p-1)
        out_3d = out_3d[:, :, :L]                       # shape: (B, H, L)

        # 4) Merge them via learned combination weights (also constrained with softmax)
        combine_w = F.softmax(self.raw_combination, dim=0)  # shape: (H,)
        # final_out has shape: (B, L)
        final_out = (out_3d * combine_w.view(1, -1, 1)).sum(dim=1)

        return final_out

##############################################################################
# 2) Adaptive RBF Sum Module
##############################################################################

class AdaptiveRBF(nn.Module):
    """
    Each sample in the batch has its own set of K lumps with:
      c_k(x), sigma_k(x), alpha_k(x).
    """
    def __init__(self, input_dim, L=96, K=10, hidden_dim=32):
        super().__init__()
        self.L = L
        self.K = K
        
        # MLP: outputs (c_k, log_sigma_k, alpha_k) for each sample => 3*K total
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 3 * K)
        )

    def forward(self, x):
        """
        x: [B, D]  => e.g., an embedding or a flattened representation
                      of the time series or a global context vector.
        Returns a [B, L] shape with the sum of K lumps for each sample.
        """
        B, D = x.shape
        params = self.mlp(x)  # => [B, 3*K]

        # Split into c, log_sigma, alpha => each [B, K]
        c = params[:, 0*self.K : 1*self.K]
        log_sigma = params[:, 1*self.K : 2*self.K]
        alpha = params[:, 2*self.K : 3*self.K]

        # Convert log_sigma => sigma
        sigma = torch.exp(log_sigma).clamp(min=1e-5)

        device = x.device
        t = torch.arange(self.L, device=device).float()  # [L]

        # Prepare shapes for broadcasting
        # c_3d, sigma_3d, alpha_3d => [B, K, 1]
        c_3d     = c.unsqueeze(-1)
        sigma_3d = sigma.unsqueeze(-1)
        alpha_3d = alpha.unsqueeze(-1)

        # t => [1, L], expand to [B, 1, L]
        t_3d = t.unsqueeze(0).repeat(B, 1, 1)

        # lumps => [B, K, L]
        diff = t_3d - c_3d
        lumps = torch.exp(-0.5 * (diff / sigma_3d)**2)
        lumps = alpha_3d * lumps  # scale by amplitude

        # sum over K => [B, L]
        sum_lumps = lumps.sum(dim=1)
        return sum_lumps



##############################################################################
# 3) A small Post-Projection MLP (Residual)
##############################################################################


class PostProjBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # self.norm1 = nn.LayerNorm(hidden_size)
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.act1 = nn.ReLU()
        self.norm2 = nn.LayerNorm(hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.act2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.act3 = nn.ReLU()

    def forward(self, x):
        residual = x
        # out = self.norm1(x)
        out = self.fc1(x)
        out = self.act1(out)
        # out = self.dropout(out)
        out = self.norm2(out)
        out = self.fc2(out + residual)
        out = self.act2(out)
        # out = self.norm1(out)
        # out = self.dropout(out)
        out = self.fc3(out + residual)
        out = self.act3(out)
        return out 


class SmallConvBlock(nn.Module):
    def __init__(self, channels=8, kernel=3):
        super().__init__()
        self.kernel = kernel
        # Determine if padding needs to be asymmetric
        if (kernel - 1) % 2 == 0:
            # Symmetric padding
            padding = (kernel - 1) // 2
            self.conv = nn.Conv1d(in_channels=1, 
                                  out_channels=channels, 
                                  kernel_size=kernel, 
                                  padding=padding)
            self.use_asymmetric = False
        else:
            # Asymmetric padding
            self.left_pad = (kernel - 1) // 2
            self.right_pad = (kernel - 1) - self.left_pad
            self.conv = nn.Conv1d(in_channels=1, 
                                  out_channels=channels, 
                                  kernel_size=kernel, 
                                  padding=0)  # No padding here
            self.use_asymmetric = True
        
        self.proj = nn.Conv1d(in_channels=channels, 
                              out_channels=1, 
                              kernel_size=1)  # No padding needed for kernel_size=1
        self.act = nn.ReLU()

    def forward(self, x):
        # x: [B, L] -> reshape for conv => [B, 1, L]
        x_3d = x.unsqueeze(1)     # [B, 1, L]
        
        if self.use_asymmetric:
            # Apply asymmetric padding
            x_padded = F.pad(x_3d, (self.left_pad, self.right_pad))
        else:
            x_padded = x_3d  # Symmetric padding is already handled by Conv1d

        y = self.conv(x_padded)   # [B, channels, L]
        y = self.act(y)
        y = self.proj(y)          # [B, 1, L]
        y = y.squeeze(1)          # [B, L]
        return x + y


class StackedSmallConv(nn.Module):
    def __init__(self, channels=4, kernel=3, num_blocks=2):
        super().__init__()
        self.blocks = nn.ModuleList([
            SmallConvBlock(channels=channels, kernel=kernel)
            for _ in range(num_blocks)
        ])
        self.gate = nn.Parameter(torch.tensor(0.0))
    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)  # each block => out + conv
        # Gate the residual
        g = torch.sigmoid(self.gate)  # => scalar in [0..1]
        return out + g * x


##############################################################################
# 4) The Enhanced MA+RBF Module that combines everything
##############################################################################

class MARBF_Enhanced(nn.Module):
    """
    Multi-Kernel MA + Multi-Scale RBF lumps + optional residual MLP
    => final [B, hidden_size].
    """
    def __init__(self, 
                 L=96, 
                 hidden_size=128,
                 p=5, 
                 H=2,
                 K_short=8, 
                 K_long=8,
                 residual=True,
                 channels=8,
                 kernel=3,
                 num_blocks=2,
                 hidden_dim=32, 
                 skip_alpha=0.0,
                 seasonal_patterns = None):
        super(MARBF_Enhanced, self).__init__()
        self.L = L
        self.hidden_size = hidden_size

        # Multi-kernel MA
        self.ma_module = MultiKernelMAModule(p=p, H=H)
        self.skip_alpha  =  skip_alpha

        # Multi-scale lumps
        self.rbf_module = AdaptiveRBF(K=K_short, L=L, input_dim=L, hidden_dim=K_short)
        self.residual_correction = StackedSmallConv(channels=channels,kernel=kernel, num_blocks=num_blocks)

        self.proj = nn.Linear(L, hidden_size)
        self.residual = residual
        if self.residual:
            self.post_block = PostProjBlock(hidden_size) 

    def forward(self, x):
        """
        x: [B, L, D]
        returns: [B, hidden_size]
        """
        B, L, D = x.shape
        x_univ = x.mean(dim=-1)           # [B,L]

        ma_out = self.ma_module(x_univ)   # multi-kernel MA
        #lumps_out = self.rbf_module(x_univ)    # multi-scale lumps => [B,L]
    
       # x_corrected = self.residual_correction(x_univ - ma_out+lumps_out)
        x_corrected = self.residual_correction(x_univ - ma_out)

        # Summation + small residual skip to raw input 
        out_sum = ma_out + x_corrected + self.skip_alpha*x_univ  #+ lumps_out  
        feat = self.proj(out_sum)         # [B, hidden_size]
        # feat = self.droput(feat)
        if self.residual:
            feat = self.post_block(feat)  # enhance with a residual MLP

        return feat


##############################################################################
# 5) Finally, your main Model class that does multi-task decoding
##############################################################################
class Model(nn.Module):
    """
    Multi-task model with the Enhanced MA+RBF approach inside.
    Tasks:
      - long_term_forecast
      - short_term_forecast
      - imputation
      - anomaly_detection
      - classification

    We keep the linear decoders, consistent with your existing code.
    """
    def __init__(self, configs, dropout=0.005, num_heads=4, orth_lambda=0.0):
        """
        orth_lambda: weight for orthogonality penalty if you want to add that
                     in your training loop. 
                     (You must handle it in the loss, e.g. loss += orth_lambda * pen.)
        """
        super(Model, self).__init__()

        self.task_name = configs.task_name
        self.c_out = configs.c_out
        self.pred_len = configs.pred_len
        self.L = configs.seq_len
        self.num_class = configs.c_out
        self.K_short = configs.K_short
        self.hidden_size = configs.hidden_size
        if configs.seasonal_patterns == "Weekly" :
            self.K_short = 12
        if configs.seasonal_patterns == "Daily" :
            self.K_short = 8
        if configs.seasonal_patterns == "Hourly" :
            self.K_short = 16
        if configs.seasonal_patterns == "Yearly" :
            self.K_short = 8 
        if configs.seasonal_patterns == "Quarterly" :
            self.K_short = 8    
        if configs.seasonal_patterns == "Monthly" :
            self.K_short = 12
        self.K_short = configs.K_short
        # self.K_short = configs.K_short
        # We build our enhanced module
        # You can tune these hyperparams
        p = max(self.L // 10, 5)       
        H = 3                          # how many kernels

        if configs.seasonal_patterns == "Yearly" :
            H = 4
        if configs.seasonal_patterns == "Quarterly" :
            H = 5    

        self.model = MARBF_Enhanced(L=self.L,
                                    hidden_size=self.hidden_size,
                                    p=p, H=H,
                                    K_short=self.K_short,
                                    channels=configs.channels,
                                    kernel=configs.kernel,
                                    num_blocks=configs.num_blocks,
                                    skip_alpha=configs.skip_alpha,
                                    residual=True,
                                    seasonal_patterns = configs.seasonal_patterns)

        # Decoders for each task
        if self.task_name in ['long_term_forecast', 'short_term_forecast']:
            self.forecast_decoder = nn.Linear(self.hidden_size, self.pred_len * self.c_out)

        else :
            print("Task not supported")
            return None

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        """
        x_enc: [B, L, D]
        returns either:
          - [B, pred_len, c_out] for forecast
          - [B, L, c_out] for imputation/anomaly
          - [B, num_class] for classification
        plus optionally you can retrieve orth penalty by self.lumps_2d externally.
        """
        # Simple per-sample normalization
        mean = x_enc.mean(dim=1, keepdim=True)
        x_enc_norm = x_enc - mean
        std = torch.sqrt(torch.var(x_enc_norm, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc_norm = x_enc_norm / std

        # forward => [B, hidden_size]
        final_feat = self.model(x_enc_norm)  # => [B, hidden_size]

        if self.task_name in ['long_term_forecast', 'short_term_forecast']:
            # => [B, pred_len*c_out]
            forecast_out = self.forecast_decoder(final_feat)
            dec_out = forecast_out.view(-1, self.pred_len, self.c_out)
            # Denormalize
            dec_out = dec_out * std[:, 0, :].unsqueeze(1) + mean[:, 0, :].unsqueeze(1)
            return dec_out

        return None
