import numpy as np
import torch
from collections import defaultdict, Counter
from verl import DataProto
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
import numpy as np
import time


def to_hashable(x):
    """Convert an object into a hashable type (used for clustering/grouping)."""
    if isinstance(x, (int, float, str, bool)):
        return x
    elif isinstance(x, (np.integer, np.floating)):
        return x.item()
    elif isinstance(x, np.ndarray):
        return tuple(x.flatten())
    elif isinstance(x, (list, tuple)):
        return tuple(to_hashable(e) for e in x)
    elif isinstance(x, dict):
        return tuple(sorted((k, to_hashable(v)) for k, v in x.items()))
    else:
        raise TypeError(f"Unsupported type: {type(x)}")


def nested_defaultdict():
    return defaultdict(list)


class ComplexGraph:
    def __init__(self):
        # graph[u][v] = list of edge dicts (e.g. {"action":..., "traj_uid":...})
        self.graph = defaultdict(nested_defaultdict)
        self.nodes = set()
        self.uids = set()
        self.traj_to_uid = {}
        self.episode_rewards = None  # later for reward-based analysis
        self.final_state=None
        self.first_state=None
        self.shortest_path_to_final=None

    def add_edge(self, u, v, **attrs):
        self.graph[to_hashable(u)][to_hashable(v)].append(attrs)
        self.nodes.update([u, v])
        if "uid" in attrs:
            self.uids.add(attrs["uid"])
        if "traj_uid" in attrs and "uid" in attrs:
            self.traj_to_uid[attrs["traj_uid"]] = attrs["uid"]

    def neighbors(self, u):
        return list(self.graph[to_hashable(u)].keys())

    def edges(self, data=False):
        for u, targets in self.graph.items():
            for v, edges in targets.items():
                if data:
                    for attr in edges:
                        yield (u, v, attr)
                else:
                    yield (u, v)

    def has_edge(self, u, v):
        return v in self.graph[to_hashable(u)]

    def get_traj_path(self, traj_uid):
        traj_edges = []
        for u, targets in self.graph.items():
            for v, edges in targets.items():
                for attr in edges:
                    if attr.get("traj_uid") == traj_uid:
                        traj_edges.append((attr.get("data_id", 0), u, attr.get("action"), v))
        if not traj_edges:
            raise ValueError(f"Trajectory {traj_uid} not found in graph.")

        traj_edges.sort(key=lambda x: x[0])
        path = [(u, action, v) for _, u, action, v in traj_edges]
        return path

    def get_edges(self, u, v):
        return self.graph[to_hashable(u)][to_hashable(v)]

    def all_simple_paths(self, source, target, path=None, visited=None):
        source=to_hashable(source)
        target=to_hashable(target)
        if path is None:
            path = [source]
        if visited is None:
            visited = set([source])
        if source == target:
            yield list(path)
            return
        for neighbor in self.neighbors(source):
            if neighbor not in visited:
                yield from self.all_simple_paths(neighbor, target, path + [neighbor], visited | {neighbor})

    def __repr__(self):
        return f"ComplexGraph(num_nodes={len(self.nodes)}, num_edges={sum(len(v) for u in self.graph.values() for v in u.values())})"

    @classmethod
    def from_data(cls, data):
        Gs = {}
        for uid in set(data.non_tensor_batch['uid']):
            g = cls()
            for i in range(len(data.batch)):
                if data.non_tensor_batch['uid'][i] != uid:
                    continue
                weight=1.0
                if not data.non_tensor_batch['is_action_valid'][i]:
                    weight=1.0
                g.add_edge(
                    to_hashable(data.non_tensor_batch['anchor_obs'][i]),
                    to_hashable(data.non_tensor_batch['next_obs'][i]),
                    action=data.non_tensor_batch['text_actions'][i],
                    traj_uid=data.non_tensor_batch['traj_uid'][i],
                    uid=data.non_tensor_batch['uid'][i],
                    data_id=i,
                    episode_rewards=data.non_tensor_batch['episode_rewards'][i],
                    weight=weight
                )
                if data.non_tensor_batch['episode_rewards'][i]>0:
                    g.final_state=to_hashable(data.non_tensor_batch['next_obs'][i])
            next_obs=data.non_tensor_batch['next_obs'][np.where(data.non_tensor_batch['uid']==uid)[0]]
            anchor_obs=data.non_tensor_batch['anchor_obs'][np.where(data.non_tensor_batch['uid']==uid)[0]]
            episode_rewards=data.non_tensor_batch['episode_rewards'][np.where(data.non_tensor_batch['uid']==uid)[0]]
            g.first_state=to_hashable(anchor_obs[0])

            Gs[uid] = g
        return Gs

    @staticmethod
    def find_success_states_per_task(data):
        success_states_per_task = defaultdict(list)
        success_idx = data.non_tensor_batch['episode_rewards'] > 0
        for traj_uid in set(data.non_tensor_batch['traj_uid'][success_idx]):
            traj_uid_ids = np.where(data.non_tensor_batch['traj_uid'] == traj_uid)[0]
            uid = data.non_tensor_batch['uid'][traj_uid_ids[0]]
            for anchor_obs in data.non_tensor_batch['anchor_obs'][traj_uid_ids]:
                success_states_per_task[(uid, traj_uid)].append(anchor_obs)
            success_states_per_task[(uid, traj_uid)].append(data.non_tensor_batch['next_obs'][traj_uid_ids[-1]])
        return success_states_per_task

    @staticmethod
    def lcs(seq1, seq2):
        n, m = len(seq1), len(seq2)
        dp = [[[] for _ in range(m + 1)] for _ in range(n + 1)]
        for i in range(n):
            for j in range(m):
                if seq1[i] == seq2[j]:
                    dp[i + 1][j + 1] = dp[i][j] + [seq1[i]]
                else:
                    dp[i + 1][j + 1] = max(dp[i][j + 1], dp[i + 1][j], key=len)
        return dp[n][m]

    @staticmethod
    def Calculate_all_shortest_path(Gs:dict):
        for key in Gs.keys():
            Gs[key].calculate_path()

    @classmethod
    def search_golden_candidate(cls, success_states_per_task):
        golden_candidate_per_task = defaultdict(list)
        uids = set(uid for uid, _ in success_states_per_task.keys())

        for uid in uids:
            trajs = [v for (u, _), v in success_states_per_task.items() if u == uid]
            if not trajs:
                continue
            common_seq = list(trajs[0])
            for t in trajs[1:]:
                common_seq = cls.lcs(common_seq, t)
                if not common_seq:
                    break
            golden_candidate_per_task[uid] = common_seq
        return golden_candidate_per_task

    def is_mandatory(self, s_from, s_mid, s_to):
        has_path = False
        for path in self.all_simple_paths(s_from, s_to):
            has_path = True
            if s_mid not in path:
                return False
        if not has_path:
            raise Exception("no path")
        return True

    @classmethod
    def get_golden_states(cls, golden_candidate, Gs):
        final_golden_states = defaultdict(list)
        for uid, candidates in golden_candidate.items():
            g = Gs[uid]
            remove_state = []
            if len(candidates) <= 2:
                final_golden_states[uid] = candidates
                continue
            for i in range(1, len(candidates) - 1):
                if not g.is_mandatory(candidates[0], candidates[i], candidates[-1]):
                    remove_state.append(i)
            final_golden_states[uid] = [candidates[i] for i in range(len(candidates)) if i not in remove_state]
        return final_golden_states

    def shortest_path(self, source, target):
        source=to_hashable(source)
        target=to_hashable(target)
        if source not in self.nodes or target not in self.nodes:
            raise ValueError("source or target not in graph")

        from collections import deque
        queue = deque([[source]])
        visited = set([source])

        while queue:
            path = queue.popleft()
            node = path[-1]
            if node == target:
                return path

            for neighbor in self.neighbors(node):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(path + [neighbor])

        return None

    def shortest_path_weighted(self, source, target, weight_key="weight"):
        source=to_hashable(source)
        target=to_hashable(target)
        import heapq
        pq = [(0, source, [source])]
        visited = {}

        while pq:
            cost, node, path = heapq.heappop(pq)
            if node == target:
                return path, cost
            if node in visited and visited[node] <= cost:
                continue
            visited[node] = cost
            for neighbor in self.neighbors(node):
                edges = self.graph[node][neighbor]
                w = min(e.get(weight_key, 1) for e in edges)
                heapq.heappush(pq, (cost + w, neighbor, path + [neighbor]))
        return None, 51


    def all_to_target_shortest_paths(self, target, weight_key="weight"):
        if target is not None:
            target=to_hashable(target)

        import heapq

        reverse_graph = defaultdict(list)
        for u in self.graph:
            for v in self.graph[u]:
                for attr in self.graph[u][v]:
                    w = attr.get(weight_key, 1)
                    reverse_graph[v].append((u, w))

        # ---------- Dijkstra ----------
        distances = {n: float("inf") for n in self.nodes}
        paths = {n: [] for n in self.nodes}
        pq = [(0, target, [target])]
        distances[target] = 0

        while pq:
            cost, node, path = heapq.heappop(pq)
            if cost > distances[node]:
                continue

            for neighbor, w in reverse_graph[node]:
                new_cost = cost + w
                if new_cost < distances[neighbor]:
                    distances[neighbor] = new_cost
                    paths[neighbor] = [neighbor] + path
                    heapq.heappush(pq, (new_cost, neighbor, [neighbor] + path))
        
        far_distance=0
        for key,value in distances.items():
            if (value !=float("inf")) and (value>far_distance):
                far_distance=value
        
        for key,value in distances.items():
            if value==float("inf"):
                distances[key]=far_distance+1
        return distances, paths
    
    def calculate_path(self,weight_key="weight"):
        self.shortest_path_to_final=self.all_to_target_shortest_paths(self.final_state,weight_key="weight")

    def print_shortest_path(self, source, target, weighted: bool = False, weight_key: str = "weight"):
        source_hash = to_hashable(source)
        target_hash = to_hashable(target)
        if source_hash not in self.nodes:
            print(f"❌  {source}")
            return
        if target_hash not in self.nodes:
            print(f"❌ {target}")
            return
        
        if weighted:
            path, total_cost = self.shortest_path_weighted(source_hash, target_hash, weight_key)
            if path is None:
                print(f"❌")
                return
            print(f"✅")
        else:
            path = self.shortest_path(source_hash, target_hash)
            if path is None:
                print(f"❌")
                return
            print(f"✅")
        
        self._print_path_details(path, weighted, weight_key)

    def _print_path_details(self, path: list, weighted: bool, weight_key: str):
        for i in range(len(path) - 1):
            u = path[i]
            v = path[i + 1]
            edges = self.get_edges(u, v)
            
            if weighted and edges:
                edge = min(edges, key=lambda x: x.get(weight_key, 1))
            else:
                edge = edges[0] if edges else {}
            
            action = edge.get("action", "")
            weight = edge.get(weight_key, 1.0) if weighted else "-"
            traj_uid = edge.get("traj_uid", "")
            print(f"  {i+1}: {u}")
            print(f"    ↳ : {action} | : {weight} | UID: {traj_uid}")
        print(f"  {len(path)}: {path[-1]}")

    def print_all_uid_shortest_paths(self, weighted: bool = False, weight_key: str = "weight"):

        if not self.first_state:
            print("❌")
            return
        if not self.final_state:
            print("❌")
            return
        
        self.print_shortest_path(
            source=self.first_state,
            target=self.final_state,
            weighted=weighted,
            weight_key=weight_key
        )





