# nn_utils.py

import torch
import numpy as np

class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, activation='ReLU'):
        super(Feedforward, self).__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size,bias=False)
        if activation=='ReLU':
            self.activation = torch.nn.ReLU()
        else:
            raise Exception('Not implemented')
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size,bias=False)

    
    def forward(self, x):
        hidden = self.fc1(x)
        hidden = self.activation(hidden)
        output = self.fc2(hidden)
        return output


def train_ReLU_nn(X_t,y_t, m, break_points, num_trial=10, iter_num=10000, lr = 1e0, verbose=False):
    trial = 0
    while trial<10:
        net = Feedforward(2,m,1)
        optimizer = torch.optim.SGD(net.parameters(),lr=lr,momentum=0,weight_decay=0) #indeed GD
        # iter_num = max(break_points)
        loss_fn = torch.nn.BCEWithLogitsLoss()
        loss_history = np.zeros(iter_num)
        i = 0
        j = 0
        theta_list = []
        for i in range(iter_num):
            if i==(break_points[j]-1):
                W1 =list(net.parameters())[0].detach().numpy().T.copy()
                w2 =list(net.parameters())[1].detach().numpy().reshape([-1]).copy()
                theta_list.append([W1,w2])
                j = j+1
            optimizer.zero_grad()
            output = y_t*net(X_t)
            loss = torch.sum(torch.log(1+torch.exp(-output)))
            loss.backward()
            optimizer.step()

            loss_history[i] = loss.item()
            
        nn_para = list(net.parameters())
        if verbose:
            print('trial: {} loss: {}'.format(trial, np.abs(loss.item())))
        trial = trial+1
        if loss.item()<0.5:
            break

    return loss_history, theta_list