from src.eval_helper import *


def binary_search(X, beta):
    samples = X[:, :N_dim].unsqueeze(1)  # N * 1 * N_dim
    counts = X[:, N_dim:2 * N_dim].unsqueeze(1)  # N * 1 * N_dim
    Zs = Z.transpose(0, 1).unsqueeze(0) * 2 - 1  # 1 * N_arms * N_dim
    pos = torch.sum(counts + samples * Zs, dim=-1) / 2  # N * N_arms
    neg = torch.sum(counts - samples * Zs, dim=-1) / 2  # N * N_arms
    probs = ((1 - beta) ** pos) * (beta ** neg)  # N * N_arms
    probs = probs / torch.sum(probs, dim=-1).unsqueeze(-1)
    sums = -torch.abs(torch.sum(probs.unsqueeze(-1) * Zs, dim=1))
    actions = argmax_uniform(sums, dim=-1)
    return to_one_hot(actions, N_dim) * 1e20


if __name__ == "__main__":
    for l_upper in eval_range:
        print("=" * 20)
        print(l_upper)
        binary = False
        eval_approx = True
        assert len(sys.argv) == 2, "Usage: python binary_search.py game_mode"
        num_theta = 100 if game_mode == '1d' else 10
        for beta in [.01, .03, .1, .3]:
            torch.manual_seed(12345)
            print("Beta:", beta)
            PT = Training(lambda X: binary_search(X, beta),
                          'None', 'None', 'None', 'None', 'None', l_upper + 1, 0,
                          binary, 0, eval_approx, num_theta*16, 100).cuda()
            # worst_case(PT, l_upper, eval_approx)
            # worst_case_sha(PT, l_upper, eval_approx, num_theta=num_theta, n_it=100)

    for beta in [.01, .03, .1, .2, .3, .4]:
        print("Beta:", beta)
        acc = []
        for h in [.5, .6, .7, .8, .9, 1]:
            acc.append(average_case_rescale(PT, h, l_upper))
        print(acc)

    if game_mode != "1d":
        for beta in [.01, .03, .1, .2, .3, .4]:
            print("Beta:", beta)
            PT = Training(lambda X: binary_search(X, beta),
                          'None', 'None', 'None', 'None', 'None', l_upper + 1, 0,
                          binary, 0, eval_approx, 100 if game_mode == '1d' or game_mode == '2d25' else 10,
                          100).cuda()
            average_case_dataset(PT, l_upper)
