import sys
import torch
import numpy
from nn.deep_sets import DeepSet, MLP
from nn.abelian_group_net import (AbelianGroupNetwork, AbelianSemigroupNetwork1,
                                  AbelianSemigroupNetwork2, AbelianSemigroupNetwork3)
from nn.monotonic_net import MonotonicNetwork, MultiMonotonicNetwork
from setting import Xs, xs, func1, func2, func3, func4, func5, func6, func7, func8, func9, func10, func11


def mean_rae(y, pred):
    """ mean relative absolute error [batch, dim], [batch, dim] -> scalar """
    dif = torch.abs(y - pred)
    return (dif / (torch.abs(y) + 1e-12)).mean()


def mean_ae(y, pred):
    """ mean absolute error [batch, dim], [batch, dim] -> scalar """
    return torch.abs(y - pred).mean()


def train(model, func, train_xs=xs.train, valid_xs=xs.valid, batch_size=32, epoch_num=2000,
          learning_rate=5e-3, verbose=True, print_error_test=False):
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = torch.nn.MSELoss(reduction="mean")
    train_ys = Xs.calc_ys(train_xs, func)
    valid_ys = Xs.calc_ys(valid_xs, func)
    small_ys = Xs.calc_ys(xs.test_small, func)
    large_ys = Xs.calc_ys(xs.test_large, func)
    for epoch in range(epoch_num):
        # print learning progress
        if verbose and epoch % 50 == 0:
            print(f"{epoch} epoch")
            with torch.no_grad():
                model.eval()
                print(f"[train] error: {mean_ae(model(train_xs), train_ys)}", flush=True)
                print(f"[valid] error: {mean_ae(model(valid_xs), valid_ys)}", flush=True)
                print(f"[test large] error: {mean_ae(model(xs.test_large), large_ys)}", flush=True)

        # update param
        model.train()
        perm = numpy.random.permutation(len(train_xs))
        for index in numpy.array_split(perm, (len(perm) + batch_size - 1) // batch_size):
            pred = model(train_xs[index])
            loss = loss_fn(pred, train_ys[index])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    with torch.no_grad():
        model.eval()
        if print_error_test:
            print(f"test_small loss: {loss_fn(model(xs.test_small), small_ys)}", file=sys.stderr, flush=True)
            print(f"test_large loss: {loss_fn(model(xs.test_large), large_ys)}", file=sys.stderr, flush=True)
            print(f"test_small error: {mean_ae(model(xs.test_small), small_ys):.3}", file=sys.stderr, flush=True)
            print(f"test_large error: {mean_ae(model(xs.test_large), large_ys):.3}", file=sys.stderr, flush=True)
        return mean_ae(model(valid_xs), valid_ys)


if __name__ == "__main__":
    numpy.random.seed()
    ds = DeepSet(MLP(1, 8), MLP(8, 1))
    agn = AbelianGroupNetwork(MonotonicNetwork(16, 16, const_sign=1.))
    asn1 = AbelianSemigroupNetwork1(MonotonicNetwork(16, 16, const_sign=None))
    asn2 = AbelianSemigroupNetwork2(MonotonicNetwork(16, 16, const_sign=None))
    asn3 = AbelianSemigroupNetwork3(MonotonicNetwork(16, 16, const_sign=None))
    train(asn1, func2)
    train(asn2, func2)
    train(asn3, func2)
    train(agn, func2)
    train(ds, func2)
