import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ANet(nn.Module):

    def __init__(self, args):
        super(ANet, self).__init__()
        self.args = args
        self.MSRN = MSRN(args)

    def MSMN(self, x):

        return x

    def forward(self, x):
        x = self.MSRN(x)
        x = self.MSMN(x)
        return x

class MSRN(nn.Module):
    def __init__(self, args):
        super(MSRN, self).__init__()
        self.args = args

        # Spatial transformer localization-network
        # self.localization = nn.Sequential(
        #     nn.Conv2d(3, 8, kernel_size=5),
        #     nn.MaxPool2d(2, stride=2),
        #     nn.ReLU(True),
        #     nn.Conv2d(8, 10, kernel_size=5),
        #     nn.MaxPool2d(2, stride=2),
        #     nn.ReLU(True)
        # )

        self.localization = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 16, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(16, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(32, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(32, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(32 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3)
        )

        # Initialize the weights/bias with identity transformation
        #self.fc_loc[2].weight.data.zero_()
        #self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def forward(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 32 * 3 * 3)
        theta = self.fc_loc(xs)
        eps = 1e-12

        #theta = theta.view(-1, 2, 3)
        #theta = torch.ones(64,3).float()
        theta = theta.view(-1, 3)
        rot_x = []
        theta = torch.tanh(theta)
        for i in range(3):
            theta_i = theta[:, i]
            #theta_i[theta_i == 0] += eps
            theta_i = theta_i * 10 + torch.sign(theta_i) * 10 * (i)
            theta_i = theta_i * math.pi / 180

            #print(theta_i)
            rot_matrix = torch.stack([
                torch.cos(theta_i), -torch.sin(theta_i), torch.zeros_like(theta_i),
                torch.sin(theta_i), torch.cos(theta_i), torch.zeros_like(theta_i)
            ], dim=1).view(-1, 2, 3)

            grid = F.affine_grid(rot_matrix, x.size()).cuda()
            x_i = F.grid_sample(x, grid)
            rot_x.append(x_i)

        x = torch.stack(rot_x, dim=1)

        return x

class MSMN(nn.Module):
    def __init__(self, args):
        super(MSMN, self).__init__()
        self.args = args


    # Spatial transformer network forward function
    def forward(self, x):

        return x
