from src.minimax_policy import *
from src.utils import float2str
import sys


if __name__ == "__main__":
    binary = True
    assert len(sys.argv) == 5, "Usage: python finetune.py game_mode model_id l_upper load_name"
    save_id = int(sys.argv[2])
    l_upper = float(sys.argv[3])
    load_name = sys.argv[4]
    if binary:
        fname = '../model/finetune_binary_%s_%d_%s.dict' % (game_mode, save_id, float2str(l_upper))
    else:
        fname = '../model/finetune_regret_%s_%d_%s.dict' % (game_mode, save_id, float2str(l_upper))
    print("Loading file from", load_name)
    print("Saving file to", fname)
    minimax_dict = torch.load(load_name)
    reward, p, po, g, go, s, so = train(minimax_dict['policy'], None, minimax_dict['generator'], None, M, 500000,
                                        sample_size=L, sampler=None, binary=binary, finetune=True,
                                        l_upper=l_upper, fname=fname, num_samples=N_arms*N_GEN, num_checks=5,
                                        n_pulls=N_pulls)

    minimax_dict = torch.load(fname)
