from src.minimax_policy import *


def evaluate_model_sampling(policy, pre_theta, sample_size, it, l_upper=None, binary=False, approx=True, fix=False):
    num_theta = pre_theta.size(0)
    if fix:
        theta = pre_theta
    else:
        theta = torch.sigmoid(pre_theta) * 2 - 1
    log_theta = F.logsigmoid(pre_theta)
    log_theta_neg = F.logsigmoid(-pre_theta)
    if l_upper is not None:
        if approx:
            c, _ = c_fn(theta)
        else:
            _, c = rho_optimal(theta, difficulty=True)
        l = torch.log2(c)
        in_range = l < l_upper
        barrier_loss = (l - l_upper) * (l > l_upper).float() * 1e4
        # random_walk = F.mse_loss(theta,
        #                          theta.data + torch.randn(theta.size()).float().cuda(),
        #                          reduction='none')
        # barrier_loss = barrier_loss + torch.sum(random_walk, dim=-1) * (1 if approx else 1e2)
    p = policy(theta.data).repeat(sample_size, 1)
    count = torch.distributions.multinomial.Multinomial(N_pulls * N_budget, p).sample()
    expand_theta = theta.repeat(sample_size, 1)
    expand_log_theta = log_theta.repeat(sample_size, 1)
    expand_log_theta_neg = log_theta_neg.repeat(sample_size, 1)
    sample = torch.distributions.binomial.Binomial(total_count=count,
                                                   probs=(expand_theta + 1) / 2).sample()
    sample_neg = count - sample
    log_prob = torch.sum(sample * expand_log_theta + sample_neg * expand_log_theta_neg, dim=-1)
    log_sum = log_prob / N_pulls
    samples = sample * 2 - count

    count[count == 0] = 1
    samples = samples / count
    accumulated_theta = torch.mm(expand_theta, Z)
    accumulated_samples = torch.mm(samples, Z)

    max_ind = argmax_uniform(accumulated_samples, dim=-1)
    reward_max = torch.gather(accumulated_theta, 1, max_ind.unsqueeze(1)).squeeze() - \
                 torch.max(accumulated_theta, dim=-1)[0]
    acc_max = (max_ind == torch.argmax(accumulated_theta, dim=-1)).float() * R_lim

    if fix:
        print()
        print(torch.mean(reward_max).data.cpu().numpy(), torch.min(reward_max).data.cpu().numpy())
        print("Accuracy:", torch.mean(acc_max).data.cpu().numpy() / R_lim,
              torch.min(acc_max).data.cpu().numpy() / R_lim)
        return torch.mean(acc_max.view(sample_size, num_theta), dim=0).cpu().numpy() / R_lim

    if l_upper is None:
        sys.stdout.write('\r% 4d  % 9.8f' %
                         (it, torch.mean(acc_max).data.cpu().numpy() / R_lim if binary
                         else torch.mean(reward_max).data.cpu().numpy()))
    else:
        sys.stdout.write('\r% 4d  % 9.8f  % 9.8f  % 5d' %
                         (it,
                          torch.mean(acc_max).data.cpu().numpy() / R_lim if binary
                          else torch.mean(reward_max).data.cpu().numpy(),
                          torch.min(torch.mean(
                              (acc_max / R_lim if binary else reward_max).view(sample_size, num_theta),
                              dim=0) + (1 - in_range.float()) * 1000).data.cpu().numpy(),
                          torch.sum(in_range).data.cpu().numpy()))
    sys.stdout.flush()
    if binary:
        avg_acc_max = torch.mean(acc_max.view(sample_size, num_theta), dim=0)
        return1 = avg_acc_max.data / R_lim
        g_loss = torch.mean((acc_max * log_sum).view(sample_size, num_theta), dim=0)
    else:
        avg_reward_max = torch.mean(reward_max.view(sample_size, num_theta), dim=0)
        return1 = avg_reward_max.data
        g_loss = torch.mean((reward_max + reward_max.data * log_sum).view(sample_size, num_theta),
                            dim=0)

    if l_upper is not None:
        g_loss = barrier_loss + g_loss
    return return1, g_loss, in_range, theta.data


