import gurobipy as gb
from gurobipy import GRB

import numpy as np

from sklearn.base import clone
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim



class PowerCaps:
    def __init__(self, X, demands, params, EPS = 0):
       
        self.X = X 
        self.demands = demands
        self.EPS = EPS
        
        self.c_ramp = params['c_ramp']
        self.gamma_under = params['gamma_under']
        self.gamma_over = params['gamma_over']
        
        self.n = demands.shape[0]
        self.dim = demands.shape[1]

    def get_solutions(self):
        self.opt_s = []
        for i in range(self.n):
            self.opt_s.append(self.solve(self.demands[i,:]))

    def solve(self, d): 
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        s = model.addVars(self.dim)
        b = model.addVars(self.dim)
        h = model.addVars(self.dim)

        for i in range(self.dim): 
            model.addConstr(b[i] >= 0)
            model.addConstr(h[i] >= 0)
            model.addConstr(b[i] >= d[i] - s[i])
            model.addConstr(h[i] >= s[i] - d[i])
            

        for i in range(self.dim - 1):
            model.addConstr(s[i] - s[i + 1] <= self.c_ramp)
            model.addConstr(s[i] - s[i + 1] >= -self.c_ramp)
        
        model.setObjective(self.gamma_under * sum(b[i] for i in range(self.dim)) + self.gamma_over * sum(h[i] for i in range(self.dim))) #+ 0.5 * sum((s[i] - d[i]) * (s[i] - d[i]) for i in range(self.dim)))
        model.update()
        model.optimize()
        sol = [] 
        for i in range(self.dim): sol.append(s[i].X)
        # print(sol)
        return sol

    def objective(self, s, d): 
        return np.sum(self.gamma_under * np.maximum(d - s, 0) + self.gamma_over * np.maximum(s - d, 0)) #+ 0.5 * (s - d) * (s - d))
    def get_hour_cost(self, s, d):
        return self.gamma_under * np.maximum(d - s, 0) + self.gamma_over * np.maximum(s - d, 0) #+ 0.5 * (s - d) * (s - d)
    def get_hour_cost_quad(self, s, d):
        return self.gamma_under * np.maximum(d - s, 0) + self.gamma_over * np.maximum(s - d, 0) + 0.5 * (s - d) * (s - d)

    def init_n(self, frac = -1, EPS = -1):
        self.EPS = EPS
        self.labels = []
        for i in range(self.n): 
            cur = []
            costs = []
            for k in range(self.demands.shape[0]):
