import torch
import torch.nn as nn
from .models import UNet, UNetCond

class Find(nn.Module):
    def __init__(self, ic, oc):
        super().__init__()
        self.m = UNet(ic, oc)
    
    def forward(self, x):
        return torch.tanh(self.m(x))

class FindMulti(nn.Module):
    def __init__(self, ic, oc, oc_2=1):
        super().__init__()
        self.oc = oc
        self.m = UNet(ic, oc + oc_2)
    
    def forward(self, x):
        out = self.m(x)
        # c, s = out[:, :self.oc], out[:, self.oc:]
        c = torch.tanh(out[:, :self.oc])
        s = out[:, self.oc:]
        return c, s

class FindMultiCond(nn.Module):
    def __init__(self, ic, oc, oc_2=1):
        super().__init__()
        self.oc = oc
        self.m = UNetCond(ic, oc + oc_2)
    
    def forward(self, x, cond):
        cond = torch.full((x.shape[0], 1), cond, dtype=torch.float32, device=x.device)
        out = self.m(x, cond)
        # c, s = out[:, :self.oc], out[:, self.oc:]
        c = torch.tanh(out[:, :self.oc])
        s = out[:, self.oc:]
        return c, s