import train_model
from scipy.stats import mode
import numpy as np


def test_train_shallow():
    accs = []
    for test_seed in range(10):
        print(test_seed)
        preds = []
        
        for val_seed in range(9):
            args = {
                "dataset": "IMDB-MULTI",
                "early_stopping": True,
                "model": "fast_rbf_net",
                "model_args": {
                    "epsilon": 1e-16,
                    "num_layers": 1,
                    "trainable_laziness": False
                },
                "model_dir": "/home/anonymous/data/trainable_scattering/test/0",
                "num_epochs": 1000,
                "splits": [
                    8,
                    1,
                    1
                ],
                "test_seed": test_seed,
                "transform": "fast_scatter",
                "val_seed": val_seed
            }
            out_file = "/home/anonymous/data/trainable_scattering/test/model.pt"
            predict, test_y = train_model.train_shallow(args, out_file)
            preds.append(predict)

        preds = np.array(preds)
        agg_preds, _ = mode(preds, axis=0)
        acc = (sum(agg_preds.squeeze() == test_y) / len(test_y))
        accs.append(acc)
    accs = np.array(accs)
    print(accs)
    print(np.mean(accs), np.std(accs))
    

if __name__ == '__main__':
    test_train_shallow()
