from src.minimax_policy import *


def constrained_rho_optimal(theta, n_it=1000, difficulty=False):
    if difficulty:
        theta_wgrad = theta
        theta = theta.data
    lambda_pre = torch.ones(theta.size()).float().cuda() / N_dim  # N * N_dim
    i_star = torch.argmax(torch.mm(theta, Z), dim=-1)  # N
    Zs = Z.transpose(0, 1)  # N_arms * N_dim
    Z_star = Zs[i_star, :]  # N * N_dim
    Z_diff = Z_star.unsqueeze(1) - Zs.unsqueeze(0)  # N * N_arms * N_dim
    for it in range(n_it):
        lambda_pre.requires_grad = False
        theta_z = theta.unsqueeze(1)  # N * 1 * N_dim
        for j in range(10):
            theta_z = theta_z - torch.sum(theta_z * Z_diff, dim=-1).unsqueeze(-1) / torch.clamp(
                torch.sum((Z_diff ** 2) / lambda_pre.unsqueeze(1), dim=-1).unsqueeze(-1),
                1e-20) * (Z_diff / lambda_pre.unsqueeze(1))
            theta_z.clamp_(-1, 1)
        lambda_pre.requires_grad = True
        lambda_pre.grad = None
        loss = torch.sum(((theta.unsqueeze(1) - theta_z.data) ** 2) * lambda_pre.unsqueeze(1), dim=-1)  # N * N_arms
        loss = loss + 1e20 * (loss == 0).float()
        loss = torch.sum(torch.max(1 / loss, dim=-1)[0])
        # print(it, loss.data.cpu().numpy())
        loss.backward()
        grad = to_one_hot(torch.argmin(lambda_pre.grad, dim=-1), theta.size(-1))
        gamma = 2. / (it + 3)
        lambda_pre = lambda_pre.data * (1 - gamma) + gamma * grad
    if difficulty:
        lambda_pre.requires_grad = False
        theta_z = theta.unsqueeze(1)  # N * 1 * N_dim
        for j in range(100):
            theta_z = theta_z - torch.sum(theta_z * Z_diff, dim=-1).unsqueeze(-1) / torch.clamp(
                torch.sum((Z_diff ** 2) / lambda_pre.unsqueeze(1), dim=-1).unsqueeze(-1),
                1e-20) * (Z_diff / lambda_pre.unsqueeze(1))
            theta_z.clamp_(-1, 1)
        # print(theta_z)
        loss = torch.sum(((theta_wgrad.unsqueeze(1) - theta_z.data) ** 2) * lambda_pre.unsqueeze(1),
                         dim=-1)  # N * N_arms
        loss = loss + 1e20 * (loss == 0).float()
        loss = torch.min(loss, dim=-1)[0]
        return lambda_pre, 1 / loss
    else:
        return lambda_pre.data


if __name__ == "__main__":
    # theta = np.array([-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352, -0.4076,
    #                   0.7329, 0.1568, 0.3411, -0.3599, -0.0094, -0.4265, 0.2287, -0.7805,
    #                   -0.1288, 0.3354, -0.2866, -0.8773, -0.0857, -0.1532, 0.4471, -0.6913,
    #                   0.3511])
    # import matplotlib.pyplot as plt
    # plt.plot(theta)
    # plt.show()
    # theta = torch.from_numpy(theta).float().cuda().unsqueeze(0)
    #
    # print(rho_optimal(theta, difficulty=True))
    # print(constrained_rho_optimal(theta, difficulty=True))

    theta = np.array([[-1.0000, -0.4585, -0.5033, 0.4508, 0.3536, 0.4187, 0.3499,
                       -0.5473, 0.5964, 0.0727, 0.2669, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.5754, -0.6231, 0.3293, 0.2296, 0.2960, 0.2318,
                       -0.6363, 0.5094, 0.0191, 0.2196, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.7012, 0.2501, 0.1488, 0.2160, 0.1549,
                       -0.6943, 0.4527, -0.0159, 0.1888, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.1385, 0.0349, 0.1033, 0.0464,
                       -0.7761, 0.3728, -0.0651, 0.1454, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.0617, 0.1298, 0.0719,
                       -0.7569, 0.3916, -0.0536, 0.1556, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.1444, 0.0859,
                       -0.7463, 0.4020, -0.0472, 0.1612, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.1291,
                       -0.7138, 0.4338, -0.0276, 0.1785, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.6723, 0.4743, -0.0026, 0.2005, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.1631, -0.1944, 0.0313, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, -0.1078, 0.1078, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.0000, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3599, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, 0.0000, -0.0094, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.1760, 0.1760, -0.4265,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.1242, 0.2282, -0.1040,
                       0.2287, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.2400, 0.1115, -0.2624,
                       0.3910, -0.7805, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.2172, 0.1345, -0.2312,
                       0.4218, -0.1079, -0.1288, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.2565, 0.0948, -0.2850,
                       0.3686, -0.2932, 0.3713, 0.3354, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3006, 0.0504, -0.3453,
                       0.3090, -0.5010, 0.1581, 0.6295, -0.2866, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3082, 0.0427, -0.3558,
                       0.2986, -0.5370, 0.1211, 0.5916, 0.1471, -0.8773, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.2973, 0.0537, -0.3408,
                       0.3135, -0.4853, 0.1742, 0.6460, 0.2393, -0.3033, -0.0857,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3082, 0.0427, -0.3558,
                       0.2986, -0.5369, 0.1212, 0.5917, 0.1473, -0.4037, 0.4029,
                       -0.1532, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3135, 0.0373, -0.3631,
                       0.2914, -0.5620, 0.0955, 0.5653, 0.1027, -0.4524, 0.3527,
                       0.2459, 0.4471, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3270, 0.0237, -0.3815,
                       0.2732, -0.6255, 0.0303, 0.4985, -0.0105, -0.5760, 0.2252,
                       0.1298, 0.7400, -0.6913, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3225, 0.0283, -0.3754,
                       0.2793, -0.6043, 0.0520, 0.5208, 0.0272, -0.5348, 0.2677,
                       0.1685, 0.7800, -0.2868, 0.3511],
                      [-0.8359, -0.2751, -0.3154, 0.6415, 0.5482, 0.6113, 0.5352,
                       -0.4076, 0.7329, 0.1568, 0.3411, -0.3317, 0.0191, -0.3879,
                       0.2669, -0.6475, 0.0077, 0.4754, -0.0497, -0.6187, 0.1811,
                       0.0897, 0.6985, -0.3859, 0.6829]])
    theta = torch.from_numpy(theta).float().cuda()
    theta = Z.transpose(0, 1) * 2 - 1
    print(theta)
    print(constrained_rho_optimal(theta, difficulty=True)[1])
    l_upper = 7
    print("=" * 20)
    torch.manual_seed(12345)
    binary = True
    eval_approx = False
    fname = "../model/finetune_overall_1d_1_7.dict"
    print(fname, l_upper)
    minimax_dict = torch.load(fname)
    num_theta = 1
    PT = Training(minimax_dict['policy'],
                  'None', 'None', 'None', 'None', 'None', l_upper, 0,
                  binary, 0, eval_approx, num_theta * 16, 100).cuda()
    # worst_case(PT, l_upper, eval_approx)
    print(evaluate_model_sampling(PT, 10000, 1, fix_theta=theta))
