from util import *

class FractionalMatchingAlgorithm(ABC):
    def __init__(self, eps: float) -> None:
        self.eps = eps
        self.name = "FractionalMatchingAlgo"

    def init_for_new_instance(self, n: int, L_weights: np.ndarray) -> None:
        self.n = n
        self.w = L_weights
        self.X = np.array([0.0] * self.n)
        self.A = np.array([0.0] * self.n)
        self.v_neighbors = dict()
        self.matches = np.array([[0.0] * self.n for _ in range(self.n)])
    
    def get_X_so_far(self) -> np.ndarray:
        return self.X

    def compute_offline_optimal(self) -> float:
        assert len(self.v_neighbors) == self.n
        weighted_edges = dict()
        for v_idx, offline_nbrs in self.v_neighbors.items():
            for u_idx in offline_nbrs:
                u_weight = self.w[u_idx]
                weighted_edges[(u_idx, self.n + v_idx)] = u_weight
        G = nx.Graph()
        G.add_weighted_edges_from([(k[0], k[1], v) for k, v in weighted_edges.items()])
        optimal_matching = nx.max_weight_matching(G)

        # networkx may flip edge indices, so we take min and max accordingly
        optimal_matching_value = sum([weighted_edges[min(a,b), max(a,b)] for a, b in optimal_matching])
        
        return optimal_matching_value
    
    def compute_obtained_online_matching(self) -> float:
        assert len(self.v_neighbors) == self.n
        return self.w @ self.X
    
    def _match_preprocessing(self, v_arrival: dict) -> None:
        # Validate that advice is feasible while accumulating advice into A
        total_advice = 0
        for u_idx, u_advice in v_arrival.items():
            self.A[u_idx] += u_advice
            total_advice += u_advice
            np.clip(self.A, 0.0, 1.0, out=self.A)
        assert 0 - self.eps <= total_advice and total_advice <= 1 + self.eps

        # Store arrived neighbors
        v_idx = len(self.v_neighbors)
        self.v_neighbors[v_idx] = list(v_arrival.keys())
        return v_idx

    def _match_postprocessing(self, v_idx: int, flow: np.ndarray) -> None:
        # Store (fractional) match
        self.matches[v_idx] = copy.deepcopy(flow)

        # Do sanity checks while handling numerical issues
        np.clip(self.X, 0.0, 1.0, out=self.X)
        assert len(flow) == self.n
        assert sum(flow) <= 1 + self.eps
        for u_idx in range(self.n):
            assert self.X[u_idx] <= 1
            assert np.isclose(self.X[u_idx], min(1, sum([self.matches[v_idx, u_idx] for v_idx in range(self.n)])))

    def match(self, v_arrival: dict) -> np.ndarray:
        v_idx = self._match_preprocessing(v_arrival)
        if len(v_arrival) == 0:
            flow = np.array([0.0] * self.n)
        else:
            flow = self.match_subroutine(v_arrival)
        self._match_postprocessing(v_idx, flow)
        return flow

    """
    We can choose whether to do balance-style or waterfilling-style matching
    """
    @abstractmethod
    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        raise NotImplementedError

    """
    Penalty function for balance-style matching
    """
    @abstractmethod
    def _g(self, x: float, a: float = None) -> float:
        raise NotImplementedError

    """
    Compute flows to offline vertices in balance-style matching
    """
    def _compute_u_flows_for_balance(self, neighbors: list) -> np.ndarray:
        def flow_to_u_condition(u_idx: int, l: float) -> Callable[[float], bool]:
            return lambda flow : self.w[u_idx] * (1 - self._g(self.X[u_idx] + flow, self.A[u_idx])) >= l

        def l_condition(l: float) -> bool:
            total_flow = 0
            for u_idx in neighbors:
                if self.w[u_idx] * (1 - self._g(self.X[u_idx], self.A[u_idx])) <= l:
                    u_flow = 0
                else:
                    u_flow = binary_search_largest_true(
                        flow_to_u_condition(u_idx, l),
                        low=0.0,
                        high=1.0 - self.X[u_idx]
                    )
                total_flow += u_flow
            return total_flow <= 1
        
        l = binary_search_smallest_true(
            l_condition,
            low=0.0,
            high=max(self.w)
        )
        u_flows = []
        for u_idx in range(self.n):
            if u_idx not in neighbors:
                u_flows.append(0)
            else:
                if self.w[u_idx] * (1 - self._g(self.X[u_idx], self.A[u_idx])) <= l:
                    u_flows.append(0)
                else:
                    u_flow = binary_search_largest_true(
                        flow_to_u_condition(u_idx, l),
                        low=0.0,
                        high=1.0 - self.X[u_idx]
                    )
                    u_flows.append(u_flow)
        u_flows = np.array(u_flows)
        assert sum(u_flows) <= 1 + self.eps
        if sum(u_flows) > 0:
            u_flows = u_flows / sum(u_flows)
        return u_flows
    
    """
    Balance-style matching subroutine
    """
    def _balance(self, v_arrival: dict) -> np.ndarray:
        if sum([1 - self.X[u_idx] for u_idx in v_arrival.keys()]) <= 1:
            # Saturate all of v's neighbors
            flow = [
                1 - self.X[u_idx]
                if u_idx in v_arrival.keys()
                else 0
                for u_idx in range(self.n)
            ]
        else:
            flow = self._compute_u_flows_for_balance(list(v_arrival.keys()))

        flow = np.array(flow).astype(float)
        self.X += flow
        return flow

    """
    Waterfill-style matching subroutine
    """
    def _waterfill(self, neighbors: list, tau: float = 0.0) -> np.ndarray:
        def l_condition(l: float) -> bool:
            total_flow = sum([max(0, l - self.X[u_idx]) for u_idx in neighbors])
            return total_flow <= 1 - tau
        
        l = binary_search_largest_true(
            l_condition,
            low=0.0,
            high=1.0
        )
        u_flows = np.array([
            max(0, l - self.X[u_idx])
            if u_idx in neighbors
            else 0
            for u_idx in range(self.n)
        ])
        assert sum(u_flows) <= 1 - tau + self.eps
        if sum(u_flows) > 0:
            u_flows = u_flows / sum(u_flows) * (1 - tau)

        self.X += u_flows
        return u_flows

