import numpy as np
import random
from itertools import combinations

class SampleEfficiencyMatrix:
    def __init__(self, n):
        self.n = n
        self.data = np.zeros((n, n), dtype=np.float32)  # Adjacency matrix (directed edge weights)
    
    def set_edge(self, i, j, weight):
        """Set weight for directed edge from node i to j"""
        self.data[i][j] = weight

    def set_data(self, data):
        self.n = data.shape[0]
        self.data = data
    
    def load(self, fname: str):
        """Load adjacency matrix from file"""
        self.data = np.load(fname)
    
    def max_hamilton_cycle(self):
        """
        Main function: automatically selects algorithm based on graph size
        Returns: {'path': node list, 'total_weight': total weight, 'method': algorithm name}
        """
        # if self.n <= 20:
        #     return self._dp_solution()
        # else:
        #     return self._heuristic_solution()
        return self._heuristic_solution()
    
    def _dp_solution(self):
        """Dynamic programming solution for directed complete graphs (any weights)"""
        n = self.n
        m = 1 << n
        dp = np.full((m, n), 0)
        parent = np.zeros((m, n), dtype=int)
        
        # Base case: single node with 0 initial weight
        for v in range(n):
            dp[1 << v][v] = 0
        
        # State transitions (allow all edges, including negative weights)
        for mask in range(1, m):
            for v in range(n):
                if not (mask >> v) & 1:
                    continue
                for u in range(n):
                    if u == v or not (mask >> u) & 1:
                        continue
                    new_weight = dp[mask ^ (1 << v)][u] + self.data[u][v]
                    if new_weight > dp[mask][v]:
                        dp[mask][v] = new_weight
                        parent[mask][v] = u
        
        # Find the best Hamiltonian cycle (must return to start)
        final_mask = m - 1
        max_weight = -np.inf
        last_node = -1
        
        for v in range(n):
            total_weight = dp[final_mask][v] + self.data[v][0]
            # print(f"total_weight: {total_weight}, max_weight: {max_weight}")
            if total_weight > max_weight:
                max_weight = total_weight
                last_node = v
        
        if last_node == -1:
            return None  
        
        # Reconstruct path
        path = []
        mask = final_mask
        current = last_node
        
        while mask:
            print(current, mask)
            path.append(current)
            new_mask = mask ^ (1 << current)
            current = parent[mask][current]
            mask = new_mask
        
        path.reverse()
        # path.append(path[0])  # Close the cycle
        
        return {
            'path': path,
            'total_weight': max_weight,
            'method': 'DP (Directed Complete Graph)'
        }
    
    def _heuristic_solution(self):
        """Heuristic solution (for large directed graphs)"""
        best_path = None
        best_weight = -np.inf
        
        trials=int(len(self.data)//2)

        for _ in range(trials):
            # Random starting node
            start = random.randint(0, self.n - 1)
            path = [start]
            visited = set(path)
            
            # Greedy path extension (following directed edges)
            while len(path) < self.n:
                last = path[-1]
                candidates = []
                
                # Find all valid outgoing edges
                for v in range(self.n):
                    if v not in visited:
                        candidates.append((self.data[last][v], v))
                
                if not candidates:
                    break  # Cannot extend further
                
                # Sort by weight and select from top 3 (with randomness)
                candidates.sort(reverse=True)
                _, next_node = random.choice(candidates[:3]) if len(candidates) >= 3 else candidates[0]
                path.append(next_node)
                visited.add(next_node)
            
            # Check if valid directed cycle
            if len(path) == self.n:
                current_weight = sum(self.data[path[i]][path[i+1]] for i in range(self.n - 1))
                # current_weight += self.data[path[-1]][path[0]]
                if current_weight > best_weight:
                    best_weight = current_weight
                    best_path = path.copy()
        
        if best_path:
            best_path.append(best_path[0])  # Close the cycle
            return {
                'path': best_path,
                'total_weight': best_weight,
                'method': 'Heuristic (Directed)'
            }
        return None


# Only a test
if __name__ == "__main__":
    # Example 1: Small directed graph (DP)
    dsem_small = SampleEfficiencyMatrix(4)
    dsem_small.set_edge(0, 1, 2)  # 0→1 (weight 2)
    dsem_small.set_edge(1, 2, 3)  # 1→2 (weight 3)
    dsem_small.set_edge(2, 3, 4)  # 2→3 (weight 4)
    dsem_small.set_edge(3, 0, 5)  # 3→0 (weight 5)
    dsem_small.set_edge(0, 2, 1)  # 0→2 (weight 1)
    result_small = dsem_small.max_hamilton_cycle()
    print("Small directed graph result:", result_small)

    # Example 2: Large directed graph (heuristic)
    n_large = 50
    dsem_large = SampleEfficiencyMatrix(n_large)
    # Generate random directed edges (50% connection probability, weights 1-10)
    for i, j in combinations(range(n_large), 2):
        if random.random() > 0.5:
            dsem_large.set_edge(i, j, random.randint(1, 10))
        if random.random() > 0.5:
            dsem_large.set_edge(j, i, random.randint(1, 10))
    result_large = dsem_large.max_hamilton_cycle()
    print("Large directed graph result (first 5 nodes):", result_large['path'][:5] if result_large else "No cycle found")