#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""

import torch
import torch.nn as nn

class TransformLatenttoOrig(nn.Module):
    def __init__(self, dim_latent,dim_orig):
        super(TransformLatenttoOrig, self).__init__()
        self.dim_latent = dim_latent
        self.dim_orig = dim_orig
        self.dim_hidden = 10
        self.net = nn.Sequential(nn.Linear(self.dim_latent,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                 nn.Linear(self.dim_hidden,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                nn.Linear(self.dim_hidden,self.dim_hidden),
                                 #nn.Sigmoid(),
                                 nn.ReLU(),
                                 nn.Linear(self.dim_hidden, self.dim_orig),
                                 nn.Sigmoid(),
                                 #nn.ReLU(),
                                 )
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))

    
class TransformNet(nn.Module):
    def __init__(self, size):
        super(TransformNet, self).__init__()
        self.size = size
        self.hidden = 50
        self.net = nn.Sequential(nn.Linear(self.size,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,  self.size),
                                 )
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))