import torch
import torch.nn as nn

class EnergyEnhancer(nn.Module):
    def __init__(self, seq_len, channels, embed_dim, lambda_init, alpha=0.5):
        super(EnergyEnhancer, self).__init__()
        self.seq_len = seq_len
        self.channels = channels
        self.embed_dim = embed_dim
        self.lambda_init = lambda_init
        self.scaling_matrix = nn.Parameter(torch.ones(int(self.seq_len / 2) + 1, self.channels))
        self.diff_fourier_denoiser = DiffFourierBlock(self.embed_dim, self.lambda_init)
        self.alpha = alpha

    def forward(self, x):
        '''
        input:
            x: (bs, seq_len, n_vars)

        return:
            x_denoised: (bs, seq_len, n_vars)
            x_inverse_fft: (bs, seq_len // 2 + 1, n_vars)
            loss_nonstat: loss of non-stationarity
        '''
        x_fft = torch.fft.rfft(x, dim=1)   # domain conversion
        x_inverse_fft = torch.flip(x_fft, dims=[1])  # flip the spectrum
        x_inverse_fft = x_inverse_fft * self.scaling_matrix
        x_enhanced_fft = self.freq_mix(x_fft, x_inverse_fft)
        x_denoised_fft, loss_nonstat = self.diff_fourier_denoiser(x_enhanced_fft)
        x_denoised = torch.fft.irfft(x_denoised_fft, dim=1) # (bs, seq_len, n_vars)
        return x_denoised, x_inverse_fft, loss_nonstat

    def freq_mix(self, x_fft, x_inverse_fft):
        phase1 = x_fft.angle()
        phase2 = x_inverse_fft.angle()
        amp1 = x_fft.abs()
        amp2 = x_inverse_fft.abs()
        amp = self.alpha * amp1 + (1 - self.alpha) * amp2
        phase = self.phase_mix(phase1, phase2)
        x_enhanced_fft = torch.polar(amp, phase) # Go back to fft
        return x_enhanced_fft

    def phase_mix(self, phase_x, phase_y):
        phase_difference = phase_x - phase_y
        dtheta = phase_difference % (2 * torch.pi)
        dtheta = torch.where(dtheta > torch.pi, dtheta - 2 * torch.pi, dtheta)
        clockwise = dtheta > 0
        sign = torch.where(clockwise, -1, 1)
        mixed_phase = phase_x + torch.abs(dtheta) * sign
        return mixed_phase

class DiffFourierBlock(nn.Module):
    """
    Enhanced Fourier Transform with embedding-based differential denoising
    First projects frequency components to embedding space, then performs differential denoising
    """
    def __init__(self, embed_dim=64, lambda_init=0.1):
        super(DiffFourierBlock, self).__init__()
        self.embed_dim = embed_dim
        self.lambda_param1 = nn.Parameter(torch.tensor(lambda_init))
        self.lambda_param2 = nn.Parameter(torch.tensor(lambda_init))
        self.freq_embedding = nn.Linear(1, embed_dim).to(torch.cfloat)
        self.freq_projection = nn.Linear(embed_dim, 1).to(torch.cfloat)
        self.softplus1 = nn.Softplus()
        self.softplus2 = nn.Softplus()
        self._reset_parameters()
    
    def _reset_parameters(self):
        """Initialize parameters for improved training stability"""
        # Initialize embedding layers
        nn.init.xavier_uniform_(self.freq_embedding.weight)
        nn.init.xavier_uniform_(self.freq_projection.weight)
    
    def forward(self, x_fft):
        """
        Embedding-based differential denoising
        Args:
            x_fft: Input tensor [B, F, C]
        Returns:
            Denoised signal in frequency domain [B, F, C]
        """
        B, F, C = x_fft.shape

        fft_flat = x_fft.reshape(-1, 1)     # [B*F*C, 1]
        emb_freq = self.freq_embedding(fft_flat) # [B*F*C, embed_dim]
        
        # Split into dual paths for denoising
        emb_freq1, emb_freq2 = emb_freq.chunk(2, dim=-1)  # Split along feature dimension: [B*F*C, embed_dim] -> [B*F*C, embed_dim//2]

        # Differential denoising in embedding space
        diff_embed1 = emb_freq1 - self.softplus1(self.lambda_param1) * emb_freq2
        diff_embed2 = emb_freq2 - self.softplus2(self.lambda_param2) * emb_freq1
        diff_embed = torch.concat([diff_embed1, diff_embed2], dim=-1)

        loss_nonstat = torch.std(torch.abs(diff_embed))

        # Project back from embedding space
        out = self.freq_projection(diff_embed)  # [B*F*C, 1]
        
        # Reshape back to original dimensions
        out = out.reshape(B, F, C)
        
        return out, loss_nonstat

class EnergyPredictor(nn.Module):
    def __init__(self, seq_len, pred_len, embed_dim_out=64):
        super(EnergyPredictor, self).__init__()
        self.fft_seq = int(seq_len / 2) + 1
        self.fft_pred = int(pred_len / 2) + 1
        self.embed_dim_out = embed_dim_out
        
        self.linear1 = nn.Linear(self.fft_seq, self.embed_dim_out).to(torch.cfloat)
        self.linear2 = nn.Linear(self.embed_dim_out + self.fft_pred, self.fft_pred).to(torch.cfloat)

    def forward(self, x_inverse_fft, y):
        '''
        x_inverse_fft: (bs, seq_len // 2 + 1, n_vars)
        y: (bs, pred_len, n_vars)
        out: (bs, pred_len, n_vars)
        '''
        y = y.permute(0, 2, 1) # (bs, n_vars, pred_len)
        x_inverse_fft = x_inverse_fft.permute(0, 2, 1) # (bs, n_vars, seq_len // 2 + 1)
        y_fft = torch.fft.rfft(y, dim=-1) # (bs, n_vars, pred_len // 2 + 1)
        inp = torch.cat([self.linear1(x_inverse_fft), y_fft], dim=-1) # (bs, n_vars, embed_dim_out + pred_len // 2 + 1)
        inp = self.linear2(inp).permute(0, 2, 1) # (bs, pred_len // 2 + 1, n_vars)
        out = torch.fft.irfft(inp, dim=1) # (bs, pred_len, n_vars)
        return out

if __name__ == "__main__":
    pred_len = 96
    seq_len = 96
    enc_in = 7
    batch_size = 32
    embed_dim = 64
    embed_dim_out = 64
    x = torch.randn(batch_size, seq_len, enc_in)
    denoiser = DiffFourierBlock(embed_dim=embed_dim, lambda_init=0.1)
    x_inverse_fft = torch.randn(batch_size, seq_len // 2 + 1, enc_in).to(torch.cfloat)

    import time
    t1 = time.time()
    out, loss_nonstat = denoiser(x_inverse_fft)
    t2 = time.time()
    print(f"Time taken for denoiser: {t2 - t1} seconds")
    predictor = EnergyPredictor(seq_len, pred_len, embed_dim_out)
    y = torch.randn(batch_size, pred_len, enc_in)
    out = predictor(x_inverse_fft, y)
    t3 = time.time()
    print(f"Time taken for predictor: {t3 - t2} seconds")
    print(f"Total time taken: {t3 - t1} seconds")