from src.minimax_policy import *
from src.utils import float2str

if __name__ == "__main__":
    binary = True
    assert len(sys.argv) > 4, "Usage: python train_overall.py game_mode model_id l_upper model1 model2 ..."
    l_upper = float(sys.argv[3])
    save_id = int(sys.argv[2])
    load_names = sys.argv[4:]
    fname = '../model/finetune_overall_%s_%d_%s.dict' % (game_mode, save_id, float2str(l_upper))
    print("Loading file from", load_names)
    print("Saving file to", fname)
    if load_names[0] == "f_star":
        assert len(load_names) == 2, "f_star must be followed by one and only one model name to finetune on"
        init_dict = torch.load(load_names[1])
        reward, p, po, g, go, s, so = train(init_dict['policy'], None, None, None, M, N_it_overall,
                                            sample_size=L, sampler=None, binary=binary, finetune=True,
                                            l_upper=l_upper, fname=fname, num_samples=N_arms * N_GEN,
                                            num_checks=2 if game_mode == "q20" else 5, overall="f_star")
    else:
        baselines = [torch.load(load_name)['policy'] for load_name in load_names]
        reward, p, po, g, go, s, so = train(baselines[(len(baselines) - 1) // 2], None, None, None, M, N_it_overall,
                                            sample_size=L, sampler=None, binary=binary, finetune=True,
                                            l_upper=l_upper, fname=fname, num_samples=N_arms * N_GEN,
                                            num_checks=3, overall=baselines)
    minimax_dict = torch.load(fname)
