import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class TensorConvLayer2d(nn.Module):
    """ tensor based convolution layer """
    def __init__(self, device, pad_h=0, pad_w=0, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[3, 3], output_tensor=[3, 3], out_type=0):
        super(TensorConvLayer2d, self).__init__()
        self.device = device
        self.pad_h = pad_h
        self.pad_w = pad_w
        self.kh = kh
        self.kw = kw
        self.stride_h = stride_h
        self.stride_w = stride_w
        self.out_channel = out_channel
        self.input_tensor = input_tensor
        self.output_tensor = output_tensor
        self.out_type = out_type

        #if len(self.output_tensor) == 2:
        if self.out_type == 0:
            self.weights = nn.Parameter(torch.empty([out_channel, in_channel, kh, kw, input_tensor[0], output_tensor[1]], dtype=torch.float32))
            nn.init.kaiming_normal_(self.weights, a=0, mode='fan_out')
        #elif len(self.output_tensor) == 3:
        elif self.out_type == 1:
            self.weights = nn.Parameter(torch.empty([out_channel, in_channel, kh, kw, input_tensor[0], output_tensor[1], output_tensor[2]], dtype=torch.float32))
            nn.init.kaiming_normal_(self.weights, a=0, mode='fan_out')
        elif self.out_type == 2:
            self.weights = nn.Parameter(torch.empty([out_channel, in_channel, kh, kw, input_tensor[1], input_tensor[0], output_tensor[0], output_tensor[1]], dtype=torch.float32))
            nn.init.kaiming_normal_(self.weights, a=0, mode='fan_out')
        else:
            raise Exception("output tensor shape is not supported yet")
        
    def forward(self, x):
        in_tensor_shape = x.size()
        output_h = (in_tensor_shape[2] + 2*self.pad_h - self.kh) // self.stride_h + 1
        output_w = (in_tensor_shape[3] + 2*self.pad_w - self.kw) // self.stride_w + 1
        assert output_h > 0
        assert output_w > 0
        #if len(self.output_tensor) == 2:
        if self.out_type == 0:
            out_tensor = np.zeros([in_tensor_shape[0], self.out_channel, output_h, output_w, self.input_tensor[0], self.output_tensor[1]])
            out_tensor = torch.from_numpy(out_tensor).float().to(self.device)
            for i in range(output_h):
                for j in range(output_w):
                    a0, a1, b0, b1, c0, c1, d0, d1 = cal_offset(i, self.stride_h, self.pad_h, self.kh, j, self.stride_w, self.pad_w, self.kw, in_tensor_shape[2], in_tensor_shape[3])
                    out_tensor[:, :, i, j, :, :] = torch.einsum('abcdef,gbcdfh->ageh', 
                                    x[:, :, a0:a1, b0:b1, :, :], 
                                    self.weights[:, :, c0:c1, d0:d1, :, :])
        #elif len(self.output_tensor) == 3:
        elif self.out_type == 1:
            out_tensor = np.zeros([in_tensor_shape[0], self.out_channel, output_h, output_w, self.input_tensor[0], self.output_tensor[1], self.output_tensor[2]])
            out_tensor = torch.from_numpy(out_tensor).float().to(self.device)
            for i in range(output_h):
                for j in range(output_w):
                    a0, a1, b0, b1, c0, c1, d0, d1 = cal_offset(i, self.stride_h, self.pad_h, self.kh, j, self.stride_w, self.pad_w, self.kw, in_tensor_shape[2], in_tensor_shape[3])
                    out_tensor[:, :, i, j, :, :, :] = torch.einsum('abcdef,gbcdfhi->agehi', 
                                    x[:, :, a0:a1, b0:b1, :, :], 
                                    self.weights[:, :, c0:c1, d0:d1, :, :, :])
        elif self.out_type == 2:
            out_tensor = np.zeros([in_tensor_shape[0], self.out_channel, output_h, output_w, self.output_tensor[0], self.output_tensor[1]])
            out_tensor = torch.from_numpy(out_tensor).float().to(self.device)
            for i in range(output_h):
                for j in range(output_w):
                    a0, a1, b0, b1, c0, c1, d0, d1 = cal_offset(i, self.stride_h, self.pad_h, self.kh, j, self.stride_w, self.pad_w, self.kw, in_tensor_shape[2], in_tensor_shape[3])
                    out_tensor[:, :, i, j, :, :] = torch.einsum('abcdef,gbcdfehi->aghi', 
                                    x[:, :, a0:a1, b0:b1, :, :], 
                                    self.weights[:, :, c0:c1, d0:d1, :, :, :, :])

        else:
          raise Exception("output tensor shape is not supported yet")
        return out_tensor


