"""Experiment script to train an MLP on the California Housing data set using
the mean-field parameterization.
"""

from __future__ import annotations

import os

from torch.optim import SGD
from torch.nn import MSELoss

from data_proc.california_housing import PrepareCaliforniaHousing, CaliforniaHousing
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from archs.mlp_mfp import MLPMFP

from setup.seeding import seed_all
from setup.configuration import TrainArgs
from trainer.loop import train
from env.directories import create_folder
from env.user import PROJECT_PATH
from helpers.logger import get_logger

logger = get_logger()
logger.setLevel(20)


def main() -> None:
    #--------------------------------------------------------------------------
    # Prepare experiment directories
    experiment_name = 'california_housing_mfp_fixed_epoch'

    trials = 50

    create_folder(
        path=PROJECT_PATH + '/data/' + experiment_name,
        safe_mode=False,
    )
    create_folder(
        path=PROJECT_PATH + '/models/' + experiment_name,
        safe_mode=False,
    )
    create_folder(
        path=PROJECT_PATH + '/analysis/' + experiment_name,
        safe_mode=False,
    )

    #--------------------------------------------------------------------------
    # Data

    calif = PrepareCaliforniaHousing()
    # generate subsampled training sets
    calif.subsample(
        trials=trials,
        exp_name=experiment_name,
        safe_mode=False
    )

    valset = CaliforniaHousing(
        Xarray=calif.X_val,
        yarray=calif.y_val,
    )
    valloader = DataLoader(
        dataset=valset,
        shuffle=False,
        batch_size=4096,
        num_workers=4,
    )

    # loop over subpredictors (model widths)
    for subpred in [5, 10, 50, 100, 500, 1000, 5000]:
        # create a subfolder in /models/experiment name for current run
        folder = os.path.join(PROJECT_PATH, "models")
        folder = os.path.join(folder, experiment_name)
        folder = os.path.join(folder, str(subpred))
        create_folder(path=folder, safe_mode=False)

        #scan over t trials
        dir_contents = os.listdir(path=PROJECT_PATH+'/data/'+experiment_name)
        for t in dir_contents:
            seed_all(seed=123) # reseed to ensure model weights are init the same

            # read subsampled data set from disk, then pass to dataloader
            calif.read_from_disk(exp_name=experiment_name, file_name=t)
            trainset = CaliforniaHousing(
                Xarray=calif.X_train_sub,
                yarray=calif.y_train_sub,
            )
            trainloader = DataLoader(
                dataset=trainset,
                shuffle=True,
                batch_size=512,
                num_workers=4,
            )

            #------------------------------------------------------------------
            # Trainer
            name = t.replace(".pickle", "")
            model_name = f"model-{name}_subpred-{subpred}"
            
            model = MLPMFP(
                in_dim=13,
                out_dim=1,
                width=subpred,
                is_bias=False,
            )
            train_args = TrainArgs(
                model=model,
                optim=SGD,
                optim_kwargs={
                    "momentum": 0.9,
                },
                fn_loss=MSELoss(),
                lr=(0.15) * subpred,
                lr_sched=StepLR,
                lr_sched_kwargs={
                    "step_size": 5,
                    "gamma": 0.99,
                },
                dataloaders=(trainloader, valloader),
                wandb_kwargs={
                    "project": "california_housing_mfp_fixed_epoch",
                    "name": model_name,
                    "entity": "your-entity",
                },
                save_folder=folder,
                ckpt_name=model_name,
                max_epochs=2000,
            )

            train(
                train_args=train_args,
            )

if __name__ == "__main__":
    main()








