
import torch
from torch import nn
dtype = torch.cuda.FloatTensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")







class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, omega_0=1):
        super().__init__()
        self.omega_0 = omega_0
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            self.linear.weight.uniform_(-1 / self.in_features,1 / self.in_features)
            #self.linear.weight.uniform_(-1 , 1 )

    def forward(self, input):
        return torch.sin(torch.sin(self.omega_0 * self.linear(input)))




class FTM_3D(nn.Module):
    def __init__(self, R:tuple, omega=10):
        super(FTM_3D, self).__init__()
        self.r_1 = R[0]
        self.r_2 = R[1]
        self.r_3 = R[2]
        self._mode = "training"

        mid_channel = 1024
      
        self.U_net = nn.Sequential(SineLayer(1, mid_channel, omega_0=omega),
                                   SineLayer(mid_channel, mid_channel, omega_0=omega), nn.Dropout(0),
                                   nn.Linear(mid_channel, self.r_1), nn.Tanh())

        self.V_net = nn.Sequential(SineLayer(1, mid_channel, omega_0=omega),
                                   SineLayer(mid_channel, mid_channel, omega_0=omega), nn.Dropout(0),
                                   nn.Linear(mid_channel, self.r_2), nn.Tanh())
        
        self.W_net = nn.Sequential(SineLayer(1, mid_channel, omega_0=omega),
                                   SineLayer(mid_channel, mid_channel, omega_0=omega), nn.Dropout(0),
                                   nn.Linear(mid_channel, self.r_3), nn.Tanh())
        



    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        if mode not in ["training", "sampling"]:
            raise ValueError("Mode should be 'training' or 'sampling'")
        self._mode = mode


    def kronecker_product_einsum_batched(self, A: torch.Tensor, B: torch.Tensor):
        """
        Batched Version of Kronecker Products
        :param A: has shape (b, a, c)
        :param B: has shape (b, k, p)
        :return: (b, ak, cp)
        """
        assert A.dim() == 3 and B.dim() == 3

        res = torch.einsum("bac,bkp->bakcp", A, B).view(A.size(0),
                                                        A.size(1) * B.size(1),
                                                        A.size(2) * B.size(2))
        return res

    def forward(self, input_ind_train=None, input_ind_sampl=None):
        # input_ind_train: (U_ind_batch, V_ind_batch, W_ind_batch)
        # U_ind_batch: B * 1
        # V_ind_batch: B * 1
        # W_ind_batch: B * 1
        if self._mode == "training":
            U = self.U_net(input_ind_train[0].unsqueeze(1))  # B  * r_1
            V = self.V_net(input_ind_train[1].unsqueeze(1))  # B  * r_2
            W = self.W_net(input_ind_train[2].unsqueeze(1)) # B * r_3
            return (U,V,W)
        elif self._mode == "sampling":
        # input_ind_sampl: B * 3
            U = self.U_net(input_ind_sampl[:,:1]).unsqueeze(1)  # B * 1 * r_1
            V = self.V_net(input_ind_sampl[:,1:2]).unsqueeze(1)  # B * 1 * r_2
            W = self.W_net(input_ind_sampl[:,2:3]).unsqueeze(1) # B * 1 * r_3
            UV = self.kronecker_product_einsum_batched(U, V)
            UVW = self.kronecker_product_einsum_batched(UV, W).squeeze(1)
            return UVW

    
