import torch
import argparse
from argparse import Namespace
import torch.optim as optim
import copy
import wandb
from tqdm import tqdm
from utils import generate_loaders, generate_model, init_wandb, get_criterion, set_seed, log_results
from train import train_epoch
import json

def main(args: Namespace) -> None:
    # seed run
    set_seed(args['seed'])

    # get loaders and model
    train_loader, val_loader, test_loader, pos_in, edge_in, feat_in = generate_loaders(args)
    model = generate_model(args, pos_in, feat_in, edge_in)
    model.to(args["device"])

    # get params
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Number of parameters: {num_params}.')
    args['num_params'] = num_params
    # init wandb
    init_wandb(args)
    wandb.config['num_params'] = num_params

    # get optimization
    criterion = get_criterion(args)
    optimizer = optim.Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=25, factor=0.5, verbose=True, min_lr=1e-6)
    best_val_loss, best_model = float('inf'), None

    for _ in tqdm(range(args["epochs"])):
        # train and validate
        results = {
            'Train': train_epoch(args, model, train_loader, criterion, optimizer),
            'Validation': train_epoch(args, model, val_loader, criterion)
        }
        scheduler.step(results['Validation']['Loss'])

        # update best model
        if results['Validation']['Loss'] < best_val_loss:
            best_model = copy.deepcopy(model)
            best_val_loss = results['Validation']['Loss']

        # log results
        log_results(results)

        # stop early
        if optimizer.param_groups[-1]['lr'] < 1e-5:
            break

    # test
    results = {'Test': train_epoch(args, best_model, test_loader, criterion)}
    log_results(results)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # General parameterss
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='batch size')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='num workers')

    # Model parameters
    parser.add_argument('--model_name', type=str, default='pe_mpnn',
                        help='model')
    parser.add_argument('--layer_type', type=str, default='gcn',
                        help='information type')
    parser.add_argument('--state_type', type=str, default='tensor',
                        help='information type')
    parser.add_argument('--ent_deg', type=int, default=0,
                        help='number of entanglements')
    parser.add_argument('--feat_hidden', type=int, default=18,  # 72
                        help='hidden features')
    parser.add_argument('--pos_hidden', type=int, default=18,  # 72
                        help='hidden features')
    parser.add_argument('--num_out', type=int, default=1,
                        help='hidden features')
    parser.add_argument('--num_layers', type=int, default=4,
                        help='number of layers')
    parser.add_argument('--aggr', type=str, default='add',
                        help='aggregate function')
    parser.add_argument('--seed', type=int, default=42,
                        help='seed')
    parser.add_argument('--res', type=int, default=0,
                        help='res')
    parser.add_argument('--struc_info_type', type=str, default='random_walk',
                        help='res')
    parser.add_argument('--struc_dim', type=int, default=20,
                        help='res')
    parser.add_argument('--red', type=int, default=0,
                        help='res')
    parser.add_argument('--data_size', type=int, default=10000)

    # Optimizer parameters
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='learning rate')

    # Dataset arguments
    parser.add_argument('--dataset', type=str, default='zinc',
                        help='dataset')

    parsed_args = parser.parse_args()
    parsed_args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    parsed_args.task = 'reg' if parsed_args.dataset in ['zinc', 'peptides_struct'] else 'class'

    parsed_args = vars(parsed_args)
    main(parsed_args)
