import torch.nn as nn
import torch
import math

class CustomSharedHyper(nn.Module):
    def __init__(self, shared_choice, in_channels=1):
        super(CustomSharedHyper, self).__init__()
        if shared_choice == 1:
            self.net = option_1()
            self.cnn_dim = 16

        elif shared_choice == 3:
            self.net = option_3()
            self.cnn_dim = 32

        elif shared_choice ==4:
            self.net = option_4(in_channel=in_channels)
            self.cnn_dim = 64 * 2 * 2

        elif shared_choice == 5:
            self.net = option_5()
            self.cnn_dim = 512

        else:
            raise NotImplementedError("choose shared hyper")

    def forward(self,x):
        x = self.net(x)
        return x
class individual_head(nn.Module):
    def __init__(self, previous_dim, final_dim, lora=False, intermediate_dim=1):
        super(individual_head, self).__init__()
        self.lora = lora
        if lora:
            self.d = math.ceil(math.sqrt(final_dim))
            self.e = math.ceil(final_dim / self.d)
            self.fca = nn.Linear(previous_dim, self.d * intermediate_dim)
            self.fcb = nn.Linear(previous_dim, self.e * intermediate_dim)
            self.extra = self.d * self.e - final_dim
        else:
            self.fc1 = nn.Linear(previous_dim, final_dim)


    def forward(self, x):
        if self.lora:
            x = x.view(x.size(0), -1)
            x1 = self.fca(x).reshape(x.shape[0], self.d, -1)
            x2 = self.fcb(x).reshape(x.shape[0], -1, self.e)
            x = torch.matmul(x1, x2).view(-1, self.d * self.e)
            if self.extra > 0:
                x = x[:, :(-1)*self.extra]
        else:
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
        return x

class option_1(nn.Module):
    def __init__(self):
        super(option_1, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16,16, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(16,  16, kernel_size=2)
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu1(x)
        x = self.conv3(x)
        x = self.pool(x)
        # we didn't use relu here, since we allow negative values in the parameters
        return x

class option_3(nn.Module):
    def __init__(self):
        super(option_3, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16,32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32,  32, kernel_size=2)
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu1(x)
        x = self.conv3(x)
        x = self.pool(x)
        # we didn't use relu here, since we allow negative values in the parameters.
        return x

class option_4(nn.Module):
    def __init__(self, in_channel):
        super(option_4, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 32, kernel_size=3)
        # self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32,64, kernel_size=3)
        # self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64,  64, kernel_size=3, stride =2 )


    def forward(self, x):
        x = self.conv1(x)
        # x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        # x = self.bn2(x)
        x = self.relu1(x)
        x=self.pool2(x)
        x = self.conv3(x)
        return x

class option_5(nn.Module):
    '''
    '''
    def __init__(self):
        super(option_5, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x=self.conv5(x)
        x = self.conv6(x)
        return x
