from sklearn.neighbors import NearestNeighbors

import numpy as np

from rsome import dro

from rsome import square
from rsome import ro
import rsome as rso
from rsome import msk_solver as msk
from rsome import grb_solver as grb

from rsome import E


def setup_knn(X_train, K): 
    neigh = NearestNeighbors(n_neighbors=K, radius=1)
    neigh.fit(X_train)
    return neigh

def knn_mean(x_test, X_train, demands, K, params, neigh):
    c_lin = params['c_lin']
    c_quad = params['c_quad']
    b_lin = params['b_lin']
    b_quad = params['b_quad']
    h_lin = params['h_lin']
    h_quad = params['h_quad']
    
    #  get k-nearest neighbors as the empirical data
    indx = neigh.kneighbors([x_test], K, False)

    X = X_train[indx,:]
    D = demands[indx]

    N = len(D) 

    model = ro.Model()

    w = model.dvar()

    yb = model.dvar(N)
    yh = model.dvar(N)

    w2 = model.dvar()
    yh2 = model.dvar(N) 
    yb2 = model.dvar(N) 

    model.st(w2 >= square(w))
    model.st(yh2 >= square(yh))
    model.st(yb2 >= square(yb))
    
    model.st(yh >= 0, yb >= 0)
    model.st(yb >= D - w)
    model.st(yh >= w - D)
    
    model.min((c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2).sum())

    model.solve(msk, display=False)
    return w.get()


def knn_robust_kl(x_test, X_train, demands, K, params, neigh, epsilon):
    c_lin = params['c_lin']
    c_quad = params['c_quad']
    b_lin = params['b_lin']
    b_quad = params['b_quad']
    h_lin = params['h_lin']
    h_quad = params['h_quad']
    
    #  get k-nearest neighbors as the empirical data
    indx = neigh.kneighbors([x_test], K, False)

    D = np.array(demands[indx][0])

    # print(D)

    N = len(D) 

    model = ro.Model()

    pi = model.rvar(N)
    uset = (pi.sum() == 1, pi >= 0,
            rso.kldiv(pi, 1/N, epsilon))    # uncertainty set of pi
    
    w = model.dvar()
    model.st(w >= 0)

    yb = model.dvar(N)
    yh = model.dvar(N)

    w2 = model.dvar()
    yh2 = model.dvar(N) 
    yb2 = model.dvar(N) 

    model.st(w2 >= square(w))
    model.st(yh2 >= square(yh))
    model.st(yb2 >= square(yb))
    
    model.st(yh >= 0, yb >= 0)
    model.st(yb >= D - w)
    model.st(yh >= w - D)
    
    model.minmax((pi * (c_lin * w + b_lin * yb + h_lin * yh)).sum(), uset)
    # model.minmax((pi * (c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2)).sum(), uset)

    model.solve(msk, display=False)
    return w.get()

def knn_robust_budget(x_test, X_train, demands, K, params, neigh, epsilon):
    c_lin = params['c_lin']
    c_quad = params['c_quad']
    b_lin = params['b_lin']
    b_quad = params['b_quad']
    h_lin = params['h_lin']
    h_quad = params['h_quad']
    
    #  get k-nearest neighbors as the empirical data
    indx = neigh.kneighbors([x_test], K, False)

    X = X_train[indx,:]
    D = demands[indx] + epsilon

    N = len(D) 

    model = ro.Model()

    w = model.dvar()

    yb = model.dvar(N)
    yh = model.dvar(N)

    w2 = model.dvar()
    yh2 = model.dvar(N) 
    yb2 = model.dvar(N) 

    model.st(w2 >= square(w))
    model.st(yh2 >= square(yh))
    model.st(yb2 >= square(yb))
    
    model.st(yh >= 0, yb >= 0)
    model.st(yb >= D - w)
    model.st(yh >= w - D)
    
    model.min((c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2).sum())

    model.solve(msk, display=False)
    return w.get()


def knn_robust_wass(x_test, X_train, demands, K, params, neigh, epsilon):
    c_lin = params['c_lin']
    c_quad = params['c_quad']
    b_lin = params['b_lin']
    b_quad = params['b_quad']
    h_lin = params['h_lin']
    h_quad = params['h_quad']

    #  get k-nearest neighbors as the empirical data
    indx = neigh.kneighbors([x_test], K, False)

    D_ = demands[indx][0]

    N = len(D_) 

    model = dro.Model(N)
    w = model.dvar()
    d = model.rvar()
    u = model.rvar()

    fset = model.ambiguity()                    # create an ambiguity set
    for s in range(N):
        fset[s].suppset(d - D_[s] <= u, d - D_[s] >= -u, d >= 0) # define the support for each scenario
    fset.exptset(E(u) <= epsilon)                 # the Wasserstein metric constraint
    pr = model.p                                # an array of scenario probabilities
    fset.probset(pr == 1/N)                     # support of scenario probabilities

    yb = model.dvar()                           # define first-stage decisions
    yh = model.dvar()                           # define decision rule variables
    yb.adapt(d)                                  # y affinely adapts to z
    yh.adapt(d)                                  # y affinely adapts to u
    for s in range(N):
        yh.adapt(s)                              # y adapts to each scenario s
        yb.adapt(s)                              # y adapts to each scenario s

    model.st(yb >= 0, yh >= 0)
    model.st(yb >= d - w)
    model.st(yh >= w - d)

    # w2 = model.dvar()
    # yh2 = model.dvar(N) 
    # yb2 = model.dvar(N) 

    # model.st(w2  >= square(w))
    # model.st(yh2 >= square(yh))
    # model.st(yb2 >= square(yb))

    # model.minsup(E((c_lin * w + c_quad * w2 + b_lin * yb + b_quad * yb2 + h_lin * yh + h_quad * yh2).sum()), fset)
    model.minsup(c_lin*w + E((b_lin * yb + h_lin * yh).sum()), fset)
    model.solve(grb, display=False)                            # solve the model by Gurobi
    return w.get()
