import torch

from KAN import KAN
from dataset import TabularDatasetManager
from modules_KAAN import KAAN, MLP_tanh, MLP_relu, MLP_silu
from modules_KAAN import FourierLayer, PolyLayer, GaussianLayer, DoGLayer, ParallelLayerV2, ParallelLayerV1
from train_function import train_multilayer_network


def train_one(data_name, model_name, data_manager, epoches=8500):
    (X_train, y_train), (X_test, y_test), in_dim, out_dim = data_manager.get_local_data(data_name, )
    model_shape = [in_dim, out_dim * 2, out_dim]
    criterion = torch.nn.CrossEntropyLoss()
    if model_name == 'KAN':
        model = KAN(model_shape, device='cuda:0')
    if model_name == 'MLP_tanh':
        model = MLP_tanh(model_shape)
    elif model_name == 'MLP_relu':
        model = MLP_relu(model_shape)
    elif model_name == 'MLP_silu':
        model = MLP_silu(model_shape)
    elif model_name == 'FourierKAAN':
        model = KAAN(model_shape, FourierLayer)
    elif model_name == 'PolyKAAN':
        model = KAAN(model_shape, PolyLayer)
    elif model_name == 'PolyKAAN16':
        model = KAAN(model_shape, PolyLayer, Poly_size=12)
    elif model_name == 'PolyKAAN4*4':
        model = KAAN(model_shape, PolyLayer, mult_size=4, repeat=4)
    elif model_name == 'GaussianKAAN':
        model = KAAN(model_shape, GaussianLayer)
    elif model_name == 'DoGKAAN':
        model = KAAN(model_shape, DoGLayer)
    elif model_name == 'ParallelKAANV1':
        model = KAAN(model_shape, ParallelLayerV2)
    elif model_name == 'ParallelKAANV2':
        model = KAAN(model_shape, ParallelLayerV1)
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=.01)
    best_accuracy, model = train_multilayer_network(model, X_train, y_train, X_test, y_test, criterion, optimizer, epochs=50, batch_size=1000, pre_train=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=.1)
    best_accuracy, model = train_multilayer_network(model, X_train, y_train, X_test, y_test, criterion, optimizer, epochs=500, batch_size=1000, pre_train=False)
    print(f"\rModel: {model_name}, Dataset: {data_name} Finished.")
    return best_accuracy


def train_tabular(pool):
    data_manager = TabularDatasetManager()
    data_names = data_manager.get_names()
    model_names = ['KAN', 'ParallelKAANV1', 'ParallelMLPV2', 'FourierMLP', 'MultMLP', 'MultMLP4*4', 'GaussianMLP', 'DoGMLP', 'MLP_tanh', 'MLP_relu', 'MLP_silu', 'MultMLP16']
    results_file = 'easy_results.txt'
    recorder = {data_name: {model_name: [] for model_name in model_names} for data_name in data_names}

    for data_name in data_names:
        for model_name in model_names:
            results = [pool.apply_async(train_one, args=(data_name, model_name, data_manager)) for _ in range(10)]
            results = [res.get() for res in results]
            recorder[data_name][model_name] = results
            with open(results_file, 'w') as f:
                f.write(str(recorder))
    return recorder


if __name__ == "__main__":
    from multiprocessing import Pool
    with Pool(5) as pool:
        train_tabular(pool)
