# lenet base model for Pareto MTL
import torch
import torch.nn as nn
import numpy as np


class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.linear_1 = torch.nn.Linear(input_size, input_size, bias=False)
        self.linear_2 = torch.nn.Linear(input_size, input_size // 2, bias=False)
        self.linear_3 = torch.nn.Linear(input_size // 2, 1, bias=False)
        # self.linear_1 = torch.nn.Linear(input_size, input_size // 2, bias=False)
        # self.linear_2 = torch.nn.Linear(input_size // 2, 1, bias=False)

    def forward(self, x):

        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.relu(x)
        x = self.linear_3(x)
        x = self.sigmoid(x)
        # x = self.linear_1(x)
        # x = self.relu(x)
        # x = self.linear_2(x)
        # x = self.sigmoid(x)


        return torch.squeeze(x)


    def fit(self, x, y, num_epoch=1000, lr=0.01, lamb=0, tol=1e-4, batch_size=20, verbose=True):

        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)
        last_loss = 1e9

        num_sample = len(x)
        total_batch = num_sample // batch_size
        early_stop = 0
        for epoch in range(num_epoch):
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)
            epoch_loss = 0
            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size * idx:(idx + 1) * batch_size]

                sub_x = torch.Tensor(x[selected_idx])
                sub_y = torch.Tensor(y[selected_idx])

                pred = self.forward(sub_x)

                loss = nn.MSELoss()(pred, sub_y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # epoch_loss += xent_loss.detach().numpy()
                epoch_loss += loss.detach().numpy()

            if epoch_loss > last_loss - tol:
                if early_stop > 5:
                    print("[IPS_model] epoch:{}, xent:{}".format(epoch, epoch_loss))
                    break
                early_stop += 1

            last_loss = epoch_loss

            if epoch % 10 == 0 and verbose:
                print("[IPS_model] epoch:{}, xent:{}".format(epoch, epoch_loss))

            if epoch == num_epoch - 1:
                print("[Warning] Reach preset epochs, it seems does not converge.")

    def predict(self, x):
        x = torch.Tensor(x)
        x = self.forward(x)
        return x.detach().cpu().numpy()
        
    
   