import os
import time
import multiprocessing
import argparse

import torch
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.nn import CrossEntropyLoss

from data_proc.cifar10 import PrepareCorruptCIFAR
from archs.cnn import CNN
from setup.configuration import TrainArgs
from setup.seeding import seed_all
from env.directories import create_folder
from env.user import PROJECT_PATH
from helpers.logger import get_logger, get_logging_queue, setup_logging
from trainer.loop_with_accuracy_and_gpu_cache import train

DATA_SPLIT = 0.1

def launch_trial_training_wrapper(t, subpred, model_folder, batch_size, logging_queue):
    setup_logging(logging_queue)
    seed_all(123)

    pcifar = PrepareCorruptCIFAR(train_corrupt_seed=123, train_corrupt_percentage=0.1)
    pcifar.subsample(trial_seed=t, split=[DATA_SPLIT, 1 - DATA_SPLIT])

    train_loader = DataLoader(
        dataset=pcifar.train_subset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=0,
    )
    testloader = DataLoader(
        dataset=pcifar.test_set,
        shuffle=False,
        batch_size=len(pcifar.test_set),
        num_workers=0,
    )

    model_name = f"cifar_model_trial-{t}_subpred-{subpred}"

    model = CNN(
        width=[16, subpred],
        input_channels=3,
        kernel_size=3,
        add_bias=False,
        out_dim=10
    )
    train_args = TrainArgs(
        model=model,
        optim=SGD,
        optim_kwargs={"momentum": 0.9},
        fn_loss=CrossEntropyLoss(),
        lr=(0.1),
        lr_sched=StepLR,
        lr_sched_kwargs={"step_size": 5, "gamma": 0.99},
        dataloaders=(train_loader, testloader),
        wandb_kwargs={
            "project": "CNN-cifar-dd-corrupt-small",
            "name": model_name,
        },
        save_folder=model_folder,
        ckpt_name=model_name,
        max_epochs=1000,
        save_final=True,
        warm_up=100,
        device="cuda:2"
    )
    train(train_args=train_args)


def parse_args():
    parser = argparse.ArgumentParser(description="Parallel CNN Training")
    parser.add_argument("--num_runs", type=int, default=1,
                        help="Maximum number of concurrent processes (default: 5)")
    return parser.parse_args()


def main():
    args = parse_args()
    num_runs = args.num_runs

    experiment_name = "cifar_cnn_dd_01_label_corruption_small"
    trials = 50
    batch_size = 256

    # create_folder(path=PROJECT_PATH + '/data/' + experiment_name, safe_mode=False)
    # create_folder(path=PROJECT_PATH + '/models/' + experiment_name, safe_mode=False)
    # create_folder(path=PROJECT_PATH + '/analysis/' + experiment_name, safe_mode=False)

    tasks = []
    logging_queue = get_logging_queue()

    for subpred in [2048, 4096]:
    # for subpred in [200, 500, 1000, 5000]:

        model_folder = os.path.join(PROJECT_PATH, "models", experiment_name, str(subpred))
        create_folder(path=model_folder, safe_mode=False)

        for t in range(trials):
            tasks.append((t, subpred, model_folder, batch_size, logging_queue))

    processes = []
    for args_ in tasks:
        p = multiprocessing.Process(target=launch_trial_training_wrapper, args=args_)
        p.start()
        processes.append(p)

        # Limit concurrent processes
        while len(processes) >= num_runs:
            for proc in processes:
                if not proc.is_alive():
                    proc.join()
                    processes.remove(proc)
            time.sleep(0.5)

    # Wait for remaining processes
    for proc in processes:
        proc.join()


if __name__ == "__main__":
    multiprocessing.set_start_method('spawn', force=True)
    logger = get_logger()
    main()