def evaluate_model_sha(policy, num_theta=1000 if game_mode == "1d" else 100, reduction=.25, n_rung=3, sample_size=100,
                       n_it=100, binary=False, l_upper=None, plot=False, approx=True):
    worst_theta = None
    worst_rewards = None
    losses = []
    it = 0
    batch_num_theta = num_theta
    num_theta = int(num_theta / (reduction ** (n_rung - 1)))

    pre_theta = sample_theta(num_theta)
    pre_theta = torch.log(pre_theta) - torch.log(1 - pre_theta)
    pre_theta = nn.Parameter(pre_theta)

    lr = 1e-3 if game_mode == "1d" else 1e-2
    optim = Adam([pre_theta], lr=lr)

    inds = torch.arange(num_theta).cuda()
    for i in range(n_rung):
        rewards = []
        for j in range((num_theta - 1) // batch_num_theta + 1):
            ind_set = inds[j * batch_num_theta:min((j + 1) * batch_num_theta, num_theta)]
            rs = []
            for _ in range(n_it):
                it += 1
                return1, g_loss, in_range, theta = evaluate_model_sampling(policy, pre_theta[ind_set], sample_size,
                                                                           it, l_upper=l_upper, binary=binary,
                                                                           approx=approx, fix=False)
                rs.append(return1.data.unsqueeze(-1))
                optim.zero_grad()
                g_loss = torch.sum(g_loss)
                losses.append(g_loss.data.cpu().numpy())
                g_loss.backward()
                optim.step()
            in_range = in_range.float()
            rewards.append(torch.cat(rs, dim=-1) * in_range.unsqueeze(-1) + 1000 * (1 - in_range.unsqueeze(-1)))
        rewards = torch.cat(rewards, dim=0)[:, -min(n_it // 4, 100):]
        rewards = torch.mean(rewards, dim=-1)
        num_theta = int(num_theta * reduction)
        if i != n_rung - 1:
            inds = inds[torch.topk(rewards, num_theta, largest=False, sorted=False)[1]]
            n_it = int(n_it / reduction)
    del g_loss
    rewards = rewards.cpu().numpy()
    min_theta = np.argmin(rewards)
    worst_reward = rewards[min_theta]
    print()
    if plot:
        worst_theta = torch.sigmoid(pre_theta[inds[int(min_theta)]].data) * 2 - 1
        worst_rewards = rewards
    if plot:
        return worst_reward, worst_theta, worst_rewards
    else:
        return worst_reward


def approximate_optimal(theta, visualize=False):
    accumulated_theta = torch.mm(theta, Z)
    theta_star, i_star = torch.max(accumulated_theta, dim=-1)
    theta_star = theta_star.unsqueeze(-1)
    numer = Z_sym_diff_ct[i_star]
    accumulated_theta = torch.clamp(torch.pow(accumulated_theta - theta_star, 2), min=1e-8)
    c = numer / accumulated_theta
    sym_diff = Z_sym_diff[i_star]
    identity = torch.eye(N_dim).cuda()  # 50 * 50
    sym_diff = torch.tensordot(sym_diff, identity, dims=([-1], [-1]))
    c = torch.max(c.unsqueeze(-1) * sym_diff, dim=1)[0]
    return c / torch.sum(c, dim=-1).unsqueeze(-1)


if __name__ == "__main__":
    for l_upper in eval_range:
        print("=" * 20)
        torch.manual_seed(12345)
        binary = False
        eval_approx = False
        assert len(sys.argv) == 3, "Usage: python static_allocs.py game_mode oracle"
        oracle_name = sys.argv[2]
        assert oracle_name == 'rho' or oracle_name == 'approx', "oracle name should be either rho or approx"
        print(l_upper)

        if oracle_name == "rho":
            print("Rho*")
            worst_reward, worst_theta, worst_rewards = \
                evaluate_model_sha(rho_optimal, plot=True, l_upper=l_upper, binary=binary, approx=eval_approx)
            print(worst_reward)
            print(worst_theta.cpu())
            evaluate_model_sampling(rho_optimal, worst_theta.view(1, -1), 10000, 1, fix=True)
        else:
            print("Approximate")
            worst_reward, worst_theta, worst_rewards = \
                evaluate_model_sha(approximate_optimal, plot=True, l_upper=l_upper, binary=binary, approx=eval_approx)
            print(worst_reward)
            print(worst_theta.cpu())
            evaluate_model_sampling(approximate_optimal, worst_theta.view(1, -1), 10000, 1, fix=True)

    # def run(optimal, h):
    #     probs = []
    #     for _ in range(10):
    #         Zs = Z.transpose(0, 1).repeat(10 if game_mode == '1d' else 1, 1) * 2 - 1
    #         Zs = Zs * (torch.bernoulli(h * torch.ones(Zs.size()).float().cuda()) * 2 - 1)
    #         probs.extend(
    #             evaluate_model_sampling(optimal, 1, 1, 1, fix_theta=Zs))
    #     print(np.mean(np.array(probs)))
    #     return np.mean(np.array(probs))
    #
    # optimal = rho_optimal if oracle_name == "rho" else approximate_optimal
    # acc = []
    # for h in [.5, .6, .7, .8, .9, 1]:
    #     print()
    #     print("=" * 20)
    #     print("Rescaled Z Average Case")
    #     print(h)
    #     acc.append(run(optimal, h))
    # print(acc)
    #
    # if game_mode == "q20":
    #     print()
    #     print("=" * 20)
    #     print("Dataset Average Case")
    #     torch.manual_seed(12345)
    #     QEnv = Q20Env()
    #     theta = torch.from_numpy(QEnv.get_probs()[0]).float().cuda()
    #     theta = theta * 2 - 1
    #     probs = []
    #     n = 100
    #     for i in range(theta.size(0) // n):
    #         probs.extend(evaluate_model_sampling(optimal, n, 100, 1, fix_theta=theta[i * n:i * n + n]))
    #     weight = np.array([p[1] for p in QEnv.posList])
    #     weight /= np.sum(weight)
    #     print(np.sum(probs * weight))
    #     print(np.mean(probs))
