import numpy as np

import torch
from torch.nn import BCELoss, MSELoss, CrossEntropyLoss

from KAN import KAN
from dataset import ToyDatasetManager
from modules_KAAN import KAAN, MLP_tanh, MLP_relu, MLP_silu
from modules_KAAN import FourierLayer, PolyLayer, GaussianLayer, DoGLayer, ParallelLayerV2, ParallelLayerV1
from train_function import train_single_layer_network


def train_one_toy_dataset(data_name, model_names, seed, epoches=8500):
    np.random.seed(seed)
    data_manager = ToyDatasetManager()
    train_loader, val_loader, test_loader, in_dim, out_dim = data_manager.get_loader(data_name, train_samples=10000, val_samples=2000, test_samples=2000, batch_size=1000)
    best_accuracies = {}
    out_dim = {
        "classification": 1,
        "blobs": 3,
        "circles": 1,
        "moons": 1,
        "friedman1": 1,
    }
    out_dim = out_dim[data_name]
    for model_name in model_names:
        best_accuracies[model_name] = train_one_toy(data_name, model_name, in_dim, out_dim, train_loader, val_loader, epoches)
    return best_accuracies


def train_one_toy(data_name, model_name, in_dim, out_dim, train_loader, val_loader, epoches=8500):
    model_shape = [in_dim, out_dim]
    criterions = {
        "classification": BCELoss(),
        "blobs": CrossEntropyLoss(),
        "circles": BCELoss(),
        "moons": BCELoss(),
        "friedman1": MSELoss(),
    }
    criterion = criterions[data_name]
    if model_name == 'KAN':
        model = KAN(model_shape, device='cuda:0')
    if model_name == 'MLP_tanh':
        model = MLP_tanh(model_shape)
    elif model_name == 'MLP_relu':
        model = MLP_relu(model_shape)
    elif model_name == 'MLP_silu':
        model = MLP_silu(model_shape)
    elif model_name == 'FourierMLP':
        model = KAAN(model_shape, FourierLayer)
    elif model_name == 'PolyMLP':
        model = KAAN(model_shape, PolyLayer)
    elif model_name == 'PolyMLP16':
        model = KAAN(model_shape, PolyLayer, Poly_size=12)
    elif model_name == 'MultMLP4*4':
        model = KAAN(model_shape, PolyLayer, mult_size=4, repeat=4)
    elif model_name == 'GaussianMLP':
        model = KAAN(model_shape, GaussianLayer)
    elif model_name == 'DoGMLP':
        model = KAAN(model_shape, DoGLayer)
    elif model_name == 'ParallelMLPV1':
        model = KAAN(model_shape, ParallelLayerV2)
    elif model_name == 'ParallelMLPV2':
        model = KAAN(model_shape, ParallelLayerV1)
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    best_accuracy, model = train_single_layer_network(model, train_loader, val_loader, criterion, optimizer, data_name, 100)
    print(f"\rModel: {model_name}, Dataset: {data_name} Finished, Accuracy: {best_accuracy}.")
    return best_accuracy


def train_all_toy(pool):
    data_manager = ToyDatasetManager()
    data_names = data_manager.get_names()

    model_names = ['KAN', 'MLP_tanh', 'MLP_relu', 'MLP_silu', 'ParallelMLPV1', 'ParallelMLPV2', 'FourierMLP', 'MultMLP', 'MultMLP4*4', 'GaussianMLP', 'DoGMLP', 'MultMLP16']
    results_file = 'toy_results_new_100.txt'
    recorder = {data_name: {model_name: [] for model_name in model_names} for data_name in data_names}

    for data_name in data_names:
        results = [pool.apply_async(train_one_toy_dataset, args=(data_name, model_names, i)) for i in range(100)]
        results = [res.get() for res in results]
        recorder[data_name] = results
        with open(results_file, 'w') as f:
            f.write(str(recorder))
    return recorder


if __name__ == '__main__':
    from multiprocessing import Pool
    with Pool(10) as pool:
        train_all_toy(pool)