class Greedy(FractionalMatchingAlgorithm):
    def __init__(self, eps: float = 1e-5) -> None:
        super().__init__(eps)
        self.name = "Greedy"

    def _g(self, x: float, a: float = None) -> float:
        pass

    """
    Put full flow to available neighbor with highest weight (always integral)
    """
    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        free_neighbors = [
            (self.w[u_idx], u_idx)
            for u_idx in v_arrival.keys()
            if self.X[u_idx] == 0
        ]
        flow = np.array([0] * self.n)
        if len(free_neighbors) > 0:
            free_neighbors.sort()
            
            # Always choose the largest index max weight neighbor
            # because of how we generate the upper triangular instance
            max_weight = free_neighbors[0][0]
            choice = np.array([idx for w, idx in free_neighbors if w == max_weight])[-1]
            flow[choice] = 1.0
        self.X += flow
        return flow

class Balance(FractionalMatchingAlgorithm):
    def __init__(self, eps: float = 1e-5) -> None:
        super().__init__(eps)
        self.name = "Balance"

    def _g(self, x: float, a: float = None) -> float:
        return np.exp(x-1)
    
    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        return self._balance(v_arrival)

class Waterfill(FractionalMatchingAlgorithm):
    def __init__(self, eps: float = 1e-5) -> None:
        super().__init__(eps)
        self.name = "Waterfill"

    def _g(self, x: float, a: float = None) -> float:
        pass

    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        return self._waterfill(v_arrival.keys())

class PushAndWaterfill(FractionalMatchingAlgorithm):
    def __init__(self, lambda_val: float, eps: float = 1e-5) -> None:
        self.lambda_val = lambda_val
        super().__init__(eps)
        self.name = fr"PAW ($\lambda$ = {self.lambda_val})"

    def _g(self, x: float, a: float = None) -> float:
        pass

    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        tau = 0
        max_advice_u = max(v_arrival, key=v_arrival.get)
        A_u = max_advice_u if v_arrival[max_advice_u] > 0 else None

        # Phase 1: Push max(0, lambda - self.X[u_idx]) to u_idx in v's neighbors
        advice_flow = np.array([0.0] * self.n)
        if A_u is not None:
            tau = max(0, self.lambda_val - self.X[A_u])
            advice_flow[A_u] += tau

            # Manually update X here so that phase 2 has correct view of X
            self.X += advice_flow

        # Phase 2: Waterfill the remaining units of water
        waterfill_flow = self._waterfill(v_arrival.keys(), tau)

        return advice_flow + waterfill_flow

class LABalance(FractionalMatchingAlgorithm):
    def __init__(self, lambda_val: float, eps: float = 1e-5) -> None:
        self.lambda_val = lambda_val
        super().__init__(eps)
        self.name = fr"LAB ($\lambda$ = {self.lambda_val})"

    def _g(self, x: float, a: float = None) -> float:
        def t(l: float) -> float:
            return l * np.exp(1-l)
        def f0(z: float) -> float:
            return min(np.exp(z + self.lambda_val - 1), 1)
        def f1(z: float) -> float:
            if self.lambda_val == 0:
                return np.exp(z-1)
            elif z == 1:
                return 1
            elif 0 <= z < t(self.lambda_val):
                return (np.exp(self.lambda_val - 1) - self.lambda_val)/(1 - z)
            elif t(self.lambda_val) <= z < 1:
                denom = np.real(lambertw(- self.lambda_val * np.exp(1 - self.lambda_val - z)))
                assert denom != 0
                return - self.lambda_val / denom
        
        """
        g(x,a) = f_1(x) when a > x and max{f_0(x-a), f_1(x)} when a <= x
        """
        if a > x:
            return f1(x)
        else:
            return max(f0(x-a), f1(x))

    def match_subroutine(self, v_arrival: dict) -> np.ndarray:
        return self._balance(v_arrival)