import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from collections import defaultdict

def visualize_complex_graph_improved(G, step_rewards=None, show_info=False, figsize=(14,6)):
    fig, ax = plt.subplots(figsize=figsize,dpi=800)
    ax.set_aspect('equal')
    ax.axis('off')

    all_states = set()
    for u, targets in G.graph.items():
        all_states.add(u)
        for v in targets.keys():
            all_states.add(v)
    if G.final_state is not None:
        all_states.add(G.final_state)
    
    state_id_map = {state: idx+1 for idx, state in enumerate(sorted(all_states))}
    id_state_map = {idx: state for state, idx in state_id_map.items()}

    traj_edges = defaultdict(list)
    for u, targets in G.graph.items():
        for v, edges in targets.items():
            for attr in edges:
                traj_uid = attr.get("traj_uid", "unknown")
                traj_edges[traj_uid].append((u,v,attr))

    traj_uids = list(traj_edges.keys())
    colors = plt.cm.tab10.colors
    traj_color_map = {traj_uid: colors[i % len(colors)] for i, traj_uid in enumerate(traj_uids)}

    node_positions = {}
    y_step = 2.0
    for i, traj_uid in enumerate(traj_uids):
        edges = traj_edges[traj_uid]
        edges.sort(key=lambda x: x[2].get("data_id",0))
        states = [edges[0][0]] + [v for _,v,_ in edges]
        x_positions = np.linspace(0, len(states)*3.0, len(states))
        y_position = -i*y_step
        for x,s in zip(x_positions, states):
            if s not in node_positions:
                node_positions[s] = [x, y_position]
            else:
                node_positions[s][0] = (node_positions[s][0]+x)/2
                node_positions[s][1] = (node_positions[s][1]+y_position)/2
    scaling_x=3.0
    scaling_y=5.0
    for node, (x,y) in node_positions.items():
        node_positions[node][0] = x*scaling_x
        node_positions[node][1] = y*scaling_y

    for node, (x,y) in node_positions.items():
        node_id = state_id_map[node]
        if G.final_state is not None and node == G.final_state:
            ax.scatter(x, y, s=300, facecolors='blue', edgecolors='black', zorder=3, alpha=0.8)
            if show_info:
                ax.text(x, y+0.5, 'F', fontsize=10, ha='center', va='center', zorder=4, color='red', fontweight='bold')
        elif G.first_state is not None and node == G.first_state:
            ax.scatter(x, y, s=300, facecolors='red', edgecolors='black', zorder=3, alpha=0.8)
            if show_info:
                ax.text(x, y+0.5, 'S', fontsize=10, ha='center', va='center', zorder=4, color='red', fontweight='bold')
        else:
            ax.scatter(x, y, s=200, facecolors='white', edgecolors='black', zorder=3)
        
        if show_info:
            ax.text(x, y, str(node_id), fontsize=4, ha='center', va='center', zorder=4, fontweight='bold')

    edge_count = defaultdict(int)
    for traj_uid, edges in traj_edges.items():
        color = traj_color_map[traj_uid]
        for u,v,attr in edges:
            key = (u,v)
            edge_count[key] += 1
            rad = 0.2 * (edge_count[key]-1)
            if (v,u) in edge_count:
                rad += 0.2

            start = node_positions[u]
            end = node_positions[v]
            arrow = FancyArrowPatch(start, end,
                                    connectionstyle=f"arc3,rad={rad}",
                                    arrowstyle='-|>',
                                    color=color,
                                    lw=1,
                                    alpha=0.8,
                                    mutation_scale=10,
                                    shrinkA=5,
                                    shrinkB=5,
                                    zorder=2)
            ax.add_patch(arrow)

            mid_x = (start[0]+end[0])/2 + rad
            mid_y = (start[1]+end[1])/2 + rad


    xs = [pos[0] for pos in node_positions.values()]
    ys = [pos[1] for pos in node_positions.values()]

    padding_x = (max(xs) - min(xs)) * 0.5 + 2.0
    padding_y = (max(ys) - min(ys)) * 0.5 + 2.0
    ax.set_xlim(min(xs)-padding_x*2, max(xs)+padding_x*2)
    ax.set_ylim(min(ys)-padding_y*2, max(ys)+padding_y*2)
    
    legend_elements = []
    for traj_uid, color in traj_color_map.items():
        legend_elements.append(plt.Line2D([], [], color=color, lw=2, label=traj_uid))
    if G.final_state is not None:
        legend_elements.append(plt.scatter([], [], s=300, facecolors='red', edgecolors='black', alpha=0.8, label='Final State'))
    
    ax.legend(handles=legend_elements, title="Trajectory UID / State Map", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

    plt.subplots_adjust(left=0.05, right=0.85, top=0.95, bottom=0.05)
    plt.savefig('text.png', bbox_inches='tight', pad_inches=0.8, dpi=800)
    plt.close()

    return id_state_map






def compute_graph_step_discounted_returns(batch: DataProto, golden_states_dict: dict, gamma: float):
    """
    Compute discounted returns for each trajectory. (Eq. 5 in the paper)
    
    Args:
        batch (DataProto): Input batch.
        gamma (float): Discount factor.
    
    Returns:
        torch.Tensor: Discounted returns.
    """
    rewards = batch.non_tensor_batch['rewards'].astype(np.float32)
    traj_uids = batch.non_tensor_batch['traj_uid']
    uids=batch.non_tensor_batch['uid']
    active_masks = batch.non_tensor_batch['active_masks'].astype(np.float32)
    next_obs_=batch.non_tensor_batch['next_obs']
    old_obs_=batch.non_tensor_batch['anchor_obs']
    returns_by_traj = {}
    unique_traj_uids = np.unique(traj_uids)
    for uid in unique_traj_uids:
        # Get indices for this trajectory
        traj_indices = np.where(traj_uids == uid)[0]
        uid_uid=uids[traj_indices[0]]
        # Extract rewards and masks for this trajectory
        traj_rewards = rewards[traj_indices]
        traj_active_masks = active_masks[traj_indices]
        new_obs=next_obs_[traj_indices]
        old_obs=old_obs_[traj_indices]
        assert traj_active_masks.all(), "active_masks should be all 1s for the same trajectory"
        
        # Calculate returns
        traj_returns = np.zeros_like(traj_rewards)
        running_return = 0.0

        if uid_uid in golden_states_dict.keys():
            golden_states=golden_states_dict[uid_uid][1:]
            golden_state=golden_states[0]
            golden_states=golden_states[1:]
            
            flag=0

            for t in range(len(traj_rewards)):
                if new_obs[t]==golden_state:
                    if len(golden_states)>0:
                        golden_state=golden_states[0]
                        golden_states=golden_states[1:]
                    traj_rewards[t]=10.0
                    running_return=0.0
                    for rt in reversed(range(flag,t+1)):
                        running_return = traj_rewards[rt] + gamma * running_return
                        traj_returns[rt] = running_return
                    flag=t+1
        else:
            for t in reversed(range(len(traj_rewards))):
                running_return = traj_rewards[t] + gamma * running_return
                traj_returns[t] = running_return
            
        # Store the results
        returns_by_traj[uid] = traj_returns
    
    # Recombine the returns into the original batch order
    all_returns = np.zeros_like(rewards)
    for i, uid in enumerate(traj_uids):
        traj_indices = np.where(traj_uids == uid)[0]
        idx_in_traj = np.where(traj_indices == i)[0][0]  # Find position of i in its trajectory
        all_returns[i] = returns_by_traj[uid][idx_in_traj]
    
    all_returns = torch.tensor(all_returns, dtype=torch.float32, device=batch.batch['input_ids'].device)
    return all_returns


def compute_graph_path_returns(batch: DataProto, Gs: dict, gamma: float,normalize_distance: bool):
    """
    Compute discounted returns for each trajectory. (Eq. 5 in the paper)
    
    Args:
        batch (DataProto): Input batch.
        gamma (float): Discount factor.
    
    Returns:
        torch.Tensor: Discounted returns.
    """
    rewards = batch.non_tensor_batch['rewards'].astype(np.float32)
    traj_uids = batch.non_tensor_batch['traj_uid']
    uids=batch.non_tensor_batch['uid']
    active_masks = batch.non_tensor_batch['active_masks'].astype(np.float32)
    next_obs_=batch.non_tensor_batch['next_obs']
    old_obs_=batch.non_tensor_batch['anchor_obs']
    is_action_valid_=batch.non_tensor_batch['is_action_valid']
    returns_by_traj = {}
    unique_traj_uids = np.unique(traj_uids)
    for uid in unique_traj_uids:
        # Get indices for this trajectory
        traj_indices = np.where(traj_uids == uid)[0]
        uid_uid=uids[traj_indices[0]]
        # Extract rewards and masks for this trajectory
        traj_rewards = rewards[traj_indices]
        traj_active_masks = active_masks[traj_indices]
        new_obs=next_obs_[traj_indices]
        old_obs=old_obs_[traj_indices]
        is_action_valid=is_action_valid_[traj_indices]
        assert traj_active_masks.all(), "active_masks should be all 1s for the same trajectory"
        
        # Calculate returns
        traj_returns = np.zeros_like(traj_rewards)
        g=Gs[uid_uid]
        distances_dict,distances_path_dict=g.shortest_path_to_final

        
        for t in range(len(traj_rewards)):
            weight=1.0
            # if not is_action_valid[t]:
            #     weight=1.0
            # traj_returns[t] = 10*gamma**(distances_dict[old_obs[t]]-distances_dict[new_obs[t]]-weight+1)
            # traj_returns[t] = 10*gamma**(distances_dict[new_obs[t]]-distances_dict[old_obs[t]]+weight)
            if normalize_distance:
                traj_returns[t] = 10*gamma**(distances_dict[to_hashable(new_obs[t])]-distances_dict[to_hashable(old_obs[t])]+weight)
            else:
                traj_returns[t] = 10*gamma**(distances_dict[to_hashable(new_obs[t])]+weight-1)
            
        # Store the results
        returns_by_traj[uid] = traj_returns
    
    # Recombine the returns into the original batch order
    all_returns = np.zeros_like(rewards)
    for i, uid in enumerate(traj_uids):
        traj_indices = np.where(traj_uids == uid)[0]
        idx_in_traj = np.where(traj_indices == i)[0][0]  # Find position of i in its trajectory
        all_returns[i] = returns_by_traj[uid][idx_in_traj]
    
    all_returns = torch.tensor(all_returns, dtype=torch.float32, device=batch.batch['input_ids'].device)
    return all_returns