#                 print(self.demands[k], self.D[i])
                # print(self.objective(self.opt_s[k], self.demands[i]))
                costs.append(self.objective(self.opt_s[k], self.demands[i]))
                # if self.objective(self.opt_s[k], self.demands[i]) <= self.EPS: 
                #     cur.append(1)
                # else: 
                #     cur.append(0)
            q = np.quantile(costs, frac)
            cur = costs <= q    
            self.labels.append(cur)
        self.labels = np.array(self.labels)


    class Net(nn.Module):
        def __init__(self, in_dim, out_dim):
            super().__init__()
            self.fc1 = nn.Linear(in_dim, 200)
            self.fc2 = nn.Linear(200, 200)
            self.fc3 = nn.Linear(200, out_dim)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    def init_models(self, epochs): 
        self.phat = self.Net(self.X.shape[1], self.n)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.phat.parameters(), lr=0.002)
        mses = []
        batch_size = 100
        for epoch in range(epochs):  # loop over the dataset multiple times

            for i in range(0, self.X.shape[0], batch_size):
                inp = torch.tensor(self.X[i:i+batch_size,:]).float()
                out = torch.tensor(self.labels[i:i+batch_size]).long()

                optimizer.zero_grad()
                
                pred = self.phat(inp) 
                loss = criterion(pred, out.float()) #* P.weights[i:i+batch_size,:]).mean()
                loss.backward()

                optimizer.step()       
                mses.append(loss.item())
            print(epoch, np.mean(mses[-100:]))
            
        print("MEAN MSE: ", np.mean(mses[-100:]))
    
    # def init_models(self, model_class): 
    #     self.models = [] 
    #     for i in range(self.n): 
    #         print('training: ', i)
    #         model = clone(model_class)
    #         model.fit(self.X, self.labels[i])
    #         self.models.append(model)

    def setup_model(self, eps):
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        # s = 20
        s = model.addVars(self.dim)
        
        h = model.addVars(self.n, self.dim)
        b = model.addVars(self.n, self.dim)
        q = model.addVars(self.n)
        model.update()

        for i in range(self.n):
            model.addConstr(q[i] >= 0)
            for k in range(self.dim):
                model.addConstr(h[i,k] >= 0)
                model.addConstr(b[i,k] >= 0)
                model.addConstr(b[i,k] >= self.demands[i,k] - s[k] - eps / 24)
                model.addConstr(h[i,k] >= s[k] - self.demands[i,k] - eps / 24)
        
            model.addConstr(q[i] >= self.gamma_under * sum(b[i,k] for k in range(self.dim)) +
                                   self.gamma_over * sum(h[i,k] for k in range(self.dim))) # +
                                #    0.5 * sum((s[k] - self.demands[i,k]) * (s[k] - self.demands[i,k]) for k in range(self.dim)))

        for i in range(self.dim - 1):
            model.addConstr(s[i] - s[i + 1] <= self.c_ramp)
            model.addConstr(s[i] - s[i + 1] >= -self.c_ramp)
        self.cap_solver = model
        self.q = q 
        self.s = s  

    def predict_quick(self, x, phat, n_keep = 50, phi = 0):
        model = gb.Model("lp")
        model.Params.LogToConsole = 0
        # s = 20
        s = model.addVars(self.dim)
        
        h = model.addVars(n_keep, self.dim)
        b = model.addVars(n_keep, self.dim)
        q = model.addVars(n_keep)
        model.update()

        keep = torch.topk(phat, n_keep)[1]
        for i, j in enumerate(keep): 
            q[i].Obj = phat[j.item()]

        for i in range(n_keep):
            j = keep[i].item()
            model.addConstr(q[i] >= 0)
            for k in range(self.dim):
                model.addConstr(h[i,k] >= 0)
                model.addConstr(b[i,k] >= 0)
                model.addConstr(b[i,k] >= self.demands[j,k] - s[k])
                model.addConstr(h[i,k] >= s[k] - self.demands[j,k])
        
            model.addConstr(q[i] >= self.gamma_under * sum(b[i,k] for k in range(self.dim)) +
                                   self.gamma_over * sum(h[i,k] for k in range(self.dim)) - phi)
                                #    0.5 * sum((s[k] - self.demands[i,k]) * (s[k] - self.demands[i,k]) for k in range(self.dim)))

        for i in range(self.dim - 1):
            model.addConstr(s[i] - s[i + 1] <= self.c_ramp)
            model.addConstr(s[i] - s[i + 1] >= -self.c_ramp)
        self.cap_solver = model
        self.cap_solver.update()
        self.cap_solver.optimize()

        print("Status:",self.cap_solver.status)

        self.q = q 
        self.s = s

        sol = [] 
        for i in range(self.dim): sol.append(self.s[i].X)
        return sol


    def predict(self, x, phat, t = 0, eps = 0):
        # p = self.model.predict_proba([x])[0]
        # p = F.sigmoid(self.phat(torch.tensor(x).unsqueeze(0).float()).squeeze(0))
        # self.cap_solver.setObjective(sum(self.q[i] * p[i] for i in range(self.n)))

        print(phat.shape)
        keep = torch.topk(phat, 10)[1]
        print(keep)
        for i in keep: 
            self.q[i.item()].Obj = phat[i.item()]
        for i in range(self.n): 
            if i not in keep: 
                self.q[i].Obj = 0
            else: 
                print("HI")

        # for i in range(self.n):
        #     # self.q[i].Obj = self.models[i].predict_proba([x])[0][1]
        #     if phat[i] > t:
        #         self.q[i].Obj = phat[i]
        #     else: 
        #         self.q[i].Obj = 0
        # self.cap_solver.setObjective(sum(self.q[i] * self.models[i].predict_proba([x])[0][1] for i in range(self.n)))
        self.cap_solver.update()

        print("RUNNING")
        self.cap_solver.optimize()

        sol = [] 
        for i in range(self.dim): sol.append(self.s[i].X)
        return sol
