from sklearn.neighbors import NearestNeighbors

import numpy as np

from rsome import dro

from rsome import square
from rsome import ro
import rsome as rso
from rsome import msk_solver as msk
from rsome import grb_solver as grb

from rsome import E

import torch

def setup_knn(X_train, K): 
    neigh = NearestNeighbors(n_neighbors=K, radius=1)
    neigh.fit(X_train)
    return neigh

def knn_robust_kl(x_test, X_train, demands, K, params, neigh, epsilon):

    c_ramp = params['c_ramp']
    gamma_under = params['gamma_under']
    gamma_over = params['gamma_over']

    #  get k-nearest neighbors as the empirical data
    indx = neigh.kneighbors([x_test], K, False)

    D = np.array(demands[indx][0])

    N = len(D) 

    model = ro.Model()

    pi = model.rvar(N)
    uset = (pi.sum() == 1, pi >= 0,
            rso.kldiv(pi, 1/N, epsilon))    # uncertainty set of pi

    w = model.dvar(24)
    model.st(w >= 0)

    for i in range(23):
        model.st(w[i] - w[i + 1] <= c_ramp)
        model.st(w[i] - w[i + 1] >= -c_ramp)

    yb = model.dvar((N, 24))
    yh = model.dvar((N, 24))

    model.st(yh >= 0, yb >= 0)
    model.st(yb >= D - w)
    model.st(yh >= w - D)

    e = model.dvar((N, 24))
    e2 = model.dvar((N, 24))
    model.st(e >= w - D)
    model.st(e2 >= square(e))

    model.minmax((pi * (gamma_under * yb + gamma_over * yh + 0.5 * e2).sum(axis=1)).sum(), uset)

    model.solve(msk, display=False)
    return w.get()

def robust_predict(model, x_test, Lambda, params): 
    c_ramp = params['c_ramp']
    gamma_under = params['gamma_under']
    gamma_over = params['gamma_over']

    with torch.no_grad():
        pred = model(x_test)[0].numpy()

    robust = ro.Model()
    w = robust.dvar(24)
    dd = robust.dvar(24) 
    d = robust.rvar(24)

    robust.st(dd >= (1 + d) * pred - 1)
    robust.st(dd <= (1 + d) * pred + 1)

    uset = (rso.norm(d, 1) <= Lambda, d <= 1, d >= -1)

    for i in range(23):
        robust.st(w[i] - w[i + 1] <= c_ramp)
        robust.st(w[i] - w[i + 1] >= -c_ramp)

    yb = robust.dvar(24)
    yh = robust.dvar(24)

    robust.st(yh >= 0, yb >= 0)
    robust.st(yb >= dd - w)
    robust.st(yh >= w - dd)

    e = robust.dvar(24)
    e2 = robust.dvar(24)
    robust.st(e >= w - dd)
    robust.st(e2 >= square(e))

    robust.minmax((gamma_under * yb + gamma_over * yh + 0.5 * e2).sum(), uset)
    robust.solve(grb, display=False)
    return w.get()

