import argparse
import yaml
import wandb
import matplotlib.pyplot as plt

from train_diag_net import train_diagonal_net


def parse_args():
    parser = argparse.ArgumentParser(description="Run diagonal-net optimizer comparison over varying n_train")
    parser.add_argument('--d', type=int, required=True, help='Dimensionality d')
    parser.add_argument('--k', type=int, required=True, help='The sparsity k of w*')
    parser.add_argument('--delta', type=float, required=True, help='Noise magnitude delta')
    return parser.parse_args()


def main():
    args = parse_args()
    d, k, delta = args.d, args.k, args.delta

    optimizer_configs = []
    
    ## Adam: beta1=0.9, beta2 in [0.95, 0.999], lr in [1e-3, 5e-3, 1e-2]
    for beta2 in [0.95, 0.999]:
        for lr in [5e-3, 1e-2, 5e-2, 1e-1]:
            name = f"Adam_b1=0.9_b2={beta2}_lr={lr}"
            optimizer_configs.append({
                'name': name,
                'optimizer_type': 'Adam',
                'lr': lr,
                'beta1': 0.9,
                'beta2': beta2
            })

    ## SGD: lr in [5e-3, 1e-2, 5e-2, 1e-1]
    for lr in [5e-3, 1e-2, 5e-2, 1e-1]:
        name = f"SGD_lr={lr}"
        optimizer_configs.append({
            'name': name,
            'optimizer_type': 'SGD',
            'lr': lr
        })

    n_trains = list(range(200, 350, 10))
    results = {}
    results = {opt['name']: {'train': [], 'test': []} for opt in optimizer_configs}

    # config = yaml.load(open('.config.yml'), Loader=yaml.FullLoader)
    # entity = config['wandb_entity']
    # run = wandb.init(
    #     project="diagonal-net-loss-trend",
    #     entity=entity,
    #     name=f"comparison-d{d}-k{k}-delta{delta}",
    #     reinit=True
    # )

    for opt in optimizer_configs:
        for n_train in n_trains:
            print(f"Running {opt['name']} with n_train={n_train}...")
            steps = int(10 / (opt['lr'] ** 2))
            # eval_first = int(10 / opt['lr'])
            train_loss, test_loss = train_diagonal_net(
                d=d,
                k=k,
                delta=delta,
                n_train=n_train,
                lr=opt['lr'],
                optimizer_type=opt['optimizer_type'],
                beta1=opt.get('beta1'),
                beta2=opt.get('beta2'),
                steps=steps,
                eval_first=0,
                eval_period=10, ## eval every 10 steps, probably change later
            )
            results[opt['name']]['train'].append(train_loss)
            results[opt['name']]['test'].append(test_loss)


if __name__ == '__main__':
    main()
