import torch
import torch.nn as nn
import math
import torch.nn.functional as F


class OurModel(nn.Module):
    def __init__(self, input_size,  order, output_size=1):
        super(OurModel, self).__init__() 
        self.input_size = input_size
        self.hidden_size = math.comb(input_size+order-1, order)
        self.output_size = output_size
        self.order=order

        self.encoder = nn.Linear(input_size, self.hidden_size, False)
        self.decoder = nn.Linear(order*self.hidden_size, output_size)

        #adopt orthogonal initialization
        # nn.init.orthogonal(self.encoder.weight)
        # nn.init.orthogonal(self.decoder.weight)

    def forward(self, input):
        output = self.encoder(input)
        out=-torch.unsqueeze(output,dim=1)
        for i in range(1,self.order):
            out=torch.cat((out,torch.pow(-output,i+1).unsqueeze_(dim=1)),1)
        out=torch.flatten(out, 1)
        output = self.decoder(out)
        return output

class OurModel_pooling(nn.Module):
    def __init__(self, input_size,  order, output_size=1):
        super(OurModel_pooling, self).__init__() 
        self.input_size = input_size
        self.hidden_size = math.comb(input_size+order-1, order)
        self.output_size = output_size
        self.order=order

        self.pooling = nn.MaxPool2d(3)
        self.encoder = nn.Linear(input_size, self.hidden_size, False)
        self.decoder = nn.Linear(order*self.hidden_size, output_size)

        #adopt orthogonal initialization
        # nn.init.orthogonal(self.encoder.weight)
        # nn.init.orthogonal(self.decoder.weight)

    def forward(self, input):
        input = torch.mean(input,1)
        #output = self.pooling(input).reshape(input.shape[0],-1)
        output = input.reshape(input.shape[0],-1)
        output = self.encoder(output)
        out=-torch.unsqueeze(output,dim=1)
        for i in range(1,self.order):
            out=torch.cat((out,torch.pow(-output,i+1).unsqueeze_(dim=1)),1)
        out=torch.flatten(out, 1)
        output = self.decoder(out)
        return output

class OurModel_app(nn.Module):
    def __init__(self, input_size,  order,hidden_size, output_size=1):
        super(OurModel_app, self).__init__() 
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.order=order

        self.encoder = nn.Linear(input_size, self.hidden_size, False)
        self.decoder = nn.Linear(order*self.hidden_size, output_size)


    def forward(self, input):
        output = self.encoder(input)
        out=-torch.unsqueeze(output,dim=1)
        for i in range(1,self.order):
            out=torch.cat((out,torch.pow(-output,i+1).unsqueeze_(dim=1)),1)
        out=torch.flatten(out, 1)
        output = self.decoder(out)
        return output


class Feed_foward_same_depth(nn.Module):
    def __init__(self, input_size,  hidden_size, number_of_hidden_layer,output_size=1):
        super(Feed_foward_same_depth, self).__init__() 
        self.hidden_size = hidden_size
        self.layer_first = nn.Linear(input_size, self.hidden_size)
        self.layer = nn.ModuleList()  
        for i in range(number_of_hidden_layer):
            self.layer.append(nn.Linear(self.hidden_size, self.hidden_size))
            self.layer.append(nn.ReLU())
        self.layer_last = nn.Linear(self.hidden_size, output_size)

    def forward(self, input):
        out = F.relu(self.layer_first(input))
        for layer in self.layer:
            out = layer(out)
        out = self.layer_last(out)
        return out

class Feed_foward_same_width(nn.Module):
    def __init__(self, input_size,  order, output_size=1):
        super(Feed_foward_same_width, self).__init__() 
        self.input_size = input_size
        self.hidden_size = math.comb(input_size+order-1, order)
        self.output_size = output_size

        self.layer1 = nn.Linear(input_size, self.hidden_size)
        self.layer2 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input):
        out = F.relu(self.layer1(input))
        out = self.layer2(out)
        return out

class Feed_foward_same_width_quadratic_activation(nn.Module):
    def __init__(self, input_size,  hidden_size, output_size=1):
        super(Feed_foward_same_width_quadratic_activation, self).__init__() 
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layer1 = nn.Linear(input_size, self.hidden_size)
        self.layer2 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input):
        out = torch.pow(self.layer1(input),2)
        out = self.layer2(out)
        return out

class Feed_foward_same_width_quadratic_activation(nn.Module):
    def __init__(self, input_size,  hidden_size, output_size=1):
        super(Feed_foward_same_width_quadratic_activation, self).__init__() 
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layer1 = nn.Linear(input_size, self.hidden_size)
        self.layer2 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input):
        out = torch.pow(self.layer1(input),2)
        out = self.layer2(out)
        return out

class Feed_foward_same_width_quadratic_activation_2(nn.Module):
    def __init__(self, input_size,  hidden_size, output_size=1):
        super(Feed_foward_same_width_quadratic_activation_2, self).__init__() 
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layer1 = nn.Linear(input_size, self.hidden_size)
        self.layer2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.layer3 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input):
        out = torch.pow(self.layer1(input),2)
        out = torch.pow(self.layer2(out),2)
        out = self.layer3(out)
        return out
