#!/usr/bin/env python
# coding: utf-8
"""
Cover problem
"""

import numpy as np
import gurobipy as gp
from gurobipy import GRB

from pyepo.model.grb.grbmodel import optGrbModel


class coverModel(optGrbModel):
    """
    This class is optimization model for covering problem

    Attributes:
        _model (GurobiPy model): Gurobi model
        concentrations (np.ndarray / list): Metal concentration per supplier
        reqs (np.ndarray / listy): per metal requirements
        items (list): List of item index
    """

    def __init__(self, concentrations, reqs):
        """
        Args:
            concentrations (np.ndarray / list): concentration of items
            reqs (np.ndarray / list): total requirements
        """
        self.concentrations = np.array(concentrations)
        self.reqs = np.array(reqs)
        self.num_reqs, self.num_items = np.array(concentrations).shape
        super().__init__()

    def _getModel(self):
        """
        A method to build Gurobi model

        Returns:
            tuple: optimization model and variables
        """
        # ceate a model
        m = gp.Model("cover")
        # varibles
        x = m.addVars(self.num_items, name="x", lb=0,ub=100, vtype=GRB.CONTINUOUS)
        # x = m.addVars(self.num_items, name="x", lb=0, vtype=GRB.CONTINUOUS)
        # sense
        m.modelSense = GRB.MINIMIZE
        
        # constraints
        for i in range(self.num_reqs):
            m.addConstr(gp.quicksum(self.concentrations[i,j] * x[j]
                        for j in range(self.num_items)) >= self.reqs[i], name=f"req_{i}")
        return m, x
    

class cspo_coverModel(optGrbModel):
    def __init__(self, pred_concentrations, q_hat, reqs):
        self.pred_concentrations = np.array(pred_concentrations)
        self.reqs = np.array(reqs)
        self.num_reqs, self.num_items = self.pred_concentrations.shape
        self.q_hat = q_hat
        super().__init__()

    def _getModel(self):
        m = gp.Model("robust_cover")
        m.setParam('OutputFlag', 0)
        m.setParam('NonConvex', 2)

        # Variables
        x = m.addVars(self.num_items, name="x", lb=0,ub=100, vtype=GRB.CONTINUOUS)
        # x = m.addVars(self.num_items, name="x", lb=0 ,vtype=GRB.CONTINUOUS)
        t = m.addVars(self.num_reqs, name="t", lb=0, vtype=GRB.CONTINUOUS)
        # t = m.addVar(name="t",lb=0, vtype=GRB.CONTINUOUS)

        # Objective
        m.modelSense = GRB.MINIMIZE

        # Robust constraints
        for i in range(self.num_reqs):
            lhs = gp.quicksum(self.pred_concentrations[i, j] * x[j] for j in range(self.num_items))
            m.addConstr(lhs - self.reqs[i] == self.q_hat[i] * t[i] , name=f"req_{i}")
            m.addQConstr(gp.quicksum(x[j] * x[j] for j in range(self.num_items)) <= t[i] * t[i], name=f"soc_{i}")
        return m, x


if __name__ == "__main__":
    
    import random
    # random seed
    random.seed(42)
    # set random cost for test
    cost = [random.random() for _ in range(16)]
    concentrations = np.random.choice(range(300, 800), size=(2,16)) / 100
    reqs = [20, 20]
    
    # solve model
    optmodel = coverModel(concentrations=concentrations, reqs=reqs) # init model
    optmodel = optmodel.copy()
    optmodel.setObj(cost) # set objective function
    sol, obj = optmodel.solve() # solve
    # print res
    print('Obj: {}'.format(obj))
    for i in range(16):
        if sol[i] > 1e-3:
            print(i)
    
    # Set random seed for reproducibility
    random.seed(42)
    np.random.seed(42)

    # Random cost vector
    num_items = 16
    cost = [random.random() for _ in range(num_items)]

    # Generate a fake predicted concentrations matrix (e.g., 2 constraints × 16 items)
    num_reqs = 2
    concentrations = np.random.uniform(3.0, 8.0, size=(num_reqs, num_items))  # scaled [300, 800]/100

    # Required concentrations (h vector)
    reqs = [20, 25]

    # Confidence quantiles Q_m (e.g., q_hat), one per constraint
    q_hat = [1.5, 1.5]

    # Initialize model
    model = cspo_coverModel(pred_concentrations=concentrations, q_hat=q_hat, reqs=reqs)

    # Set cost (if needed)
    model.cost = cost  # Optional if you use c in model directly

    # Solve the problem
    solution, obj_val = model.solve()

    # Print objective
    print(f"Objective value: {obj_val:.4f}")

    # Print non-zero solution values
    for i, val in enumerate(solution):
        if val > 1e-3:
            print(f"x[{i}] = {val:.4f}")

