import torch
import torch.nn as nn
class Hinge_Loss(nn.Module):
    def __init__(self, weight = None):
        super(Hinge_Loss, self).__init__()
        self.activation = nn.ReLU()
        self.weight = weight
    
    def forward(self, input, label):
        loss = self.activation(1 - input * (2 * label - 1))
        if self.weight is not None:
            loss = loss * torch.where(label < 0.5,
                                      self.weight[0], self.weight[1])
        return torch.mean(loss)

class Classification_Net(nn.Module):
    def __init__(self, size1, size2, size_out = 2):
        super(Classification_Net, self).__init__()
        self.size1 = size1
        self.size2 = size2
        self.fc1 = nn.Linear(size1, size2, bias=True)
        self.fc2 = nn.Linear(size2, size_out, bias=False)
        self.activation = nn.ReLU()
        self.effective_dim = None

    def forward(self, input, epoch = None):
        input = self.fc1(input)
        if epoch is not None and epoch == 2 and self.effective_dim is not None:
            input += torch.empty(input.shape).normal_(0.0, 1 / (self.effective_dim**2))

        feature = - self.activation(- self.activation(input) + 1) + 1
        out = self.fc2(feature)

        return out, feature

    def init_weight_grad(self, effective_dim):
        self.effective_dim = effective_dim
        for name, param in self.named_parameters():
            if "fc1" in name and "weight" in name:
                m = param.data.shape[0] // 2
                param.data.normal_(0.0, 1 / effective_dim ** 0.5)
                param.data[m:] = param.data[:m].detach().clone()
            elif "fc1" in name and "bias" in name:
                m = param.data.shape[0] // 2
                param.data.normal_(0.0, 1 / (32*effective_dim**2))
                param.data[m:] = param.data[:m].detach().clone()
            elif "fc2" in name and "weight" in name:
                m = param.data.shape[1] // 2
                param.data.normal_(0.0, 1 * 6**0.5 / effective_dim)  # 5 is input std, 6 is \gamma
                param.data[:, m:] = -param.data[:, :m].detach().clone()
            else:
                print(name)
                assert(False)

# Mathematical Models of Overparameterized Neural Networks: Equation(8)
class NTK(nn.Module):
    def __init__(self, size1, size2, size_out = 1):
        super(NTK, self).__init__()
        self.fc1_train = nn.Linear(size1, size2, bias=True)
        self.fc2_train = nn.Linear(size2, size_out, bias=True)
        self.fc1 = nn.Linear(size1, size2, bias=True)
        self.fc2 = nn.Linear(size2, size_out, bias=True)
        self.activation = nn.ReLU()

    def forward(self, input, epoch = None):
        feature = self.activation(self.fc1(input))
        out = self.fc2_train(feature)

        input2 = self.fc1(input)
        out2 = self.fc2(self.fc1_train(input) * torch.where(input2 > 0, 1.0, 0.0))
        return out + out2, feature

    def init_weight_grad(self,):
        self.apply(weights_init)
        for name, param in self.named_parameters():
            if "train" in name:
                param.requires_grad = True
                if "fc1" in name and "weight" in name:
                    torch.nn.init.xavier_normal_(param.data, gain = 0.000001)
                elif "fc2" in name and "weight" in name:
                    param.data = self.fc2.weight.data.detach().clone()
            else: 
                param.requires_grad = False


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_normal_(m.weight.data, gain = 0.001)
        if m.bias is not None:
            m.bias.data.fill_(0)