#!/usr/bin/env python
# coding: utf-8
"""
packing problem
"""

import numpy as np
import gurobipy as gp
from gurobipy import GRB

from pyepo.model.grb.grbmodel import optGrbModel


class packingModel(optGrbModel):
    """
    This class is optimization model for packinging problem

    Attributes:
        _model (GurobiPy model): Gurobi model
        adjacency_matrix (np.ndarray / list): Metal concentration per supplier
        capacity (np.ndarray / listy): per metal requirements
        items (list): List of item index
    """

    def __init__(self, adjacency_matrix, capacity):
        """
        Args:
            adjacency_matrix (np.ndarray / list): concentration of items
            capacity (np.ndarray / list): total requirements
        """
        self.adjacency_matrix = np.array(adjacency_matrix)
        assert self.adjacency_matrix>=0, "adjacency_matrix ill defined, should be non-negative"
        self.capacity = np.array(capacity)
        self.num_edges, self.num_paths = np.array(adjacency_matrix).shape
        super().__init__()

    def _getModel(self):
        """
        A method to build Gurobi model

        Returns:
            tuple: optimization model and variables
        """
        # ceate a model
        m = gp.Model("packing")
        # varibles
        x = m.addVars(self.num_paths, name="x", lb=0, vtype=GRB.CONTINUOUS)
        # x = m.addVars(self.num_items, name="x", lb=0, vtype=GRB.CONTINUOUS)
        # sense
        m.modelSense = GRB.MAXIMIZE
        
        # constraints
        for i in range(self.num_edges):
            m.addConstr(gp.quicksum(self.adjacency_matrix[i,j] * x[j]
                        for j in range(self.num_paths)) <= self.capacity[i], name=f"req_{i}")
        return m, x
    

class cspo_packingModel(optGrbModel):
    def __init__(self, adjacency_matrix, q_hat, capacity_lb):
        self.adjacency_matrix = np.array(adjacency_matrix)
        self.capacity_lb = capacity_lb
        self.num_edges, self.num_paths = self.adjacency_matrix.shape
        self.q_hat = q_hat
        super().__init__()

    def _getModel(self):
        m = gp.Model("robust_packing")
        m.setParam('OutputFlag', 0)
        m.setParam('NonConvex', 2)

        # Variables
        x = m.addVars(self.num_paths, name="x", lb=0,ub=100, vtype=GRB.CONTINUOUS)
        
        # Objective
        m.modelSense = GRB.MAXIMIZE

        ## constraints
        for i in range(self.num_edges):
            m.addConstr(gp.quicksum(self.adjacency_matrix[i,j] * x[j]
                        for j in range(self.num_paths)) <= self.capacity_lb[i], name=f"req_{i}")
        return m, x


# if __name__ == "__main__":
    
#     import random
#     # random seed
#     random.seed(42)
#     # set random cost for test
#     cost = [random.random() for _ in range(16)]
#     adjacency_matrix = np.random.choice(range(300, 800), size=(2,16)) / 100
#     capacity = [20, 20]
    
#     # solve model
#     optmodel = packingModel(adjacency_matrix=adjacency_matrix, capacity=capacity) # init model
#     optmodel = optmodel.copy()
#     optmodel.setObj(cost) # set objective function
#     sol, obj = optmodel.solve() # solve
#     # print res
#     print('Obj: {}'.format(obj))
#     for i in range(16):
#         if sol[i] > 1e-3:
#             print(i)
    
#     # Set random seed for reproducibility
#     random.seed(42)
#     np.random.seed(42)

#     # Random cost vector
#     num_items = 16
#     cost = [random.random() for _ in range(num_items)]

#     # Generate a fake predicted adjacency_matrix matrix (e.g., 2 constraints × 16 items)
#     num_capacity = 2
#     adjacency_matrix = np.random.uniform(3.0, 8.0, size=(num_capacity, num_items))  # scaled [300, 800]/100

#     # Required adjacency_matrix (h vector)
#     capacity = [20, 25]

#     # Confidence quantiles Q_m (e.g., q_hat), one per constraint
#     q_hat = [1.5, 1.5]

#     # Initialize model
#     model = cspo_packingModel(pred_adjacency_matrix=adjacency_matrix, q_hat=q_hat, capacity=capacity)

#     # Set cost (if needed)
#     model.cost = cost  # Optional if you use c in model directly

#     # Solve the problem
#     solution, obj_val = model.solve()

#     # Print objective
#     print(f"Objective value: {obj_val:.4f}")

#     # Print non-zero solution values
#     for i, val in enumerate(solution):
#         if val > 1e-3:
#             print(f"x[{i}] = {val:.4f}")

