import gurobipy as gb
from gurobipy import GRB

import numpy as np

from sklearn.base import clone
from sklearn.metrics import accuracy_score

class Problem:
    def __init__(self, X, demands, params):
       
        self.X = X 
        self.demands = demands
#         print(min(D), max(D))
        self.D = np.arange(min(demands), max(demands) + 0.02, .1)

        self.params = params
        self.c_lin = params['c_lin']
        self.c_quad = params['c_quad']
        self.h_lin = params['h_lin']
        self.h_quad = params['h_quad']
        self.b_lin = params['b_lin']
        self.b_quad = params['b_quad']

#         print(self.D)
        self.n = self.D.shape[0]

        self.opt_w = []

    def objective(self, s, d): 
        lin = np.maximum((d - s) * self.b_lin, (s - d) * self.h_lin) + self.c_lin * s
        # quad = 0.5 * np.square(np.maximum(d - s, 0)) * self.b_quad + 0.5 * np.square(np.maximum(s - d, 0)) * self.h_quad +  0.5 * np.square(s) * self.c_quad
        quad = 0    
        return lin + quad

    def init_n(self):
        self.N = []
        for i in range(self.n): 
            cur = []
            for k in range(self.demands.shape[0]):
#                 print(self.demands[k], self.D[i])
                if self.demands[k] >= self.D[i]: cur.append(k)
#                 if self.objective(self.D[i], self.D[k]) <= eps:
#                     cur.append(k)
            if len(cur) == self.X.shape[0]: cur = cur[:-2]
            self.N.append(cur)

    def init_models(self, model_class): 
        self.models = [] 
        self.labels = []
        for i in range(self.D.shape[0]): 
            labels = np.zeros(self.X.shape[0])
            for j in self.N[i]: 
                labels[j] = 1
            if self.models == []:
                m = clone(model_class)
            else: 
                m = clone(self.models[-1])
            m.fit(self.X, labels)
            self.models.append(m)
            self.labels.append(labels)


    def predict(self, x, eps = 0):
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        # s = 20
        s = model.addVar()
        
        n = len(self.models)
        h = model.addMVar(n)
        b = model.addMVar(n)
        q = model.addMVar(2*n)
        model.update()
        r = []
        for i in range(len(self.models)):
            model.addConstr(h[i] >= 0)
            model.addConstr(b[i] >= 0)
            model.addConstr(q[i] >= 0)

            p = self.models[i].predict_proba([x])[0]
            
            if len(p) == 1: 
                cur = self.labels[i][0]
            else:
                cur = p[1]

            if cur >= 0: 
                r.append(cur)
                # print(self.D[i], eps)
                # model.addConstr(b[i] >= (self.D[i] - s - eps))
                model.addConstr(b[i] >= (self.D[i] - s))
                model.addConstr(q[2*i] >= self.b_lin * b[i] - eps)
                # model.addConstr(q[2*i] >= self.b_lin * (b[i]) + self.b_quad * (b[i]) * (b[i]) - eps)
                
                r.append(1-cur)
                # model.addConstr(h[i] >= (s - self.D[i] - eps))
                model.addConstr(h[i] >= (s - self.D[i]))
                model.addConstr(q[2*i + 1] >= self.h_lin * h[i] - eps)
                # model.addConstr(q[2*i + 1] >= self.h_lin * (h[i]) + self.h_quad * (h[i]) * (h[i]) - eps)

#                 model.addConstr(h[i] >= self.D[i] - s)
#                 model.addConstr(b[i] >= s - self.D[i])
#                 model.addConstr(q[i] >= self.H * h[i] + self.B * b[i] - eps)
            else: 
                r.append(0)
                r.append(0)

        r = np.array(r)

        model.setObjective(q @ r)
        model.update()
        model.optimize()

        return s.X#, q.X, b.X, model.getObjective().getValue()