#!/usr/bin/env python
# ============================================================
# Stochastic connectivity propagation + centrality comparison
# ============================================================
import math
import os
import numpy as np
import random

from . import brandes_centrality as bc
from .bottleneck_env import SimpleEnv
from .generate_state_transition_matrix import build_state_transition_matrix
from .utils import BottleneckVisualization


class StochasticConnectivityPropagation:
    """
    A toy implementation that repeatedly selects target states and
    performs Q-learning with an intrinsic reward of 1 when the agent
    steps onto one of those targets.  It accumulates various statistics
    such as cumulative VPS, visit counts, etc.
    """

    def __init__(self, env, gamma_v=0.999, gamma_q=0.99, lam=0.9):
        self.width, self.height = env.width, env.height
        self.n_states = self.width * self.height

        # hyper-parameters
        self.gamma_v = gamma_v
        self.gamma_q = gamma_q
        self.lam = lam
        self.alpha = 0.05

        # state-transition matrix and wall mask
        self.transition_matrix, self.wall_mask = build_state_transition_matrix(env)
        self.no_wall_indices = np.where(self.wall_mask.flatten() == 0)

        # TD / VPS related buffers
        self.V = np.zeros(self.n_states)
        self.cumulative_V = np.zeros(self.n_states)
        self.state_visit_count = np.zeros(self.n_states)
        self.eligibility_trace = np.zeros(self.n_states)

        # Q-table: 0,1,2,3 correspond to ↑ ↓ ← →
        self.Q = np.zeros((self.n_states, 4))

        # logs & accumulators
        self.episode_rewards = []
        self.abs_td = np.zeros(self.n_states)
        self.cum_vps = np.zeros(self.n_states)
        self.random_r = np.zeros(self.n_states)

    def choose_action(self, state, epsilon):
        if random.random() < epsilon:
            return np.random.choice([0, 1, 2, 3])
        max_q = np.max(self.Q[state, :])
        candidates = np.flatnonzero(self.Q[state, :] == max_q)
        return int(np.random.choice(candidates))

    def update_with_q_learning(self, max_episode_num, max_episode_length):
        """
        Iterate over every free state, treat it as a target once,
        and run many episodes from every other starting state.
        """
        for m in self.no_wall_indices[0]:
            target_state = [m]
            absorbing_state = []

            print(f"Target State: {target_state}")

            for i in self.no_wall_indices[0]:
                self.eligibility_trace = np.zeros(self.n_states)

                start_state = i
                episode_reward = 0
                s = start_state
                current_episode_length = 0

                for _ in range(max_episode_length):
                    action = self.choose_action(s, epsilon=1.0)  # purely random
                    s_next = self.transition_matrix[s, action]

                    self.state_visit_count[s] += 1
                    self.eligibility_trace *= self.gamma_v * self.lam

                    r = 1 if s_next in target_state else 0
                    episode_reward += r

                    is_absorbing_current = s in absorbing_state

                    # ---------- Q-learning update ----------
                    shaped_reward = r
                    delta_q = shaped_reward + self.gamma_q * np.max(self.Q[s_next, :]) - self.Q[s, action]
                    if not is_absorbing_current:
                        self.Q[s, action] += self.alpha * delta_q

                    # maintain eligibility trace
                    if is_absorbing_current:
                        self.eligibility_trace = np.zeros(self.n_states)
                    else:
                        self.eligibility_trace[s] += 1

                    # ---------- value / VPS update ----------
                    virtual_r = r
                    delta_v = virtual_r + self.gamma_v * self.V[s_next] - self.V[s]
                    self.V += self.alpha * delta_v * self.eligibility_trace

                    td = (self.V[s_next] - self.V[s]) ** 2
                    self.abs_td[s] += 0.0001 * (abs(td) - self.abs_td[s])

                    if is_absorbing_current:
                        s = start_state
                        current_episode_length = 0
                    else:
                        s = s_next
                        current_episode_length += 1

                if episode_reward != 0:
                    self.episode_rewards.append(episode_reward)

                self.eligibility_trace = np.zeros(self.n_states)

            self.cum_vps += self.abs_td
            self.abs_td = np.zeros(self.n_states)
            self.Q = np.zeros((self.n_states, 4))
            self.cumulative_V += self.V
            self.V = np.zeros(self.n_states)

        print(f"The Total Reward collected over all runs: {np.sum(self.episode_rewards)}")


def main():
    env = SimpleEnv(render_mode=None)
    env.reset()
    visualizer = BottleneckVisualization(env)

    sp_bc = bc.compute_betweenness_centrality(env)
    visualizer.plot_2d_heatmap(sp_bc, topk=0, title="Shortest-path Betweenness Centrality")

    try:
        cf_bc = bc.compute_current_flow_centrality(env)
    except Exception as e:
        print(
            f"[Warning] current_flow_betweenness_centrality failed: {e}. "
            "Skipping current-flow centrality baseline."
        )
        cf_bc = None
    if cf_bc is not None:
        visualizer.plot_2d_heatmap(cf_bc, topk=0, title="Current-flow Betweenness Centrality")

    scp = StochasticConnectivityPropagation(env)
    scp.update_with_q_learning(max_episode_num=200, max_episode_length=3000)

    output_dir = "./connectivity_results"
    os.makedirs(output_dir, exist_ok=True)

    visualizer.plot_2d_heatmap(scp.cum_vps, topk=0, title="Value-Power Strength")
    visualizer.plot_2d_heatmap(scp.cumulative_V, topk=0, title="Value Function")


if __name__ == "__main__":
    main()

