import numpy as np
from model.v1.optimization import BaseOptimization
import torch
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

class ShortestPath(BaseOptimization):

    def __init__(self, model, feasible_region, outcome_space, A, b, width):
        '''
        A:  [ nnode, nedge ] np, incident matrix
        b:  [ nnode ] np, supply/demand
        width:  [ nedge ] list or np, width of the robust region (do not set to all the same value) 
        '''
        width = np.asarray(width)       # [ nedge ] np
        self.A = A
        self.b = b
        super().__init__(model, feasible_region, outcome_space)
        nedge  = self.A.shape[1] 
        A = self.A
        b = self.b
        
        # -------------
        # Defining the robust solver
        # -------------
        z       = cp.Variable(nedge, nonneg = True) # decision variable
        ypred   = cp.Parameter(nedge)               # cost vector
        lam     = cp.Parameter(nonneg = True)       # radius
        objective   = cp.Minimize(ypred @ z + lam * width @ z)
        constraints = [A @ z == b]
        problem     = cp.Problem(objective, constraints)
        self.robust_solver = CvxpyLayer(problem, parameters=[ypred, lam], variables=[z])
        
        # -------------
        # Defining the solver
        # -------------
        z = cp.Variable(nedge, nonneg = True)       # decision variable
        y = cp.Parameter(nedge)                     # cost vector
        objective   = cp.Minimize(y @ z)
        constraints = [A @ z == b]
        problem     = cp.Problem(objective, constraints)
        self.solver = CvxpyLayer(problem, parameters=[y], variables=[z])

    def objective(self, y, z):
        '''
        Args:
        - y:    [ nbatch, ndim ] np
        - z:    [ nbatch, ndim ] np
        Return:
        - obj:   [ nbatch ] np
        '''
        obj = y @ z.T           # [ nbatch, nbatch ] np
        obj = obj.diagonal()    # [ nbatch ] np
        return obj
    
    # NOTE: assume that the UQ set is l_infty ball
    # NOTE: efficient implementation of the regret computation
    def regret(self, x, y, lam):
        '''
        Args:
        - x:    [ nbatch, ndim ] np
        - y:    [ nbatch, ndim ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:     [ nbatch ] np
        '''
        # -------------
        # Init
        # -------------
        y_batch     = torch.from_numpy(y).float()                                   # [ nbatch, ndim ] 
        ypred_batch = torch.from_numpy(self.model.pred(x)).float()                  # [ nbatch, ndim ] torch
        lam_batch   = torch.from_numpy(lam).float()                                 # [ nbatch ] torch
        # -------------
        # Solve
        # -------------
        z_ro_opt    = self.robust_solver(ypred_batch, lam_batch)[0] # [ nbatch, ndim ] torch
        z_opt       = self.solver(y_batch)[0]                       # [ nbatch, ndim ] torch
        z_ro_opt    = z_ro_opt.numpy()
        z_opt       = z_opt.numpy()
        loss        = self.objective(y, z_ro_opt) - self.objective(y, z_opt)    # [ nbatch ] np  
        # assert loss.min() >= -1e-3, f'Violation of loss rule by {loss.min()}. Perhaps increase solver accuracy?'
        return loss
    

import numpy as np

def build_incidence_many_chains(n):
    """
    Build (A, b) for n disjoint s→t chains where
    path i has i edges (i-1 intermediate nodes).
    Returns:
      A: [N_nodes, N_edges], b: [N_nodes]
      plus edge ranges per path for convenience.
    # NOTE (example usage): A, b, _ = build_incidence_many_chains(5)
    """
    N_edges = n*(n+1)//2
    N_nodes = 2 + n*(n-1)//2

    A = np.zeros((N_nodes, N_edges), dtype=float)
    b = np.zeros(N_nodes, dtype=float)

    row_s = 0
    row_t = N_nodes - 1
    b[row_s] = 1.0
    b[row_t] = -1.0

    row_ptr = 1      # next free row for intermediate nodes
    col_ptr = 0      # next free column for edges
    path_edge_ranges = []

    for i in range(1, n+1):
        # edges for this path occupy columns [col_ptr, col_ptr+i-1]
        c0 = col_ptr
        cL = col_ptr + i - 1
        path_edge_ranges.append((c0, cL))

        # source row: +1 on the first edge of this path
        A[row_s, c0] += 1.0
        # sink row: -1 on the last edge of this path
        A[row_t, cL] += -1.0

        # intermediate nodes for this path:
        # for each internal node, put -1 on incoming edge, +1 on outgoing edge
        for k in range(i-1):  # k = 0..i-2, node v_{i,k+1}
            r = row_ptr
            A[r, c0 + k]     = -1.0   # incoming
            A[r, c0 + k + 1] = +1.0   # outgoing
            row_ptr += 1

        col_ptr += i

    # return A, b, path_edge_ranges
    return A, b