import gurobipy as gb
from gurobipy import GRB

import numpy as np

from sklearn.base import clone
from sklearn.metrics import accuracy_score

class PowerCaps:
    def __init__(self, X, demands, params, EPS = 0):
       
        self.X = X 
        self.demands = demands
        self.EPS = EPS
        self.splits = np.arange(np.min(self.demands), np.max(self.demands), 0.02)

        self.c_ramp = params['c_ramp']
        self.gamma_under = params['gamma_under']
        self.gamma_over = params['gamma_over']
        
        self.n_splits = self.splits.shape[0]
        self.n = demands.shape[0]
        self.dim = demands.shape[1]

    def get_solutions(self):
        self.opt_s = []
        for i in range(self.n):
            self.opt_s.append(self.solve(self.demands[i,:]))

    def solve(self, d): 
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        s = model.addVars(self.dim)
        b = model.addVars(self.dim)
        h = model.addVars(self.dim)

        for i in range(self.dim): 
            model.addConstr(b[i] >= 0)
            model.addConstr(h[i] >= 0)
            model.addConstr(b[i] >= d[i] - s[i])
            model.addConstr(h[i] >= s[i] - d[i])
            

        for i in range(self.dim - 1):
            model.addConstr(s[i] - s[i + 1] <= self.c_ramp)
            model.addConstr(s[i] - s[i + 1] >= -self.c_ramp)
        
        model.setObjective(self.gamma_under * sum(b[i] for i in range(self.dim)) + self.gamma_over * sum(h[i] for i in range(self.dim))) #+ 0.5 * sum((s[i] - d[i]) * (s[i] - d[i]) for i in range(self.dim)))
        model.update()
        model.optimize()
        sol = [] 
        for i in range(self.dim): sol.append(s[i].X)
        # print(sol)
        return sol

    def objective(self, s, d): 
        return np.sum(self.gamma_under * np.maximum(d - s, 0) + self.gamma_over * np.maximum(s - d, 0)) #+ 0.5 * (s - d) * (s - d))
    def get_hour_cost(self, s, d):
        return self.gamma_under * np.maximum(d - s, 0) + self.gamma_over * np.maximum(s - d, 0) #+ 0.5 * (s - d) * (s - d)

    def init_n(self, EPS = 0):
        self.EPS = EPS
        self.labels = []
        for split in self.splits: 
            cur = []
            for k in range(self.n):
                cc = []
                for j in range(self.dim): 
                    if self.opt_s[k][j] >= split: 
                        cc.append(1)
                    else: 
                        cc.append(0)
                cur.append(cc)
#                 print(self.demands[k], self.D[i])
                # print(self.objective(self.opt_s[k], self.demands[i]))
                # if self.objective(self.opt_s[k], self.demands[i]) <= self.EPS: 
                #     cur.append(1)
                # else: 
                #     cur.append(0)
            self.labels.append(cur)
        self.labels = np.array(self.labels)
        # self.labels = np.array(self.labels).flatten(start_dim = 1)

    def init_models(self, model_class): 
        # self.model = clone(model_class)
        # self.model.fit(self.X, self.labels)

        self.models = [] 
        for i in range(len(self.splits)): 
            print('training: ', i, '/', len(self.splits))
            for k in range(self.dim):
                model = clone(model_class)
                model.fit(self.X, self.labels[i,:,k])
                self.models.append(model)

    def setup_model(self, eps = 0):
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        # s = 20
        s = model.addVars(self.dim)
        
        h = model.addVars(self.n_splits, self.dim)
        b = model.addVars(self.n_splits, self.dim)
        q = model.addVars(2*self.n_splits * self.dim)
        model.update()

        t = 0
        for i in range(self.n_splits):
            model.addConstr(q[i] >= 0)
            for k in range(self.dim):
                model.addConstr(h[i,k] >= 0)
                model.addConstr(b[i,k] >= 0)
                model.addConstr(b[i,k] >= self.splits[i] - s[k])
                model.addConstr(h[i,k] >= s[k] - self.splits[i])


                model.addConstr(q[t]   >= self.gamma_under * b[i,k] - eps) 
                model.addConstr(q[t+1] >= self.gamma_over  * h[i,k] - eps)
                t += 2
                                #    0.5 * sum((s[k] - self.demands[i,k]) * (s[k] - self.demands[i,k]) for k in range(self.dim)))

        for i in range(self.dim - 1):
            model.addConstr(s[i] - s[i + 1] <= self.c_ramp)
            model.addConstr(s[i] - s[i + 1] >= -self.c_ramp)
        self.cap_solver = model
        self.q = q 
        self.s = s

    def predict(self, x, eps = 0):
        # p = self.model.predict_proba([x])[0]
        # self.cap_solver.setObjective(sum(self.q[i] * p[i] for i in range(self.n)))

        t = 0
        for i in range(self.n_splits):
            for k in range(self.dim):
                pp = self.models[t].predict_proba([x])[0]
                if len(pp) == 1: 
                    self.q[2*t].Obj = 0 
                    self.q[2*t + 1].Obj = 0
                else:   
                    self.q[2*t].Obj = pp[1]
                    self.q[2*t + 1].Obj = 1 - self.q[2*t].Obj
                t += 1

        # self.cap_solver.setObjective(sum(self.q[i] * self.models[i].predict_proba([x])[0][1] for i in range(self.n)))
        self.cap_solver.update()
        self.cap_solver.optimize()

        sol = [] 
        for i in range(self.dim): sol.append(self.s[i].X)
        return sol
