import numpy as np
import gurobipy as gp
from gurobipy import GRB

import numpy as np 

from rsome import ro
import rsome as rso
from rsome import msk_solver as msk
from rsome import grb_solver as grb

def get_costs(demand, x, H = 1, B = 5): 
    return H * np.sum(np.maximum(x - demand, 0), axis=1) + B * np.sum(np.maximum(demand - x, 0), axis=1)

def nominal_saa(demand, H = 1, B = 5):
    N = demand.shape[0]
    K = demand.shape[1]

    solver = gp.Model('solver')
    solver.setParam('OutputFlag', 0)

    k_ =  [k for k in range(K)]
    n_ = [n for n in range(N)]
    nk_ =  [(n,k) for n in range(N) for k in range(K)]
    
    x = solver.addVars(k_, lb = 0, name='x')
    y_h = solver.addVars(nk_, lb = 0, name='yh')
    y_b = solver.addVars(nk_, lb = 0, name='yb')
    
    for k in k_: 
        for n in n_: 
            solver.addConstr(y_b[n,k] >= demand[n,k] - x[k])
            solver.addConstr(y_h[n,k] >= x[k] - demand[n,k])

    solver.setObjective(sum(H * y_h[n,k] + B * y_b[n,k] for (n,k) in nk_))
    
    solver.update()
    solver.optimize()
    return solver.x[:K]



def kld_robust(demand, epsilon, H = 1, B = 5, capacity = 1e10):
    N = demand.shape[0]
    K = demand.shape[1]
    
    model = ro.Model()

    pi = model.rvar(N)
    uset = (pi.sum() == 1, pi >= 0,
            rso.kldiv(pi, 1/N, epsilon))    # uncertainty set of pi

    x = model.dvar(K)
    y_h = model.dvar((N,K))
    y_b = model.dvar((N,K))

    model.st(y_b >= 0, y_h >= 0)
    model.st(y_b >= demand - x)
    model.st(y_h >=  x - demand)

    model.st(x.sum() <= capacity)
    model.minmax( (pi * (H * y_h + B * y_b).sum(axis=1)).sum(), uset)

    model.solve(msk, display=False)
    return x.get()

def vd_robust(demand, epsilon, H = 1, B = 5, capacity = 1e10):
    N = demand.shape[0]
    K = demand.shape[1]
    
    model = ro.Model()

    pi = model.rvar(N)
    uset = (pi.sum() == 1, pi >= 0,
            rso.norm(pi - 1/N, 1) <= epsilon)    # uncertainty set of pi

    x = model.dvar(K)
    y_h = model.dvar((N,K))
    y_b = model.dvar((N,K))

    model.st(y_b >= 0, y_h >= 0)
    model.st(y_b >= demand - x)
    model.st(y_h >=  x - demand)

    model.st(x.sum() <= capacity)
    model.minmax( (pi * (H * y_h + B * y_b).sum(axis=1)).sum(), uset)

    model.solve(msk, display=False)
    return x.get()


def adaptive_robust(demand, Gamma, H = 1, B = 5, capacity = 1e10): 
    N = demand.shape[0]
    K = demand.shape[1]

    dmin = np.min(demand)
    dmax = np.max(demand)

    k_ =  [k for k in range(K)]
    n_ = [n for n in range(N)]
    nk_ =  [(n,k) for n in range(N) for k in range(K)]

    model = ro.Model()              

    x = model.dvar(K)
    y_h = model.ldr(K)       
    y_b = model.ldr(K)       
    d = model.rvar(K)        

    y_h.adapt(d)                  # the decision rule y affinely depends on d
    y_b.adapt(d)

    model.st(y_h >= 0, y_b >= 0)
    model.st(y_b >= d - x)
    model.st(y_h >= x - d)

    uset = (d >= dmin,
            d <= dmax,
            rso.norm((d - np.mean(demand, axis=0)) * (1.0/ np.std(demand, axis=0)), 1) <= Gamma)    # define the uncertainty set

    model.st(x.sum() <= capacity)
    model.minmax((H * y_h+ B * y_b).sum(), uset)

    model.solve(msk, display=False)
    return x.get()

def my_robust_exact(demand, t, H = 1, B = 5, capacity = 1e10): 
    N = demand.shape[0]
    K = demand.shape[1]
    model = ro.Model()          

    x = model.dvar(K)
    y_h = model.dvar((N, K))       
    y_b = model.dvar((N, K)) 
    q = model.dvar(N, vtype="B")

    model.st(y_b >= 0, y_h >= 0)
    model.st(q >= 0)
        
    model.st(y_b >= demand - x)
    model.st(y_h >= x - demand)
    model.st(t >= (H * y_h + B * y_b).sum(axis=1) - 8000 * (1 - q))
    
    model.st(x.sum() <= capacity)
    model.max(q.sum())
    # model.min(rso.norm(q, 2))

    model.solve(grb, display=False, params={'MIPGap': 1e-2})
    return x.get()


def my_robust_relax(demand, t, H = 1, B = 5, capacity = 1e10): 
    N = demand.shape[0]
    K = demand.shape[1]

    model = ro.Model()          

    x = model.dvar(K)
    y_h = model.dvar((N, K))       
    y_b = model.dvar((N, K)) 
    q = model.dvar(N)

    model.st(y_b >= 0, y_h >= 0)
    model.st(q >= 0)
        
    model.st(y_b >= demand - x)
    model.st(y_h >= x - demand)
    model.st(q >= (H * y_h + B * y_b).sum(axis=1) - t)
    
    model.st(x.sum() <= capacity)
    model.min(q.sum())
    # model.min(rso.norm(q, 2))

    model.solve(grb, display=False)
    return x.get()