# %%
from solver import sagawa_et_al_solve, EXP3P_solve, TINF_solve
from synthetic import synthetic_DRO_instance
from adult import adult_DRO_instance, adult_hinge_DRO_instance
import csv
import json
import numpy as np

# %%
ALGS = {'Sagawa_et_al' : sagawa_et_al_solve,
        'EXP3P' : EXP3P_solve,
        'TINF': TINF_solve}
OPT = {'default': 0.40811395237370146, 'hinge' : 0.4449120391895968}
MINIBATCH = 10
T = 2000000
D = 10
with open('best_params.json', 'r') as f:
    BEST_PARAMS = json.load(f)

# %%
def run_adult_DRO(algname):
    # Adult DRO instance
    print(f'running {algname} on adult DRO instance...')
    alg = ALGS[algname]
    opt = OPT['default']
    droinstance = adult_DRO_instance(D)
    theta0 = np.zeros(droinstance.n)
    m = droinstance.m
    params = BEST_PARAMS[f'{droinstance.name} {algname} mbatch={MINIBATCH} T={T}']
    Ct = float(params['Ct'])
    Cq = float(params['Cq'])
    output, objhist = alg(droinstance, T, theta0,
                        eta_t=lambda t: Ct * D * np.sqrt(1/t),
                        eta_q=Cq * np.sqrt(np.log(m)/(m*T)),
                        beta=np.sqrt(np.log(m/0.1) / (m * T)),
                        minibatch=MINIBATCH,
    )
    with open(f'figdata/{droinstance.name}_{algname}.csv', 'w') as f:
        convhist = [(T, obj - opt) for T, obj in objhist]
        csv_writer=csv.writer(f)
        csv_writer.writerows(convhist)

# %%
def run_adult_hinge_DRO(algname):
    # Adult Hinge DRO instance
    print(f'running {algname} on adult_hinge DRO instance...')
    alg = ALGS[algname]
    opt = OPT['hinge']
    droinstance = adult_hinge_DRO_instance(D)
    theta0 = np.zeros(droinstance.n)
    m = droinstance.m
    params = BEST_PARAMS[f'{droinstance.name} {algname} mbatch={MINIBATCH} T={T}']
    Ct = float(params['Ct'])
    Cq = float(params['Cq'])
    output, objhist = alg(droinstance, T, theta0,
                        eta_t=lambda t: Ct * D * np.sqrt(1/t),
                        eta_q=Cq * np.sqrt(np.log(m)/(m*T)),
                        beta=np.sqrt(np.log(m/0.1) / (m * T)),
                        minibatch=MINIBATCH,
    )
    with open(f'figdata/{droinstance.name}_m6_{algname}.csv', 'w') as f:
        convhist = [(T, obj - opt) for T, obj in objhist]
        csv_writer=csv.writer(f)
        csv_writer.writerows(convhist)

# %%
def run_synthetic_DRO(algname, m):
    # synthetic DRO instance
    print(f'running {algname} on synthetic DRO instance of m={m}...')
    alg = ALGS[algname]
    n = 500
    droinstance = synthetic_DRO_instance(m, n, D)
    theta0 = np.zeros(droinstance.n)
    params = BEST_PARAMS[f'{droinstance.name} {algname} mbatch={MINIBATCH} T={T} m={m}']
    Ct = float(params['Ct'])
    Cq = float(params['Cq'])
    output, objhist = alg(droinstance, T, theta0,
                        eta_t=lambda t: Ct * D * np.sqrt(1/t),
                        eta_q=Cq * np.sqrt(np.log(m)/(m*T)),
                        beta=np.sqrt(np.log(m/0.1) / (m * T)),
                        minibatch=MINIBATCH,
    )
    with open(f'figdata/{droinstance.name}_m{m}_{algname}.csv', 'w') as f:
        csv_writer=csv.writer(f)
        csv_writer.writerows(objhist)

# %%
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='reproduce figdata')
    parser.add_argument('instance', help='DRO instance name')
    parser.add_argument('algname', help='Algorithm name')
    parser.add_argument('--m',      dest='m',            nargs='?', type=int,   default=50)
    args = parser.parse_args()
    algname = args.algname
    if args.instance == 'adult':
        run_adult_DRO(algname)
    elif args.instance == 'adult_hinge':
        run_adult_hinge_DRO(algname)
    else:
        m = args.m
        run_synthetic_DRO(algname, m)


