import torch
import torch.nn as nn
from .modules import ConvBlk
from .models import UNet, UNetCond

class Hide(nn.Module):
    def __init__(self, ic, ic_2):
        super().__init__()
        self.s_hider = nn.Sequential(
            ConvBlk(ic, 64),
            ConvBlk(64, 128),
            ConvBlk(128, 256),
            ConvBlk(256, 256))
        
        self.sc_hider = nn.Sequential(
            ConvBlk(ic_2+256, 256),
            ConvBlk(256, 128),
            ConvBlk(128, 64),
            ConvBlk(64, ic_2))
    
    def forward(self, s, c):
        s_feature = self.s_hider(s)
        o = torch.cat([c, s_feature], 1)
        o = self.sc_hider(o)
        return torch.tanh(o)

class Hide2(nn.Module):
    def __init__(self, ic, ic_2, residual=False):
        super().__init__()
        self.residual = residual
        self.m = UNet(ic+ic_2, ic_2)
    
    def forward(self, s, c):
        o = self.m(torch.cat([s, c], 1))
        if self.residual:
            o = o + c
        return torch.tanh(o)

class Hide2Cond(nn.Module):
    def __init__(self, ic, ic_2, residual=False):
        super().__init__()
        self.residual = residual
        self.m = UNetCond(ic+ic_2, ic_2)
    
    def forward(self, s, c, cond):
        cond = torch.full((s.shape[0], 1), cond, dtype=torch.float32, device=s.device)
        o = self.m(torch.cat([s, c], 1), cond)
        if self.residual:
            o = o + c
        return torch.tanh(o)