from src.minimax_policy import *
from src.jester_loader import get_z


def worst_case(PT, l_upper, eval_approx):
    worst_reward = float('inf')
    worst_theta = None
    for _ in range(1):
        worst_reward1, worst_theta1, _ = \
            evaluate_model_sampling(PT, 100, 2000, plot=True, l_upper=l_upper, approx=eval_approx)
        if worst_reward1 < worst_reward:
            worst_reward = worst_reward1
            worst_theta = worst_theta1
    print(worst_reward)
    print(worst_theta)

    evaluate_model_sampling(PT, 10000, 1, plot=True, l_upper=l_upper, fix_theta=worst_theta.view(1, -1),
                            approx=False)


def worst_case_sha(PT, l_upper, eval_approx, num_theta=1000, reduction=.25, num_rung=3, n_sample=100, n_it=100):
    worst_reward = float('inf')
    worst_theta = None
    for _ in range(1):
        worst_reward1, worst_theta1, _ = \
            evaluate_model_sha(PT, num_theta, reduction, num_rung, n_sample, n_it, plot=True,
                               approx=eval_approx)
        if worst_reward1 < worst_reward:
            worst_reward = worst_reward1
            worst_theta = worst_theta1
    print(worst_reward)
    print(worst_theta)

    evaluate_model_sampling(PT, 10000, 1, plot=True, l_upper=l_upper, fix_theta=worst_theta.view(1, -1),
                            approx=False)


def average_case_dataset(PT, l_upper):
    print()
    print("=" * 20)
    print("Dataset Average Case")
    torch.manual_seed(12345)
    if game_mode == "q20":
        QEnv = Q20Env()
        theta = torch.from_numpy(QEnv.get_probs()[0]).float().cuda()
        theta = theta * 2 - 1
        probs = []
        n = 100
        for i in range(theta.size(0) // n):
            probs.extend(evaluate_model_sampling(PT, 100, 1, plot=True, l_upper=l_upper, approx=False,
                                                 fix_theta=theta[i * n:i * n + n]))
        weight = np.array([p[1] for p in QEnv.posList])
        weight /= np.sum(weight)
        print(np.sum(probs * weight))
        print(np.mean(probs))
    elif game_mode == "jester":
        _, theta = get_z(True)
        theta = torch.from_numpy(theta).float().cuda() / 10
        probs = []
        n = 100
        for i in range(theta.size(0) // n):
            probs.extend(evaluate_model_sampling(PT, 100, 1, plot=True, l_upper=l_upper, approx=False,
                                                 fix_theta=theta[i * n:i * n + n]))
        print(np.mean(probs))


def average_case_rescale(PT, h, l_upper):
    print()
    print("=" * 20)
    print("Rescaled Z Average Case")
    print(h)
    torch.manual_seed(12345)
    probs = []
    for _ in range(10):
        Zs = Z.transpose(0, 1).repeat(10 if game_mode == '1d' else 1, 1) * 2 - 1
        Zs = Zs * (torch.bernoulli(h * torch.ones(Zs.size()).float().cuda()) * 2 - 1)
        probs.extend(evaluate_model_sampling(PT, 1, 1, plot=True, l_upper=l_upper, approx=False, fix_theta=Zs))
    print(np.mean(np.array(probs)))
    return np.mean(np.array(probs))
