"""Experiment script to train models on label corrupted MNIST.

The implementation uses parallel processing and an efficient trainer script.
"""

import os
import time
import multiprocessing
import argparse

from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.nn import CrossEntropyLoss

from data_proc.mnist import PrepareCorruptMNIST
from archs.mlp_mfp import MLPMFP
from setup.configuration import TrainArgs
from setup.seeding import seed_all
from env.directories import create_folder
from env.user import PROJECT_PATH
from helpers.logger import get_logger, get_logging_queue, setup_logging
from trainer.efficient_mnist_trainer import train

DATA_SPLIT = 0.9

def launch_trial_training_wrapper(t, subpred, model_folder, batch_size, logging_queue):
    setup_logging(logging_queue)
    seed_all(123)

    pmnist = PrepareCorruptMNIST(train_corrupt_seed=123, train_corrupt_percentage=0.2)
    pmnist.subsample(trial_seed=t, split=[DATA_SPLIT, 1 - DATA_SPLIT])

    train_loader = DataLoader(
        dataset=pmnist.train_subset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=0,
    )
    testloader = DataLoader(
        dataset=pmnist.test_set,
        shuffle=False,
        batch_size=len(pmnist.test_set),
        num_workers=0,
    )

    model_name = f"mnist_model_trial-{t}_subpred-{subpred}"

    model = MLPMFP(
        in_dim=784,
        out_dim=10,
        width=subpred,
        is_bias=False,
    )
    train_args = TrainArgs(
        model=model,
        optim=SGD,
        optim_kwargs={"momentum": 0.9},
        fn_loss=CrossEntropyLoss(),
        lr=(0.6) * subpred, # for large data limit (90%) 0.6 works well
        # lr=(0.15) * subpred, # for small data limit (10%) 0.15 works well
        lr_sched=StepLR,
        lr_sched_kwargs={"step_size": 5, "gamma": 0.99},
        dataloaders=(train_loader, testloader),
        wandb_kwargs={
            "project": "MNIST-bvd-mfp-dd-corrupt",
            "name": model_name,
        },
        save_folder=model_folder,
        ckpt_name=model_name,
        max_epochs=1000,
        save_final=True,
        warm_up=100,
        device="cuda:0"
    )
    train(train_args=train_args)


def parse_args():
    parser = argparse.ArgumentParser(description="Parallel MNIST Training")
    parser.add_argument("--num_runs", type=int, default=5,
                        help="Maximum number of concurrent processes (default: 5)")
    return parser.parse_args()


def main():
    args = parse_args()
    num_runs = args.num_runs

    experiment_name = "mnist_mfp_dd_02_label_corruption_test"
    trials = 50
    batch_size = 4096  # for large data limit (90%) 4096 works well
    # batch_size = 512 # for small data limit (10%) 512 works well

    create_folder(path=PROJECT_PATH + '/data/' + experiment_name, safe_mode=False)
    create_folder(path=PROJECT_PATH + '/models/' + experiment_name, safe_mode=False)
    create_folder(path=PROJECT_PATH + '/analysis/' + experiment_name, safe_mode=False)

    tasks = []
    logging_queue = get_logging_queue()

    for subpred in [5, 10, 50, 100, 500, 1000, 5000]:

        model_folder = os.path.join(PROJECT_PATH, "models", experiment_name, str(subpred))
        create_folder(path=model_folder, safe_mode=False)

        for t in range(trials):
            tasks.append((t, subpred, model_folder, batch_size, logging_queue))

    processes = []
    for args_ in tasks:
        p = multiprocessing.Process(target=launch_trial_training_wrapper, args=args_)
        p.start()
        processes.append(p)

        # Limit concurrent processes
        while len(processes) >= num_runs:
            for proc in processes:
                if not proc.is_alive():
                    proc.join()
                    processes.remove(proc)
            time.sleep(0.5)

    # Wait for remaining processes
    for proc in processes:
        proc.join()


if __name__ == "__main__":
    multiprocessing.set_start_method('spawn', force=True)
    logger = get_logger()
    main()
