import argparse
import yaml
import wandb
import matplotlib.pyplot as plt

from train_diag_net import train_diagonal_net


def parse_args():
    parser = argparse.ArgumentParser(description="Run diagonal-net optimizer comparison over varying n_train")
    parser.add_argument('--d', type=int, required=True, help='Dimensionality d')
    parser.add_argument('--k', type=int, required=True, help='The sparsity k of w*')
    parser.add_argument('--delta', type=float, required=True, help='Noise magnitude delta')
    return parser.parse_args()


def main():
    args = parse_args()
    d, k, delta = args.d, args.k, args.delta

    optimizer_configs = []
    
    ## AdamE: beta1=0.9, beta2 in [0.999], lr in [1e-2], exponent in [0.25, 0.5, 0.75, 1] 
    for beta2 in [0.999]:
        for lr in [1e-1]:
            for exponent in [0.01, 0.02, 0.05]:
            # for exponent in [0.1, 0.9, 1.1]:
            # for exponent in [0.25, 0.5, 0.75, 1]:
                name = f"AdamE_b1=0.9_b2={beta2}_lr={lr}"
                optimizer_configs.append({
                    'name': name,
                    'optimizer_type': 'AdamE',
                    'lr': lr,
                    'beta1': 0.9,
                    'beta2': beta2,
                    'exponent': exponent
                })

    # all: list(range(300, 500, 10)) + list(range(500, 1050, 50))
    # e = 0.001: lr = 1e-2, steps = 1e5  300-500, 500-1050 (with test)
    # e = 0.01: lr = 1e-2, steps = 1e5  300-400, 500-1050 (with test)
    # e = 0.1: lr = 1e-2, steps = 1e6  
    # e = 0.5: lr = 1e-2, steps = 5e5  300-500, 500-1050 (with test)
    # e = 0.9: lr = 1e-3, steps = 1e6  300-500 (with test) 
    # sgd: lr = 1e-2, steps = 1e5  50-500, 500-1050 (in loss-trend)

    # n_trains = list(range(300, 380, 10)) + list(range(430, 500, 10))
    # n_trains = list(range(380, 430, 10))
    n_trains = list(range(300, 500, 10))
    results = {}
    results = {opt['name']: {'train': [], 'test': []} for opt in optimizer_configs}

    for opt in optimizer_configs:
        for n_train in n_trains:
            print(f"Running {opt['name']} with n_train={n_train}...")
            # steps = int(1000 / (opt['lr'] ** 2))
            # eval_first = int(10 / opt['lr'])
            steps = int(1e6)
            train_loss, test_loss = train_diagonal_net(
                d=d,
                k=k,
                delta=delta,
                n_train=n_train,
                lr=opt['lr'],
                optimizer_type=opt['optimizer_type'],
                beta1=opt['beta1'],
                beta2=opt['beta2'],
                steps=steps,
                eval_first=0,
                eval_period=100, ## eval every 10 steps, probably change later
                exponent=opt['exponent']
            )
            results[opt['name']]['train'].append(train_loss)
            results[opt['name']]['test'].append(test_loss)


if __name__ == '__main__':
    main()
