# # Smooth Monotonic Networks: Experiments on fully monotonic functions
# The results are stored in files, which are read by ``MonotonicNNPaperEvaluate.ipynb``.
# ## General definitions
# Among others, we compare against XGBoost, which can be installed via `pip install xgboost`, and 
# the Hierarchical Lattice Layer, which can be installed via 
# `pip install pmlayer`.

import numpy as np
import random

import torch 
import torch.nn as nn

from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import r2_score as r2
from sklearn.isotonic import IsotonicRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures

import matplotlib.pyplot as plt
from tqdm.notebook import tnrange

from xgboost import XGBRegressor

from pmlayer.torch.layers import HLattice

from MonotonicNN import SmoothMonotonicNN, MonotonicNN, MonotonicNNAlt
from MonotonicNNPaperUtils import Progress, total_params, fit_torch

from monotonenorm import GroupSort, direct_norm, SigmaNet

prefix = "iclr-"  # prefix for filenames of result files


# ## Univariate experiments 
# Section 4.1 in the manuscript.

T = 21  # number of trials, odd number for having a "median trial"
ls = 75  # lattice points (k in original paper)
ls_small = 35
K = 6  # number of SMM groups, we always use H_k = K
N_train = 100  # number of examples in training data set
N_test = 1000 # number of examples in test data set
sigma = 0.01  # noise level, feel free to vary 


def generate1D(function_name, sigma=0., random=False, xrange=1., N=50):
    if random:
        x = np.random.rand(N) * xrange
        x = np.sort(x, axis=0)
    else:
        xstep = xrange / N
        x = np.arange(0, xrange, xstep)
    match function_name:
        case 'sigmoid10':
            y = 1. /(1. + np.exp(-(x-xrange/2.) * 10.))
        case 'sq':
            y = x**2
        case 'sqrt':
            y = np.sqrt(x)
    y = y + sigma*np.random.normal(0, 1., N)
    return x.reshape(N, 1), y


T = 11  # number of trials, odd number for having a "median trial"
tasks = ['sq', 'sqrt', 'sigmoid10']
K_values = (2, 4, 6, 8)
beta_values = (-3., -2., -1., 0., 1.)


N_tasks = len(tasks)
N_K = len(K_values)
N_beta = len(beta_values)


MSE_train = np.zeros((N_tasks, N_K, N_beta, T))
MSE_test = np.zeros((N_tasks, N_K, N_beta, T))
MSE_clip = np.zeros((N_tasks, N_K, N_beta, T))
no_params = np.zeros(N_K)

for K_id, K in enumerate(K_values):
    for beta_id, beta in enumerate(beta_values):
        #print("K:", K, "beta:", beta)
        for trial in range(T):
            for task_id, task in enumerate(tasks):
                print("K:", K, "beta:", beta, task, trial)
                seed = task_id + trial*N_tasks
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)

                x_train, y_train = generate1D(task, sigma=sigma, random=True, N=N_train)
                x_test, y_test   = generate1D(task, sigma=0., random=False, N=N_test)
                x_train_torch = torch.from_numpy(x_train.astype(np.float32)).clone()
                y_train_torch = torch.from_numpy(y_train.astype(np.float32)).clone()
                x_test_torch = torch.from_numpy(x_test.astype(np.float32)).clone()
                y_test_torch = torch.from_numpy(y_test.astype(np.float32)).clone()


                model = SmoothMonotonicNN(1, K, K, beta=beta)
                if(trial+task_id==0):
                    no_params[K_id] = total_params(model)
                fit_torch(model, x_train_torch, y_train_torch)
                y_pred_train = model(x_train_torch).detach().numpy()
                y_pred_test = model(x_test_torch).detach().numpy()

                MSE_train[task_id, K_id, beta_id, trial] = mse(y_train, y_pred_train)
                MSE_test[task_id, K_id, beta_id, trial] = mse(y_test, y_pred_test)
                MSE_clip[task_id, K_id, beta_id, trial] = mse(y_test, np.clip(y_pred_test, 0., 1.))


fn = prefix + "hyper.npz"
np.savez(fn, MSE_train=MSE_train, MSE_test=MSE_test, MSE_clip=MSE_clip, no_params=no_params)


