import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class ConvNet(nn.Module):
    def __init__(self, args):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, stride=1, padding  = 1)
        self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding  = 1)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding  = 1)
        self.fc1 = nn.Linear(1568, 128)
        self.fc2 = nn.Linear(128, 10)
        init_weights(self, args.initialization)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = x
        return output

class ConvNet3D(nn.Module):
    def __init__(self, args):
        super(ConvNet3D, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding  = 1)
        self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding  = 1)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding  = 1)
        self.fc1 = nn.Linear(1568, 128)
        self.fc2 = nn.Linear(128, 10)
        init_weights(self, args.initialization)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = x
        return output

def custom_gaussian_init(m):
    if isinstance(m, nn.Linear):
        l = torch.tensor(m.weight.shape[1], dtype =float)
        with torch.no_grad():
            nn.init.normal_(m.weight, mean=0, std=1/torch.sqrt(l))
            nn.init.normal_(m.bias, mean=0, std=1)

def init_weights(net, method):
    if method == "":
        pass
    elif method == "uniform":
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                init.uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    elif method == "gaussian":
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                init.normal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    elif method == "kaiming_uniform":
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    elif method == "kaiming_normal":
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    else:
        raise NotImplementedError("Initialization method [{}] is not implemented".format(method))

class FcNet(nn.Module):
    def __init__(self, args):
        super(FcNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
        init_weights(self, args.initialization)

    def forward(self, x):
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.fc3(x)
        x = F.relu(x)

        x = self.fc4(x)
        output = x
        return output

class FcNet3D(nn.Module):
    def __init__(self, args):
        super(FcNet3D, self).__init__()
        self.fc1 = nn.Linear(28*28*3, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 100)
        init_weights(self, args.initialization)

    def forward(self, x):
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.fc3(x)
        x = F.relu(x)

        x = self.fc4(x)
        # output = F.log_softmax(x, dim=1)
        output = x
        return output

def Fc_change_head(model, num_cls):
    if num_cls > 1:
        in_features, out_features = model.fc4.in_features, model.fc4.out_features
        if out_features != num_cls:
            model.fc4 = nn.Linear(in_features, num_cls)
            print(f"fc4 changed {in_features}x{out_features} -> {in_features}x{num_cls}")
    return model

def Conv_change_head(model, num_cls):
    if num_cls > 1:
        in_features, out_features = model.fc2.in_features, model.fc2.out_features
        if out_features != num_cls:
            model.fc2 = nn.Linear(in_features, num_cls)
            print(f"fc2 changed {in_features}x{out_features} -> {in_features}x{num_cls}")
    return model

class SubFcNet(nn.Module):
    def __init__(self, model, num_layer):
        super(SubFcNet, self).__init__()
        params_list = [
            ("fc1", 28*28, 1024),("fc2", 1024, 256),("fc3", 256, 128),("fc4", 128, 10)
        ]
        self.num_layer = num_layer

        self.model = nn.Sequential()
        for param in params_list[:num_layer]:
            self.model.add_module(param[0], nn.Linear(*param[1:]))

        for new_weight, ori_weight in zip(
            self.model.parameters(),
            list(model.parameters())[:num_layer * 2]
        ):
            new_weight.data = ori_weight.data.clone()

    def forward(self, x):
        x = torch.flatten(x, 1)

        for layer in self.model[:-1]:
            x = layer(x)
            x = F.relu(x)

        x = self.model[-1](x)
        if self.num_layer == 4:
            # output = F.log_softmax(x, dim=1)
            output = x
        else:
            output = F.relu(x)
        return output

class SubFcNet3D(nn.Module):
    def __init__(self, model, num_layer):
        super(SubFcNet3D, self).__init__()
        params_list = [
            ("fc1", 28*28*3, 1024),("fc2", 1024, 256),("fc3", 256, 128),("fc4", 128, 10)
        ]
        self.num_layer = num_layer

        self.model = nn.Sequential()
        for param in params_list[:num_layer]:
            self.model.add_module(param[0], nn.Linear(*param[1:]))

        for new_weight, ori_weight in zip(
            self.model.parameters(),
            list(model.parameters())[:num_layer * 2]
        ):
            new_weight.data = ori_weight.data.clone()

    def forward(self, x):
        x = torch.flatten(x, 1)

        for layer in self.model[:-1]:
            x = layer(x)
            x = F.relu(x)

        x = self.model[-1](x)
        if self.num_layer == 4:
            # output = F.log_softmax(x, dim=1)
            output = x
        else:
            output = F.relu(x)
        return output