from .base_gp import BaseGP
import numpy as np
from joblib import Parallel, delayed
import pulp

class Ilp(BaseGP):
    def __init__(self, max_cluster_size, weights=None, n_harvest_runs=1000, top_k=5):
        super(BaseGP, self).__init__(max_cluster_size, weights)
        self.n_harvest_runs = n_harvest_runs
        self.top_k = top_k

    def clustering(self, X, s):
        """
        Harvests clusters from multiple stochastic runs and finds the optimal
        combination using ILP (Set Partitioning).
        """
        print(f"1. Harvesting clusters from {self.n_harvest_runs} runs...")

        # We need a worker that returns the CLUSTERS (sets of indices), not just the inertia
        # Re-using the logic from your stochastic runs
        def harvest_worker(seed):
            try:
                res = self._single_stochastic_run(X, s, self.weights, self.max_cluster_size, self.top_k, seed)
                if res is None: return []
                _, labels, inertia = res

                # Convert labels to frozen sets (hashable)
                clusters = []
                for k in range(s):
                    indices = tuple(sorted(np.where(labels == k)[0]))
                    if len(indices) > 0:
                        clusters.append(indices)
                return (clusters, inertia)
            except:
                return []

        # Deterministic seeds
        seeds = np.arange(self.n_harvest_runs)

        results = Parallel(n_jobs=-1)(
            delayed(harvest_worker)(seed) for seed in seeds
        )

        # Flatten and Deduplicate
        all_clusters = [c for run in results for c in run[0]]
        all_inertias = [run[1] for run in results]
        print('Min inertia across runs', min(all_inertias))
        unique_clusters = list(set(all_clusters))

        print(f"   -> Found {len(all_clusters)} total clusters.")
        print(f"   -> Reduced to {len(unique_clusters)} UNIQUE candidates.")

        # --- Calculate Costs (Exact Squared Error) ---
        print("2. Calculating candidate costs...")
        candidate_costs = []
        if self.weights is None:
            W = np.ones(len(X))
        else:
            W = self.weights.flatten()

        for indices in unique_clusters:
            pts = X[list(indices)]
            w_sub = W[list(indices)]
            centroid = np.average(pts, axis=0, weights=w_sub)
            cost = np.sum(w_sub * np.sum((pts - centroid) ** 2, axis=1))
            candidate_costs.append(cost)

        # --- Setup ILP ---
        print("3. Solving Set Partitioning...")
        prob = pulp.LpProblem("Cluster_Ensemble", pulp.LpMinimize)

        # Variables: Do we include candidate j?
        z = pulp.LpVariable.dicts("z", range(len(unique_clusters)), cat=pulp.LpBinary)

        # Objective: Minimize total inertia
        prob += pulp.lpSum([candidate_costs[j] * z[j] for j in range(len(unique_clusters))])

        # Constraint 1: We need exactly 's' clusters
        prob += pulp.lpSum([z[j] for j in range(len(unique_clusters))]) == s

        # Constraint 2: Every point i must be covered exactly once
        # Create map: point_id -> list of candidate_ids containing it
        point_map = {i: [] for i in range(len(X))}
        for cand_idx, indices in enumerate(unique_clusters):
            for pt_idx in indices:
                point_map[pt_idx].append(cand_idx)

        for i in range(len(X)):
            prob += pulp.lpSum([z[j] for j in point_map[i]]) == 1

        # Solve
        solver = pulp.PULP_CBC_CMD(msg=0)
        prob.solve(solver)

        status = pulp.LpStatus[prob.status]
        print(f"   -> Solver Status: {status}")

        if status != 'Optimal':
            print("   -> Could not find valid partition (unlikely if runs were valid).")
            return None

        # --- Reconstruct Result ---
        final_inertia = pulp.value(prob.objective)

        final_labels = np.zeros(len(X), dtype=int)
        chosen_indices = []

        label_counter = 0
        for j in range(len(unique_clusters)):
            if pulp.value(z[j]) > 0.5:
                chosen_indices.append(unique_clusters[j])
                for pt_idx in unique_clusters[j]:
                    final_labels[pt_idx] = label_counter
                label_counter += 1

        final_centroids_list = []
        final_inertia = 0.0
        if self.weights is None:
            W_flat = np.ones(len(X))
        else:
            W_flat = np.array(self.weights).flatten()

        for k in range(s):
            mask = (final_labels == k)
            # Handle potential empty clusters (unlikely but safe to check)
            if np.any(mask):
                pts = X[mask]
                w_sub = W_flat[mask]

                # Calculate Centroid
                centroid = np.average(pts, axis=0, weights=w_sub)
                final_centroids_list.append(centroid)

                # Calculate Inertia contribution
                sq_dists = np.sum((pts - centroid) ** 2, axis=1)
                final_inertia += np.sum(w_sub * sq_dists)
            else:
                final_centroids_list.append(np.zeros(X.shape[1]))

        final_centroids = np.array(final_centroids_list)

        return final_centroids, final_labels, final_inertia