"""Experiment script to check if the maximal update parameterization is imple-
mented correctly.
"""
from data_proc.california_housing import PrepareCaliforniaHousing, CaliforniaHousing

from archs.mlp_mfp import MLPMFP
from archs.mlp_mup import MLPMUP

from trainer.check_activations import CheckActivations
from setup.seeding import seed_all

from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from torch.nn import MSELoss

def main():
    # set up models
    parameterization = "mup"
    num_seeds = 5
    widths = [
        [5] * 4,
        [10] * 4,
        [50] * 4,
        [100] * 4,
        [500] * 4,
        [1000] * 4,
    ]
    models = {}
    for width in widths:
        models[width[0]] = []

    for seed in range(num_seeds):
        for w in widths:
            seed_all(seed=seed)
            if parameterization == "mup":
                models[w[0]].append(
                    MLPMUP(
                        in_dim=13,
                        out_dim=1,
                        widths=w,
                        is_bias=False,
                    )
                )
            elif parameterization == "mfp":
                models[w].append(
                    MLPMFP(
                        in_dim=13,
                        out_dim=1,
                        width=w,
                        is_bias=False,
                    )
                )

    # set up dataloader for a small datasample
    calif = PrepareCaliforniaHousing()
    calif.read_from_disk(
        exp_name='california_housing_mfp',
        file_name="trial_0_california_housing_mfp.pickle",
    )
    trainset = CaliforniaHousing(
                Xarray=calif.X_train_sub,
                yarray=calif.y_train_sub,
            )
    trainloader = DataLoader(
        dataset=trainset,
        shuffle=False,
        batch_size=1,
        num_workers=1,
    )

    check_activations = CheckActivations(
        models=models,
        steps=3,
        dataloader=trainloader,
        optimizer=SGD,
        loss_fn=MSELoss(),
        lr=0.1, # note: Yang recommends large lr. But note that for many steps, gradients explode.
    )
    check_activations.train_step(parameterization=parameterization)

    check_activations.plot(path="/path/to/output/directory")

if __name__ == "__main__":
    main()