import torch
import numpy as np
import random
import torch.nn as nn
from get_args import get_args
import os
from MLP import MLP
from data_loaders import get_dataloaders

def main(args):
    trainloader, testloader = get_dataloaders(args)

    net = MLP(
        args.num_hidden_layers, 
        args.width, 
        args.c, 
        args.weight_distribution,
        args.weight_gain,
        args.bias_distribution,
        args.bias_gain,
        args.train_weights,
        args.input_layer_bias,
        args.output_layer_bias,
        args.middle_layers_bias,
        args.l1_weight,
        args.bias_l1_weight,
        args.bias_l1_baseline
        )
    
    if args.load_weights_path is not None:
        net.set_weights(args.load_weights_path)

    if args.load_biases_path is not None:
        net.set_biases(args.load_biases_path)

    if not args.train_weights:
        net.disable_weights_training()
    
    if args.finetune_output:
        net.only_finetune_output()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001)
    if args.save_biases_path is not None:
        os.makedirs(args.save_biases_path, exist_ok=True)

    if args.save_weights_path is not None:
        os.makedirs(args.save_weights_path, exist_ok=True)

    for epoch in range(args.n_epochs):
        print(f'Epoch {epoch + 1}/{args.n_epochs}')

        train_loss = net.train_epoch(trainloader, optimizer, criterion, args.noise_variance)
        val_loss, val_accuracy, _ = net.eval_epoch(testloader, criterion)
    
        model_weights_and_biases = net.get_weights_and_biases()
    print("parameters", args)
    print(f"Final validation accuracy:", val_accuracy)
    if args.save_biases_path is not None:
        net.save_biases(f"{args.save_biases_path}/train_weights={args.train_weights}_width={args.width}_dataset={args.dataset}_c={args.c}_num_hidden_layers={args.num_hidden_layers}.pth")
    
    if args.save_weights_path is not None:
        net.save_weights(f"{args.save_weights_path}/train_weights={args.train_weights}_width={args.width}_dataset={args.dataset}_c={args.c}_num_hidden_layers={args.num_hidden_layers}.pth")
    
    if args.modelpath is not None:
        torch.save(net, args.modelpath)

    
if __name__ == "__main__":
    args = get_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    main(args)