import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from utils.utils import ModeDropout, normalize_input, pad_input, get_freq_grids_2d, get_freq_grids_3d
from utils.vkfft import VkFFTBackend

torch.autograd.set_detect_anomaly(True)

class Sonic(nn.Module):
    def __init__(self, dim=2, in_channels=3, num_hidden=64, M_modes=12,
                 normalize_input=True, dx=1.0, dy=1.0, dz=1.0,
                 blockdiag_per_channel=False, dropout_p=0.0,
                 dtype=torch.float32, fix_v=False, v_noise=0.05, rho=1.0,
                 depth_idx: int = 0,
                 depth_total: int = 1,                  alpha_range=(0.2, 2.0),                tau_range=(0.20, 0.01),                per_mode_jitter=0.03,               
                 set_beta_zero=True):
        """
        Sonic element
        dim: 2 or 3
        in_channels: input channels
        num_hidden: output channels
        M_modes: number of spectral modes
        normalize_input: whether to normalize input
        dx, dy, dz: grid spacing (physical units, leave 1.0 if unknown)
        blockdiag_per_channel: whether to use block-diagonal mixer mask
        dropout_p: mode dropout probability
        fix_v: whether to fix directions v or learn them
        v_noise: initial noise for directions v if not fixed
        rho: scaling for imaginary part of α (damping oscillations)
        depth_idx, depth_total: for init scheduling
        alpha_range, tau_range, per_mode_jitter, set_beta_zero: for init scheduling
        """
        super().__init__()
        self.C, self.dim, self.M = int(in_channels), int(dim), int(M_modes)
        self.normalize_input = bool(normalize_input)
        self.dx, self.dy, self.dz = float(dx), float(dy), float(dz)
        self.dtype = dtype
        self.mode_dropout = ModeDropout(dropout_p) if dropout_p > 0 else nn.Identity()
        self.fix_v = bool(fix_v)
        self.v_noise = float(v_noise)
        self.fft_backend = VkFFTBackend()
        self.K = int(num_hidden)
        self.rho_val = float(rho)

        # --- optional block-diagonal mixer mask ---
        if blockdiag_per_channel:
            groups = torch.tensor_split(torch.arange(self.M), self.C)
            mask = torch.zeros(self.M, self.C)
            for c, g in enumerate(groups):
                mask[g, c] = 1.0
            self.register_buffer('Bmask', mask)
        else:
            self.register_buffer('Bmask', None)

        # --- directions v ---
        if self.dim == 2:
            base = torch.linspace(0, np.pi, steps=self.M + 2, dtype=torch.float32)[1:-1] # semi-circle
            vx0 = torch.cos(base); vy0 = torch.sin(base)
            if not self.fix_v:
                # add noise and renormalize
                # self.fix_v is not used in the experiments but can be usefull for environments with strong, fixed anisotropy
                vx0 = vx0 + self.v_noise * torch.randn(self.M, dtype=torch.float32)
                vy0 = vy0 + self.v_noise * torch.randn(self.M, dtype=torch.float32)
                v = torch.stack([vx0, vy0], dim=0)
                v = v / (v.norm(dim=0, keepdim=True) + 1e-8)
                self.vx = nn.Parameter(v[0].to(self.dtype))
                self.vy = nn.Parameter(v[1].to(self.dtype))
            else:
                self.register_buffer('vx', vx0.to(self.dtype))
                self.register_buffer('vy', vy0.to(self.dtype))
        else:
            # Fibonacci sphere sampling for near-uniform directions on 3D unit sphere
            k  = torch.arange(self.M, dtype=torch.float32) + 0.5
            phi = 2.0 * np.pi * (k / ((1 + 5 ** 0.5) / 2.0))
            z   = 1.0 - (2.0 * k / self.M)
            r   = torch.sqrt((1.0 - z * z).clamp_min(0))
            x, y = r * torch.cos(phi), r * torch.sin(phi)
            v0 = torch.stack([x, y, z], dim=0)
            v0 = v0 / (v0.norm(dim=0, keepdim=True) + 1e-8)
            if not self.fix_v:
                # add noise and renormalize
                v0 = v0 + self.v_noise * torch.randn_like(v0)
                v0 = v0 / (v0.norm(dim=0, keepdim=True) + 1e-8)
                self.vx = nn.Parameter(v0[0].to(self.dtype))
                self.vy = nn.Parameter(v0[1].to(self.dtype))
                self.vz = nn.Parameter(v0[2].to(self.dtype))
            else:
                self.register_buffer('vx', v0[0].to(self.dtype))
                self.register_buffer('vy', v0[1].to(self.dtype))
                self.register_buffer('vz', v0[2].to(self.dtype))

        # --- mixers (unit-complex rows/cols) ---
        def _unit_complex(shape, dim_to_normalize):
            real = torch.randn(*shape, dtype=torch.float32)
            imag = torch.randn(*shape, dtype=torch.float32)
            denom = torch.sqrt((real**2 + imag**2).sum(dim=dim_to_normalize, keepdim=True)).clamp_min(1e-12)
            real = (real / denom).to(self.dtype)
            imag = (imag / denom).to(self.dtype)
            return real, imag

        C_re, C_im = _unit_complex((self.K, self.M), dim_to_normalize=0)
        self.C_re = nn.Parameter(C_re)
        self.C_im = nn.Parameter(C_im)

        B_re, B_im = _unit_complex((self.M, self.C), dim_to_normalize=1)
        if self.Bmask is not None:
            # B mask learns 1-to-1 mapping per channel
            mask = self.Bmask.to(B_re.dtype)
            B_re = B_re * mask
            B_im = B_im * mask
            denom = torch.sqrt((B_re**2 + B_im**2).sum(dim=1, keepdim=True)).clamp_min(1e-12)
            keep = denom > 1e-12
            B_re[keep] /= denom[keep]; B_im[keep] /= denom[keep]
            dead = (~keep.squeeze(-1))
            if dead.any():
                rr, ii = _unit_complex((dead.sum().item(), self.C), dim_to_normalize=1)
                B_re[dead] = rr; B_im[dead] = ii
        self.B_re = nn.Parameter(B_re)
        self.B_im = nn.Parameter(B_im)

        # later we use softplus(·) to ensure positivity of alpha, tau, scale, for initialization we need softplus^{-1}(·) to range-map
        def inv_softplus(y):
            y = torch.clamp(torch.as_tensor(y, dtype=torch.float32), min=1e-12)
            return torch.log(torch.expm1(y))

        # depth fraction t in [0,1]
        t = float(depth_idx) / max(1, depth_total - 1) if depth_total > 1 else 0.0

        # α schedule (forward-space): small → big with depth (log-space)
        a_min, a_max = map(float, alpha_range)
        a_min = max(a_min, 1e-6)  
        a_fwd = a_min * ((a_max / a_min) ** t)

        # τ schedule (forward-space): big → small with depth (log-space),
        tau_hi, tau_lo = float(tau_range[0]), float(tau_range[1])
        tau_lo = max(tau_lo, 1e-6)
        tau_fwd_sched = tau_hi * ((tau_lo / tau_hi) ** t)
        # physical-space τ (isotropic diffusion) for reference
        tau_phys = 1.0 / (self.dim * (np.pi ** 2))
        tau_fwd = 0.9 * tau_fwd_sched + 0.1 * tau_phys

        # scale from grid spacing to ensure is globally stable
        max_d = max(self.dx, self.dy) if self.dim == 2 else max(self.dx, self.dy, self.dz)
        s_target = 0.25 * (2.0 * np.pi / max(max_d, 1e-12))

        alpha_raw = inv_softplus(a_fwd).repeat(self.M).to(self.dtype)
        tau_raw   = inv_softplus(tau_fwd).repeat(self.M).to(self.dtype)
        scale_raw = inv_softplus(s_target).repeat(self.M).to(self.dtype)

        # optional tiny per-mode jitter to promote diversity
        if per_mode_jitter and per_mode_jitter > 0:
            j = float(per_mode_jitter)
            with torch.no_grad():
                alpha_raw += j * torch.randn_like(alpha_raw) * alpha_raw.abs().clamp_min(1e-3)
                tau_raw   += j * torch.randn_like(tau_raw)   * tau_raw.abs().clamp_min(1e-3)
                scale_raw += j * torch.randn_like(scale_raw) * scale_raw.abs().clamp_min(1e-3)

        self.alpha     = nn.Parameter(alpha_raw)
        self.tau_raw   = nn.Parameter(tau_raw)
        self.scale_raw = nn.Parameter(scale_raw)

        # beta init (zero or small random)
        if set_beta_zero:
            self.beta_raw = nn.Parameter(torch.zeros(self.M, dtype=self.dtype))
        else:
            self.beta_raw = nn.Parameter(0.01 * torch.randn(self.M, dtype=self.dtype))
        self.rho = nn.Parameter(torch.tensor(self.rho_val, dtype=self.dtype))

    def _get_params(self, real_mixers=False):
        complex_dtype = (
            torch.complex64
            if self.dtype in (torch.float16, torch.bfloat16, torch.float32)
            else torch.complex128
        )

        a_re = -(F.softplus(self.alpha))
        tau  =  (F.softplus(self.tau_raw))
        s    =  (F.softplus(self.scale_raw))
        a_im = self.rho * torch.tanh(self.beta_raw.float())
        a    = torch.complex(a_re, a_im).to(complex_dtype)

        if self.dim == 3:
            v = torch.stack([self.vx.float(), self.vy.float(), self.vz.float()], dim=0)
        else:
            v = torch.stack([self.vx.float(), self.vy.float()], dim=0)
        v = (v / (v.norm(dim=0, keepdim=True) + 1e-6)).to(self.dtype)

        B = torch.complex(
            self.B_re.float(),
            self.B_im.float() if not real_mixers else torch.zeros_like(self.B_im),
        ).to(complex_dtype)
        C_mixer = torch.complex(
            self.C_re.float(),
            self.C_im.float() if not real_mixers else torch.zeros_like(self.C_im),
        ).to(complex_dtype)
        if self.Bmask is not None:
            B = B * self.Bmask.to(dtype=B.dtype, device=B.device)

        return a, s, tau, v, B, C_mixer

    def forward(self, x, pad_linear=False, block_h=None, **kwargs):
        if self.normalize_input:
            x = normalize_input(self.dim, x)

        # pad to avoid wrap-around artifacts
        x, D, H, W = pad_input(self.dim, x, pad_linear)

        dx_eff = float(kwargs.get('dx', self.dx))
        dy_eff = float(kwargs.get('dy', self.dy))
        if self.dim == 3:
            dz_eff = float(kwargs.get('dz', self.dz))

        a, s, tau, v, B_mixer, C_mixer = self._get_params(real_mixers=False)

        # spacing can be changed during training, this makes it possible to train on low resolution and inference on high resolution or progressively train on increasing resolution
        if ('dx' in kwargs) or ('dy' in kwargs) or ('dz' in kwargs):
            base = max(self.dx, self.dy) if self.dim == 2 else max(self.dx, self.dy, self.dz)
            eff  = max(dx_eff, dy_eff)   if self.dim == 2 else max(dx_eff, dy_eff, dz_eff)
            s = s * (base / eff)

        if self.dim == 2:
            return self._forward_2d(x, H, W, dx_eff, dy_eff, block_h, a, s, tau, v, B_mixer, C_mixer)
        else:
            return self._forward_3d(x, D, H, W, dz_eff, dx_eff, dy_eff, block_h, a, s, tau, v, B_mixer, C_mixer)

    def _forward_2d(self, x, H, W, dx_eff, dy_eff, block_h, a, s, tau, v, B_mixer, C_mixer):
        Hp, Wp = x.shape[-2:]
        OX, OY = get_freq_grids_2d(Hp, Wp, dx_eff, dy_eff, x.device, self.dtype)

        complex_dtype = torch.complex64 if self.dtype in (torch.float16, torch.bfloat16, torch.float32) else torch.complex128

        with torch.amp.autocast(device_type='cuda', enabled=False):
            Xf = self.fft_backend.rfftn(x.to(self.dtype), dim=(-2, -1))

        Bsz, C, Wq = Xf.shape[0], Xf.shape[1], Xf.shape[-1]
        K, M = int(self.K), int(v[0].numel())
        B_tc = B_mixer.transpose(0, 1).contiguous() if B_mixer.shape == (M, C) else B_mixer.contiguous()

        # we possibly update our directions to be aware of resolution (physical space)
        scale = torch.tensor([dx_eff, dy_eff], device=v.device, dtype=v.dtype)[:, None]
        v_phys = v / (scale + 1e-8)
        #unit
        v_phys = v_phys / (v_phys.norm(dim=0, keepdim=True) + 1e-8)
        v_vx, v_vy = v_phys[0][None, None, :], v_phys[1][None, None, :]

        s_s, t_t, a_a = s[None, None, :], tau[None, None, :], a[None, None, :]

        # To fit in memory, we CAN process the frequency domain data in horizontal slabs of height block_h
        Yf_total = x.new_zeros((Bsz, K, Hp, Wq), dtype=complex_dtype)
        bh = Hp if (block_h is None or block_h >= Hp) else int(block_h)

        for y0 in range(0, Hp, bh):
            y1 = min(Hp, y0 + bh)

            # slice frequency grids
            OX_sl, OY_sl = OX[y0:y1, :, None], OY[y0:y1, :, None]

            dot  = OX_sl * v_vx + OY_sl * v_vy
            wn2  = OX_sl**2 + OY_sl**2
            wperp = (wn2 - dot**2).clamp_min(0.0)

            denom = 1j * (s_s * dot) - a_a + t_t * wperp
            # normalize by magnitude squared to avoid amplifying noise at high frequencies
            magn_dn = (denom.real.square() + denom.imag.square()).clamp_min(1e-8)
            T = denom.conj() / magn_dn 

            # Again normalize the transfer function 
            T2  = (T.real**2 + T.imag**2)
            rms = torch.sqrt(T2.mean(dim=(0, 1), keepdim=True).clamp_min(1e-8))
            T = T / rms
            T = T.permute(2, 0, 1).contiguous()  

            # Mixing
            Xf_ = Xf[:, :, y0:y1, :]
            U  = torch.einsum('bchq,cm->bmhq', Xf_, B_tc)
            V  = U* T.unsqueeze(0)
            V  = self.mode_dropout(V)
            Yf = torch.einsum('km,bmhq->bkhq', C_mixer, V)
            Yf_total[:, :, y0:y1, :] = Yf

        # ensure real output
        if Wq > 0:    Yf_total[..., 0].imag.zero_()
        if Wp % 2==0: Yf_total[..., -1].imag.zero_()

        with torch.amp.autocast(device_type='cuda', enabled=False):
            y_spatial = self.fft_backend.irfftn(Yf_total, s=(Hp, Wp), dim=(-2, -1))
        return y_spatial[..., :H, :W].contiguous()

    def _forward_3d(self, x, D, H, W, dz_eff, dx_eff, dy_eff, block_h, a, s, tau, v, B_mixer, C_mixer, k_chunk: int = 8):
        """
        Processes 3D volumetric data in the frequency domain.
        x has the shape BxCxDxHxW.
        for comments see _forward_2d, the logic is similar but with an additional dimension.
        """
        Dp, Hp, Wp = x.shape[-3:]
        OZ, OY, OX = get_freq_grids_3d(Dp, Hp, Wp, dz_eff, dx_eff, dy_eff, x.device, self.dtype)

        complex_dtype = torch.complex64 if self.dtype in (torch.float16, torch.bfloat16, torch.float32) else torch.complex128
        with torch.amp.autocast(device_type='cuda', enabled=False):
            Xf = self.fft_backend.rfftn(x.to(self.dtype), dim=(-3, -2, -1)).contiguous()


        Bsz, C, _, _, Wq = Xf.shape
        K, M = int(self.K), int(v[0].numel())
        B_tc = B_mixer.transpose(0, 1).contiguous() if B_mixer.shape == (M, C) else B_mixer.contiguous()

        scale = torch.tensor([dx_eff, dy_eff, dz_eff], device=v.device, dtype=v.dtype)[:, None]
        v_phys = v / (scale + 1e-8)
        v_phys = v_phys / (v_phys.norm(dim=0, keepdim=True) + 1e-8)
        v_x, v_y, v_z = v_phys[0][None, None, None, :], v_phys[1][None, None, None, :], v_phys[2][None, None, None, :]

        s_s, t_t, a_a = s[None, None, None, :], tau[None, None, None, :], a[None, None, None, :]

        Yf_total = x.new_zeros((Bsz, K, Dp, Hp, Wq), dtype=complex_dtype)
        bh = Hp if (block_h is None or block_h >= Hp) else int(block_h)

        for y0 in range(0, Hp, bh):
            y1 = min(Hp, y0 + bh)

            OZ_sl = OZ[:, y0:y1, :, None]
            OX_sl = OX[:, y0:y1, :, None]
            OY_sl = OY[:, y0:y1, :, None]

            dot  = OX_sl * v_x + OY_sl * v_y + OZ_sl * v_z
            wn2  = OZ_sl**2 + OX_sl**2 + OY_sl**2
            wperp = (wn2 - dot**2).clamp_min(0.0)

            denom = 1j * (s_s * dot) - a_a + t_t * wperp
            denom_mag2 = (denom.real.square() + denom.imag.square()).clamp_min(1e-8)
            T = denom.conj() / denom_mag2 

            T2  = (T.real**2 + T.imag**2)
            rms = torch.sqrt(T2.mean(dim=(-3, -2, -1), keepdim=True).clamp_min(1e-8)) 
            T = T / rms
            T = T.permute(3, 0, 1, 2).contiguous() 

            Xf_= Xf[:, :, :, y0:y1, :] 
            U  = torch.einsum('bcdhq,cm->bmdhq', Xf_, B_tc)
            V = U * T.unsqueeze(0)
            V  = self.mode_dropout(V)
            Yf = torch.einsum('km,bmdhq->bkdhq', C_mixer, V)

            Yf_total[:, :, :, y0:y1, :] = Yf

        if Wq > 0:    Yf_total[..., 0].imag.zero_()
        if Wp % 2==0: Yf_total[..., -1].imag.zero_()

        with torch.amp.autocast(device_type='cuda', enabled=False):
            y_spatial = self.fft_backend.irfftn(Yf_total, s=(Dp, Hp, Wp), dim=(-3, -2, -1))
        return y_spatial[..., :D, :H, :W].contiguous()


