from src.eval_helper import *


def uncertainty_sampling(X):
    samples = X[:, :N_dim]  # N * N_dim
    counts = X[:, N_dim:2*N_dim]  # N * N_dim
    samples = samples / torch.clamp(counts, 1)  # N * N_dim
    accumulated_samples = torch.mm(samples, Z)  # N * N_arms
    max_samples = (accumulated_samples == torch.max(accumulated_samples, dim=-1)[0].unsqueeze(1)).float()  # N * N_arms
    singles = (torch.sum(max_samples.squeeze(-1), dim=-1) == 1).float().unsqueeze(1) * max_samples  # N * N_arms
    accumulated_samples.add_(singles * -1000)
    max_samples = (accumulated_samples == torch.max(accumulated_samples, dim=-1)[0].unsqueeze(1)).float() + singles
    Zs = Z.transpose(0, 1)  # N_arms * N_dim
    sym_dff = torch.mm(max_samples, Zs)  # N * N_dim
    sym_dff = (sym_dff != 0).float() * (sym_dff != torch.sum(max_samples.squeeze(-1), dim=-1).unsqueeze(-1)).float()
    return sym_dff * 1e20


if __name__ == "__main__":
    for l_upper in eval_range:
        print("=" * 20)
        torch.manual_seed(12345)
        binary = True
        eval_approx = True
        assert len(sys.argv) == 2, "Usage: python uncertainty_sampling.py game_mode"
        print(l_upper)
        num_theta = 1000 if game_mode == '1d' else 100
        PT = Training(uncertainty_sampling,
                      'None', 'None', 'None', 'None', 'None', l_upper, 0,
                      binary, 0, eval_approx, num_theta * 16, 100).cuda()
        # worst_case(PT, l_upper, eval_approx)
        worst_case_sha(PT, l_upper, eval_approx, num_theta=num_theta, n_it=100)

    # acc = []
    # for h in [.5, .6, .7, .8, .9, 1]:
    #     acc.append(average_case_rescale(PT, h, l_upper))
    # print(acc)
    #
    # if game_mode != "1d":
    #     average_case_dataset(PT, l_upper)