def squash(s, dim=-1):
	'''
	"Squashing" non-linearity that shrunks short vectors to almost zero length and long vectors to a length slightly below 1
	Eq. (1): v_j = ||s_j||^2 / (1 + ||s_j||^2) * s_j / ||s_j||
	
	Args:
		s: 	Vector before activation
		dim:	Dimension along which to calculate the norm
	
	Returns:
		Squashed vector
	'''
	squared_norm = torch.sum(s**2, dim=dim, keepdim=True)
	return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8)


class Squash(nn.Module):
    def __init__(self, dim):
        super(Squash, self).__init__()
        self.dim = dim

    def forward(self, x):
        #print(x.size())
        squared_norm = torch.sum(x**2, dim=self.dim, keepdim=False)
        #print(squared_norm.size())
        #return squared_norm / (1 + squared_norm) * x / (torch.sqrt(squared_norm) + 1e-8)
        return squared_norm

class SumLayer(nn.Module):
    def __init__(self, dim):
        super(SumLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        squared_norm = torch.sum(x**2, dim=self.dim, keepdim=True)
        return squared_norm / (1 + squared_norm) * x / (torch.sqrt(squared_norm) + 1e-8)


class MarginLoss(nn.Module):
    def __init__(self, size_average=False, loss_lambda=0.5):
        '''
        Margin loss for digit existence
        Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2
        Args:
        size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch.
        loss_lambda: parameter for down-weighting the loss for missing digits
        '''
        super(MarginLoss, self).__init__()
        self.size_average = size_average
        self.m_plus = 0.9
        self.m_minus = 0.1
        self.loss_lambda = loss_lambda

    def forward(self, inputs, labels):
        L_k = labels * F.relu(self.m_plus - inputs)**2 + self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2
        L_k = L_k.sum(dim=1)

        if self.size_average:
            return L_k.mean()
        else:
            return L_k.sum()

class TopkPooling(nn.Module):
    def __init__(self, k):
        super(TopkPooling, self).__init__()
        self.k = k
        self.flatten = nn.Flatten(0)

    def forward(self, x):
        y = self.flatten(x)
        n = y.size(dim=0)
        thre, idx = torch.kthvalue(y, int(n*self.k))
        #print(thre.item())
        mask = x.ge(thre.item())
        x = mask * x
        #exit()
        return x

class TensorBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(TensorBatchNorm, self).__init__()
        self.flatten = nn.Flatten(2)
        self.bn = nn.BatchNorm1d(num_features)

    def forward(self, x):
        in_tensor_shape = x.size()
        #print(in_tensor_shape)
        x = self.flatten(x)
        x = self.bn(x)
        return x.view(in_tensor_shape)

class TensorBatchNorm2d(nn.Module):
    def __init__(self, num_features):
        super(TensorBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)

    def forward(self, x):
        x = torch.movedim(x, (2, 3), (4, 5))
        tensor_shape = x.size()
        x = torch.flatten(x, start_dim=1, end_dim=3)
        x = self.bn(x)
        x = x.view(tensor_shape)
        x = torch.movedim(x, (4, 5), (2, 3))
        return x

class TensorBatchNorm3d(nn.Module):
    def __init__(self, num_features):
        super(TensorBatchNorm3d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)

    def forward(self, x):
        x = torch.movedim(x, (2, 3), (5, 6))
        tensor_shape = x.size()
        x = torch.flatten(x, start_dim=1, end_dim=4)
        x = self.bn(x)
        x = x.view(tensor_shape)
        x = torch.movedim(x, (5, 6), (2, 3))
        return x

class TensorBatchNorm4d(nn.Module):
    def __init__(self, num_features):
        super(TensorBatchNorm4d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)

    def forward(self, x):
        x = torch.movedim(x, (2, 3), (6, 7))
        tensor_shape = x.size()
        x = torch.flatten(x, start_dim=1, end_dim=5)
        x = self.bn(x)
        x = x.view(tensor_shape)
        x = torch.movedim(x, (6, 7), (2, 3))
        #print(x.size())
        return x

class TensorReLU(nn.Module):
    def __init__(self, device, dim):
        super(TensorReLU, self).__init__()
        self.flatten = nn.Flatten(2)
        self.dim = dim
        self.device = device

    def forward(self, x):
        in_tensor_shape = x.size()
        x = self.flatten(x)
        y = F.max_pool1d(x, kernel_size=self.dim)
        z = torch.where(y>0, 1, 0)
        m = torch.from_numpy(np.ones((in_tensor_shape[0], in_tensor_shape[1], self.dim))).float().to(self.device)
        mask = torch.einsum('abc,abd->abcd', z, m)
        return torch.mul(x.view(in_tensor_shape), mask.view(in_tensor_shape))

class TensorMaxPool(nn.Module):
    def __init__(self, kernel_size):
        super(TensorMaxPool, self).__init__()
        self.flatten = nn.Flatten(2)
        self.kernel_size = kernel_size

    def forward(self, x):
        in_tensor_shape = x.size()
        x = self.flatten(x)
        y = F.max_pool1d(x, kernel_size=self.kernel_size)
        #print(in_tensor_shape)
        #print(y.view(in_tensor_shape[:-4]).size())
        #return y.view(in_tensor_shape[:-4])
        return y.view(in_tensor_shape[:-2])


def cal_offset(i, stride_h, pad_h, kh, j, stride_w, pad_w, kw, h, w):
    h_offset = i*stride_h-pad_h
    w_offset = j*stride_w-pad_w
    h_weights_offset_start = 0
    h_weights_offset_end = kh
    w_weights_offset_start = 0
    w_weights_offset_end = kw
    if h_offset<0:
        h_offset_start = 0
        h_offset_end = h_offset+kh
        h_weights_offset_start = -h_offset
        assert h_weights_offset_start<kh
    elif h_offset+kh > h:
        h_offset_start = h_offset
        h_offset_end = h
        h_weights_offset_end = h_offset_end-h_offset_start
        assert h_weights_offset_end<kh
    else:
        h_offset_start = h_offset
        h_offset_end = h_offset+kh
    
    if w_offset<0:
        w_offset_start = 0
        w_offset_end = w_offset+kh
        w_weights_offset_start = -w_offset
        assert w_weights_offset_start<kw
    elif w_offset+kw > w:
        w_offset_start = w_offset
        w_offset_end = w
        w_weights_offset_end = w_offset_end-w_offset_start
        assert w_weights_offset_end<kw
    else:
        w_offset_start = w_offset
        w_offset_end = w_offset + kw
    #print(i, j, h_offset_start, h_offset_end, w_offset_start, w_offset_end, h_weights_offset_start, h_weights_offset_end, w_weights_offset_start, w_weights_offset_end)

    return h_offset_start, h_offset_end, w_offset_start, w_offset_end, h_weights_offset_start, h_weights_offset_end, w_weights_offset_start, w_weights_offset_end 


