import numpy as np
from torch import optim
import torch 

from baselines import sample_robust

def get_gradients(opt, rs, x): 
    grad = np.zeros_like(rs)
    for i, r in enumerate(rs): 
        if np.abs(opt - np.dot(x, r)) < 1e-5: 
            grad[i,:] = x
    return grad


def update_rs(rs, grad, generator, lr, t): 
    rs -= grad * lr
    # return rs

    cur_rs = torch.autograd.Variable(torch.tensor(rs), requires_grad = True)
    optimizer = optim.Adam([cur_rs], lr = 0.01)

    # print(generator.log_prob(cur_rs))
    mask = generator.log_prob(cur_rs) < t
    # print(torch.sum(mask))
    iter = 0
    while torch.sum(mask) > 0.5: 
        # print(iter)
        iter += 1
        # if iter > 200: break 
        # iter += 1
        optimizer.zero_grad()
        # print(mask)
        loss = -torch.mean(generator.log_prob(cur_rs) * mask)
        loss.backward() 
        optimizer.step()
        
        mask = generator.log_prob(cur_rs) < t
        # print(generator.log_prob(cur_rs))
        # print(torch.sum(mask))    

    return cur_rs.detach()


def solve_maximal(train_data, test_data, generator, epochs = 10, lr = 1e-3, t=-17): 
    rs = train_data.clone()

    for i in range(epochs):
        sample_solver = sample_robust(rs)
        
        x_cur = sample_solver.X[:-1]
        opt = sample_solver.objVal 


        grad = get_gradients(opt, rs, x_cur) 

        rs = update_rs(rs, grad, generator, lr, t)
        # print(np.quantile(np.dot(test_data, x_cur), 0.05))
    return x_cur, rs
