from util import *

class OnlineBipartiteInstance(ABC):
    def __init__(
            self,
            name: str,
            n: int,
            noise_param: float = 0.0,
            L_weights: np.ndarray = None,
            rng_seed: int = None
        ) -> None:
        assert 0 <= noise_param and noise_param <= 1
        self.folder = "Placeholder"
        self.name = name
        self.n = n
        self.noise_param = noise_param
        if rng_seed is not None:
            self.rng = np.random.default_rng(rng_seed)
        else:
            self.rng = np.random.default_rng(42)
        if L_weights is not None:
            assert len(L_weights) == self.n
            self.L_weights = L_weights
        else:
            self.L_weights = np.array([round(1000 * self.rng.random()) for _ in range(self.n)])
        
        self.arrivals = []
        self.noisy_future = []
        self._generate_offline_graph()
        assert len(self.arrivals) == self.n
        self._generate_noisy_future()
        assert len(self.noisy_future) == self.n

    @abstractmethod
    def _generate_offline_graph(self) -> None:
        raise NotImplementedError

    def _generate_noisy_future(self) -> None:
        self.noisy_future = []
        for v_idx in range(self.n):
            # Keep up to 1-noise_param fraction v's neighbors and add up to noise_param fraction non-neighbors
            neighbors = self.arrivals[v_idx]
            non_neighbors = list(set([i for i in range(self.n)]) - set(neighbors))
            self.rng.shuffle(neighbors)
            self.rng.shuffle(non_neighbors)
            n_keep = int(round(len(neighbors) * (1 - self.noise_param)))
            n_flip = len(neighbors) - n_keep
            v_noisy_neighbors = sorted(neighbors[:n_keep] + non_neighbors[:n_flip])
            self.noisy_future.append(v_noisy_neighbors)

    def get_n(self) -> int:
        return self.n

    def get_L_weights(self) -> np.ndarray:
        return self.L_weights

    def get_next_arrival(self, t: int) -> list:
        assert 0 <= t and t < self.n
        output = self.arrivals[t]
        return output
        
    def get_next_arrival_with_advice(self, t: int, X_so_far: np.ndarray) -> np.ndarray:
        arrival = self.get_next_arrival(t)
        advice = self._solve_for_advice(t, X_so_far)
        v_arrival = {u_idx: advice[u_idx] for u_idx in arrival}
        return v_arrival

    def _solve_for_advice(self, t: int, X_so_far: np.ndarray) -> np.ndarray:
        # v_t is arriving and each future predicts neighbors of v_{t+1}, ... , v_n
        future = self.noisy_future[t+1:]

        # === Define LP ===
        prob = pulp.LpProblem("", pulp.LpMaximize)

        # === Define LP variables ===
        prediction_edges_vars = pulp.LpVariable.dicts(
            "prediction",
            indices=np.arange(self.n),
            lowBound=0.0,
            upBound=1.0,
            cat=pulp.LpContinuous 
        )
        future_edges_vars = dict()
        for future_v_idx in range(len(future)):
            for u_idx in future[future_v_idx]:
                future_edges_vars[(u_idx, future_v_idx)] = pulp.LpVariable(
                    f"future[{u_idx},{future_v_idx}",
                    lowBound=0.0,
                    upBound=1.0,
                    cat=pulp.LpContinuous
                )

        # === Define LP constraints ===
        # Total predicted flow from current online vertex v_t <= 1
        prob += pulp.lpSum(prediction_edges_vars[u_idx] for u_idx in range(self.n)) <= 1

        # Set variables to 0 if edge does not exist
        for u_idx in range(self.n):
            if u_idx not in self.arrivals[t]:
                prob += prediction_edges_vars[u_idx] == 0
        
        # Total flow to any offline vertex is <= 1 - matched_so_far
        for u_idx in range(self.n):
            prob += (
                prediction_edges_vars[u_idx] +
                pulp.lpSum(
                    future_edges_vars[(u_idx, future_v_idx)]
                    for future_v_idx in range(len(future))
                    if u_idx in future[future_v_idx]
                )
            ) <= 1 - X_so_far[u_idx]
        
        # For each future online vertex, total predicted flow <= 1
        for future_v_idx in range(len(future)):
            prob += pulp.lpSum(
                future_edges_vars[(u_idx, future_v_idx)]
                for u_idx in future[future_v_idx]
            ) <= 1

        # === Define LP objective ===
        prob += pulp.lpSum(
            self.L_weights[u_idx] * (
                prediction_edges_vars[u_idx] +
                pulp.lpSum(
                    future_edges_vars[(u_idx, future_v_idx)]
                    for future_v_idx in range(len(future))
                    if u_idx in future[future_v_idx]
                )
            )
            for u_idx in range(self.n)
        )

        # === Solve LP and extract solution ===
        prob.solve(pulp.PULP_CBC_CMD(msg=0))
        assert prob.status == pulp.LpStatusOptimal
        advice = [prediction_edges_vars[u_idx].varValue for u_idx in range(self.n)]

        return advice

class ErdosRenyiInstance(OnlineBipartiteInstance):
    def __init__(
            self,
            n: int,
            p: float,
            noise_param: float = 0.0,
            L_weights: np.ndarray = None,
            rng_seed: int = None
        ) -> None:
        self.p = p
        super().__init__(f"ER({n},{p})_{rng_seed}_{noise_param}", n, noise_param, L_weights, rng_seed)
        self.folder = "ER"

    def _generate_offline_graph(self) -> None:
        adjacencies = self.rng.random((self.n, self.n)) < self.p
        self.arrivals = [
            [u_idx for u_idx in range(self.n) if adjacencies[u_idx, v_idx]]
            for v_idx in range(self.n)
        ]

class UpperTriangular(OnlineBipartiteInstance):
    def __init__(
            self,
            n: int,
            noise_param: float = 0.0,
            L_weights: np.ndarray = None,
            rng_seed: int = None
        ) -> None:
        super().__init__(f"UT({n})_{rng_seed}_{noise_param}", n, noise_param, L_weights, rng_seed)
        self.folder = "UT"

    def _generate_offline_graph(self) -> None:
        self.arrivals = [
            [u_idx for u_idx in range(v_idx, self.n)]
            for v_idx in range(self.n)
        ]

class RealWorldInstance(OnlineBipartiteInstance):
    def __init__(
            self,
            name: str,
            G: nx.Graph,
            noise_param: float = 0.0,
            L_weights: np.ndarray = None,
            rng_seed: int = None
        ) -> None:
        self.G = copy.deepcopy(G)
        assert nx.is_bipartite(self.G)
        assert self.G.number_of_nodes() % 2 == 0
        super().__init__(f"{name}_{rng_seed}_{noise_param}", self.G.number_of_nodes() // 2, noise_param, L_weights, rng_seed)
        self.folder = name

    def _generate_offline_graph(self) -> None:
        self.arrivals = [
            [int(x[1:]) for x in self.G.neighbors(node)]
            for node in self.G.nodes() 
            if self.G.nodes[node].get('bipartite') == 1
        ]
