import torch
import torch.nn as nn
from modules.sonic import Sonic
from typing import Union
from utils.utils import DropPath, safe_gn_groups


class SonicBlock(nn.Module):
    def __init__(
        self,
        dim: int = 2,
        in_channels: int = 3,
        out_channels: int = 64,
        sonic_kwargs: dict | None = None,
        droppath_p: float = 0.0,
        norm_groups: int | None = None,
        conv_block: nn.Module = nn.Conv2d,
        depth_idx: int = 0,
        depth_total: int = 1,
    ):
        """
        Sonic Block with GroupNorm, GELU, Sonic, Residual connection, and DropPath.
        dim: 2 or 3 for 2D or 3D data, higher dims not implemented
        in_channels: input channels
        out_channels: hidden or output channels
        sonic_kwargs: dict of kwargs for Sonic module
        droppath_p: drop mode prob (0.0 = no drop path)
        norm_groups: number of groups for GroupNorm (None = auto)
        conv_block: convolution block to use (nn.Conv2d or nn.Conv3d)
        depth_idx: index of this block in the network (0 = first, depth_idx=depth-1 = last, this is used to set the initialization of the Sonic layer, first layers more global receptive field, later layers more local)
        depth_total: total number of blocks in the network (used with depth_idx)
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        g = norm_groups if (norm_groups is not None and in_channels % norm_groups == 0) else safe_gn_groups(in_channels)

        self.norm = nn.GroupNorm(num_groups=g, num_channels=in_channels, affine=True)
        self.act  = nn.GELU()

        self.sonic = Sonic(dim=dim, in_channels=in_channels, num_hidden=out_channels, depth_idx=depth_idx, depth_total=depth_total, **sonic_kwargs)

        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = conv_block(in_channels, out_channels, kernel_size=1, bias=False)

        self.droppath = DropPath(droppath_p, dim=dim)

    def forward(self, x):
        y = self.sonic(self.act(self.norm(x)))
        y = self.droppath(y)
        return self.skip(x) + y

class SonicNet(nn.Module):
    def __init__(self, dim=2, n_channels=3, num_hidden=64, num_classes=2, depth=4, sonic_kwargs={}, droppath_max=0.1, norm_groups=None):
        """
        SonicNet with multiple SonicBlocks, final GroupNorm, GELU, and Conv layer to num_classes.
        dim: 2 or 3
        n_channels: input channels
        num_hidden: hidden channels in SonicBlocks, now the same for all blocks
        num_classes: output channels
        depth: number of SonicBlocks
        sonic_kwargs: dict of kwargs for Sonic module
        droppath_max: final block drop path prob (linearly scaled 0 -> droppath_max)
        norm_groups: number of groups for GroupNorm (None = auto)
        
        """
        super().__init__()
        self.model_type=None
        self.dim = dim
        self.conv_block= nn.Conv2d if dim==2 else nn.Conv3d

        blocks = []

        blocks.append(
            SonicBlock(
                dim=dim,
                in_channels=n_channels,
                out_channels=num_hidden,
                sonic_kwargs=sonic_kwargs,
                droppath_p=0.0 if depth == 1 else (0.0),
                norm_groups=norm_groups,
                conv_block=self.conv_block,
                depth_idx=0,
                depth_total=depth
            )
        )

        for i in range(1, depth):
            dp = (droppath_max * i) / (depth - 1) if depth > 1 else 0.0
            blocks.append(
                SonicBlock(
                    dim=dim,
                    in_channels=num_hidden,
                    out_channels=num_hidden,
                    sonic_kwargs=sonic_kwargs,
                    droppath_p=dp,
                    norm_groups=norm_groups,
                    conv_block=self.conv_block,
                    depth_idx=i,
                    depth_total=depth
                )
            )

        self.blocks = nn.ModuleList(blocks)

        g = norm_groups if (norm_groups is not None and num_hidden % norm_groups == 0) else safe_gn_groups(num_hidden)

        self.final_norm = nn.GroupNorm(num_groups=g, num_channels=num_hidden, affine=True)
        self.final_act  = nn.GELU()
        self.final_proj = self.conv_block(num_hidden, num_classes, kernel_size=3, padding=1, bias=True)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.final_act(self.final_norm(x))
        x = self.final_proj(x)
        return x

if __name__ == "__main__":
    """ Example usage """
    sonic_kwargs={
        "M_modes":16,
        "normalize_input":True,
        "dy":1.0, #if there is any info on H spatial resolution, use it here, otherwise assume 1.0
        "dx":1.0, #...
        "dz":1.0, #...
        "dropout_p":0.05,
        "fix_v":False, #fixes orientation to \pi\n
        "v_noise":0.0,
        "rho":0.0 
        }

    net = SonicNet(
        dim=2,
        depth=4,
        n_channels=3,
        num_hidden=32,
        num_classes=2,
        sonic_kwargs=sonic_kwargs,
        
    )
    n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f"SonicNet model with {n_params} parameters")

