#!/usr/bin/env python
# coding: utf-8
"""
Knapsack problem
"""

import numpy as np
import gurobipy as gp
from gurobipy import GRB

from pyepo.model.grb.grbmodel import optGrbModel


class knapsackModelRel(knapsackModel):
    """
    This class is relaxed optimization model for knapsack problem.
    """

    def _getModel(self):
        """
        A method to build Gurobi
        """
        # ceate a model
        m = gp.Model("knapsack")
        # turn off output
        m.Params.outputFlag = 0
        # varibles
        x = m.addVars(self.items, name="x", ub=1)
        # sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        for i in range(len(self.capacity)):
            m.addConstr(gp.quicksum(self.weights[i,j] * x[j]
                        for j in self.items) <= self.capacity[i])
        return m, x

    def relax(self):
        """
        A forbidden method to relax MIP model
        """
        raise RuntimeError("Model has already been relaxed.")
        
        
if __name__ == "__main__":
    
    import random
    # random seed
    random.seed(42)
    # set random cost for test
    cost = [random.random() for _ in range(16)]
    weights = np.random.choice(range(300, 800), size=(2,16)) / 100
    capacity = [20, 20]
    
    # solve model
    optmodel = knapsackModel(weights=weights, capacity=capacity) # init model
    optmodel = optmodel.copy()
    optmodel.setObj(cost) # set objective function
    sol, obj = optmodel.solve() # solve
    # print res
    print('Obj: {}'.format(obj))
    for i in range(16):
        if sol[i] > 1e-3:
            print(i)
            
    # relax
    optmodel = optmodel.relax()
    optmodel.setObj(cost) # set objective function
    sol, obj = optmodel.solve() # solve
    # print res
    print('Obj: {}'.format(obj))
    for i in range(16):
        if sol[i] > 1e-3:
            print(i)
            
    # add constraint
    optmodel = optmodel.addConstr([weights[0,i] for i in range(16)], 10)
    optmodel.setObj(cost) # set objective function
    sol, obj = optmodel.solve() # solve
    # print res
    print('Obj: {}'.format(obj))
    for i in range(16):
        if sol[i] > 1e-3:
            print(i)




class toy_knapsackModel(optGrbModel):
    def __init__(self):
        self.num_item = 2
        super().__init__()
    def _getModel(self):
        # create a model
        m = gp.Model()
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(gp.quicksum(x[i] for i in range(self.num_item)) == 1) # Select the item with best cost value
        return m, x

class toy_knapsackModel_small_x(optGrbModel):
    def __init__(self):
        self.num_item = 2
        super().__init__()
    def _getModel(self):
        # create a model
        m = gp.Model()
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(gp.quicksum(x[i] for i in range(self.num_item)) == 1) # Select the item with best cost value
        return m, x

class toy_knapsackModel_large_x(optGrbModel):
    def __init__(self):
        self.num_item = 2
        super().__init__()
    def _getModel(self):
        # create a model
        m = gp.Model()
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(gp.quicksum(x[i] for i in range(self.num_item)) == 1) # Select the item with best cost value
        m.addConstr(x[0]<=0)
        return m, x
    
class cspo_toy_knapsackModel_example3(optGrbModel):
    def __init__(self, pred_weights, q_hat):
        self.pred_weights = np.array(pred_weights)
        self.num_item = len(pred_weights)
        assert self.num_item == 2
        self.q_hat = q_hat
        self.capacity = 1
        super().__init__()

    def _getModel(self):
        # create a model
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        m.setParam('NonConvex', 2)
        # varibles
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        u = m.addVar(name="u", vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(self.capacity/self.q_hat - gp.quicksum(self.pred_weights[i]*x[i] / self.q_hat for i in range(self.num_item)) == u)
        # m.addConstr(gp.quicksum(self.pred_weights[i]*x[i] for i in range(self.num_item)) == self.q_hat)
        m.addConstr(gp.quicksum(x[i]*x[i] for i in range(self.num_item))<= u*u)
        return m, x

class cspo_toy_knapsackModel_example_comp_truncation(optGrbModel):
    def __init__(self, pred_capa, q_hat):
        self.num_item = 2
        self.capacity = max([pred_capa-q_hat,0])
        super().__init__()

    def _getModel(self):
        # create a model
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        m.setParam('NonConvex', 2)
        # varibles
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(x[1] <= self.capacity)
        # m.addConstr(gp.quicksum(self.pred_weights[i]*x[i] for i in range(self.num_item)) == self.q_hat)
        m.addConstr(gp.quicksum(x[i] for i in range(self.num_item)) == 1) # Select the item with best cost value
        return m, x

class knapsackModel(optGrbModel):
    """
    This class is optimization model for knapsack problem

    Attributes:
        _model (GurobiPy model): Gurobi model
        weights (np.ndarray / list): Weights of items
        capacity (np.ndarray / listy): Total capacity
        items (list): List of item index
    """

    def __init__(self, weights, capacity):
        """
        Args:
            weights (np.ndarray / list): weights of items
            capacity (np.ndarray / list): total capacity
        """
        self.weights = np.array(weights)
        self.capacity = np.array(capacity)
        self.items = list(range(self.weights.shape[1]))
        super().__init__()

    def _getModel(self):
        """
        A method to build Gurobi model

        Returns:
            tuple: optimization model and variables
        """
        # ceate a model
        m = gp.Model("knapsack")
        # varibles
        x = m.addVars(self.items, name="x", ub=1, vtype=GRB.CONTINUOUS)
        # sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        for i in range(len(self.capacity)):
            m.addConstr(gp.quicksum(self.weights[i,j] * x[j]
                        for j in self.items) <= self.capacity[i])
        return m, x

    def relax(self):
        """
        A method to get linear relaxation model
        """
        # copy
        model_rel = knapsackModelRel(self.weights, self.capacity)
        return model_rel


class fractional_knapsackModel(optGrbModel): ## Fractional => 1 single knapsack
    def __init__(self, weights, capacity):
        self.weights = np.array(weights)
        self.num_item = len(weights)
        self.capacity = capacity
        super().__init__()
    
    def _getModel(self):
        # create a model
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        # varibles
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(gp.quicksum(self.weights[i]*x[i] for i in range(self.num_item)) <= self.capacity)
        return m, x

class cspo_fractional_knapsackModel(optGrbModel):
    def __init__(self, pred_weights, q_hat,capacity):
        self.pred_weights = np.array(pred_weights)
        self.num_item = len(pred_weights)
        self.q_hat = q_hat ## CPSO adds q_hat
        self.capacity = capacity
        super().__init__()

    def _getModel(self):
        # ceate a model
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        m.setParam('NonConvex', 2)
        # varibles
        x = m.addVars(self.num_item, name="x", ub = 1, vtype=GRB.CONTINUOUS)
        u = m.addVar(name="u", vtype=GRB.CONTINUOUS) ## The uncetainty parameter
        # model sense
        m.modelSense = GRB.MAXIMIZE
        # constraints
        m.addConstr(self.capacity/self.q_hat - gp.quicksum(self.pred_weights[i]*x[i] / self.q_hat for i in range(self.num_item)) == u)
        m.addConstr(gp.quicksum(x[i]*x[i] for i in range(self.num_item))<= u*u) ## What is this constraint about?
        return m, x
