"""mcts.py – spectrum‑only Monte‑Carlo Tree Search with terminal predicates
=======================================================================
Pure‑function, single‑file MCTS that

* supports **mixed actions** (discrete family + continuous θ with
  progressive widening),
* keeps the search tree in flat NumPy arrays, JIT‑friendly, and
* handles **terminal states** via a user‑supplied predicate on spectra.

Edge‑reward prefix back‑prop correctly gives each node the expected *future*
time (negative reward) from that node onward.
"""

from __future__ import annotations

import sys
from sorcerun.git_utils import get_repo

ROOT = get_repo().working_dir
if ROOT not in sys.path:
    # add the repo to the path if it is not already there
    sys.path.append(ROOT)

from globals import SIGN, SQRT, INV

import math
from dataclasses import dataclass
from typing import Callable, Sequence, List, Tuple, Any
from richerator import richerator
from sorcerun.sacred_utils import DummyRun
from make_algorithm.actions import ACTIONS, Action, ActionInput, MatrixActionInput
from make_algorithm.losses import LOSSES

import numpy as np


# ------------------------------------------------------------------
#  Internal helpers
# ------------------------------------------------------------------
def _ensure_capacity(arrs_with_fill, need: int):
    """Grow every arena array to **at least** `need` rows and
    initialise ***newly allocated*** rows with the supplied *fill value*.

    Parameters
    ----------
    arrs_with_fill : Iterable[tuple[np.ndarray, Any]]
        Tuples ``(array, fill)`` where *fill* is the value that should
        be written to *new* rows when the array is resized. One‑dim and
        multi‑dim arrays are supported (only the first axis is grown).
    need : int
        Required capacity in rows (exclusive upper bound).
    """
    for a, fill in arrs_with_fill:
        cap = a.shape[0]
        if need <= cap:
            continue
        new_cap = cap
        while new_cap < need:
            new_cap *= 2
        old_shape = a.shape
        if a.ndim == 1:
            a.resize(new_cap, refcheck=False)
            a[cap:new_cap] = fill
        else:
            a.resize((new_cap, *a.shape[1:]), refcheck=False)
            # broadcast fill only along first axis
            if np.isscalar(fill):
                a[cap:new_cap, ...] = fill
            else:  # assume array/sequence matching trailing dims
                a[cap:new_cap, ...] = np.asarray(fill)


def _ucb(w: float, n: int, logp: float, c: float):
    return w / (n + 1e-9) + c * np.sqrt(logp / (n + 1e-9))


# ------------------------------------------------------------------
#  Tree arena – flat arrays
# ------------------------------------------------------------------
@dataclass
class Tree:
    capacity: int
    state0: Tuple[np.ndarray, ...]  # initial spectrum (root state)
    term_pred: Callable[
        [Tuple[np.ndarray, ...]], np.bool_
    ]  # terminal predicate(spec, init_spec)
    actions: List[Action]
    single_actions: List[Action]
    rng: np.random.Generator = np.random.default_rng()

    C_pw: float = 2.0
    alpha_pw: float = 0.5

    # arrays (set in __post_init__)
    parent: np.ndarray = None
    first_child: np.ndarray = None
    next_sib: np.ndarray = None
    visits: np.ndarray = None
    value: np.ndarray = None
    best_value: np.ndarray = None  # initialise to -inf
    edge_reward: np.ndarray = None
    action_idx: np.ndarray = None
    theta: np.ndarray = None
    terminal: np.ndarray = None
    num_params: np.ndarray = None
    size: int = 1
    depth: int = 0 ## depth for each node, monitor to see exploration/exploitation
    nodes_each_depth = []
    best_path = []
    rollout_alg = []
    coupled: np.ndarray = None
    couplingaction: Action = None

    states: List[Any] = None  # list of spectra (states) for each node

    # Python‑side child cache for O(1) access
    children: List[List[int]] = None  # list of immediate‑children lists
    child_count: np.ndarray = None  # number of children per node

    # rollout baselines
    rollout_baselines: List[Sequence[Tuple[Action, np.ndarray]]] = None
    single_rollout_baselines: List[Sequence[Tuple[Action, np.ndarray]]] = None

    def __post_init__(self):
        cap = self.capacity
        self.parent = np.full(cap, -1, np.int32)
        self.first_child = np.full(cap, -1, np.int32)
        self.next_sib = np.full(cap, -1, np.int32)
        self.visits = np.zeros(cap, np.int32)
        self.value = np.zeros(cap, np.float64)
        self.best_value = np.full(cap, -np.inf, np.float64)
        self.edge_reward = np.zeros(cap, np.float64)
        self.action_idx = np.full(cap, -1, np.int16)
        self.theta = np.zeros((cap, 4), np.float32)  # k_max ≤ 4
        self.terminal = np.zeros(cap, np.bool_)
        self.depth = np.zeros(cap, np.int32)
        # contiguous children lists
        self.children = [[] for _ in range(cap)]
        self.child_count = np.zeros(cap, np.int32)
        self.num_params = np.zeros(cap, np.int32)
        self.coupled = np.zeros(cap, np.int32) ## 1 means coupled, 0 means decoupled.
        self.coupled[0] = 1

        self.states = [self.state0]
        self.terminal[0] = self.term_pred(self.state0)

        self.rollout_baselines = []
        self.rollout_baseline_indices = []  # just for easy random choice
        self.single_rollout_baselines = []
        self.single_rollout_baseline_indices = [] # For coupled rollouts

    def _action_completed(self, idx):

        if idx == 0:
            return True
        action = self.actions[self.action_idx[idx]]
        if self.num_params[idx] == action.k:
            return True
        return False
    # --------------------------------------------------------------
    #  child insertion
    # --------------------------------------------------------------
    def _add_child(
        self,
        parent_idx: int,
        action: Action,
        theta: np.ndarray,
        next_state: Tuple[np.ndarray, ...],
        reward: float,
    ) -> int:
        idx = self.size
        self.size += 1
        _ensure_capacity(
            (
                (self.parent, -1),
                (self.first_child, -1),
                (self.next_sib, -1),
                (self.visits, 0),
                (self.value, 0.0),
                (self.best_value, -np.inf),
                (self.edge_reward, 0.0),
                (self.action_idx, -1),
                (self.theta, 0.0),
                (self.terminal, False),
                (self.child_count, 0),
                (self.depth, 0),
                (self.num_params, 0),
                (self.coupled, 0)
            ),
            self.size,
        )
        # ensure children list long enough
        if idx >= len(self.children):
            self.children.extend([] for _ in range(idx - len(self.children) + 1))

        # sibling prepended (legacy pointers)
        self.next_sib[idx] = self.first_child[parent_idx]
        self.first_child[parent_idx] = idx

        # append to children cache
        self.children[parent_idx].append(idx)
        self.child_count[parent_idx] += 1

        self.parent[idx] = parent_idx
        self.edge_reward[idx] = reward
        self.action_idx[idx] = self.actions.index(action)
        
        ## Update depth and num of nodes at each depth
        self.depth[idx] =  self.depth[parent_idx]+1
        if len(self.nodes_each_depth) < self.depth[idx]+1:
            self.nodes_each_depth.append(1)
        else:
            self.nodes_each_depth[self.depth[idx]] += 1

        ## Update num_params
        if self._action_completed(parent_idx): 
            self.num_params[idx] = 0
            if self.actions[self.action_idx[idx]].num_matrices == 2:
                self.coupled[idx] = 1
        else:
            self.num_params[idx] = self.num_params[parent_idx] + 1
            self.theta[idx] = self.theta[parent_idx].copy()
            self.theta[idx, self.num_params[idx]-1] = theta
            self.coupled[idx] = self.coupled[parent_idx]

        self.visits[idx] = 0
        self.value[idx] = 0.0
        self.states.append(next_state)
        self.terminal[idx] = self.term_pred(next_state)
        self.best_value[idx] = -np.inf
        return idx


# ------------------------------------------------------------------
#  Selection (progressive widening) – returns path ready for expansion/rollout
# ------------------------------------------------------------------
def select_pw(
    tree: Tree,
    root: int = 0,
    c_ucb: float = 1.4,
):
    path: List[int] = [root]
    edge_rewards: List[float] = []
    node = root
    while True:
        if tree.terminal[node]:
            # print(f"Terminal node {node} reached during selection. Stop.")
            return path, edge_rewards

        limit = max(1, int(tree.C_pw * (tree.visits[node] ** tree.alpha_pw)))
        kids = tree.children[node]
        cnt = tree.child_count[node]
        if not tree._action_completed(node) and cnt < limit:
            return path, edge_rewards  # budget left → expand here
        elif tree._action_completed(node):
            if tree.coupled[node] == 1 and cnt < len(tree.actions) - 1:
                return path, edge_rewards
            elif tree.coupled[node] == 0 and cnt < len(tree.single_actions):
                return path, edge_rewards

        # pick existing child by UCB
        logp = math.log(tree.visits[node] + 1)
        sc = _ucb(tree.best_value[kids]*tree.visits[kids], tree.visits[kids], logp, c_ucb)
        best_child = kids[np.argmax(sc)]
        path.append(best_child)
        # print(
        #     f"Select child {best_child} of {node} via {ACTIONS[tree.action_idx[best_child]].name}"
        # )
        # print(path)
        # # List all children
        # print("Children of node", node)
        # k = tree.first_child[node]
        # while k != -1:
        #     print(f"Child {k}: {ACTIONS[tree.action_idx[k]].name}")
        #     k = tree.next_sib[k]

        edge_rewards.append(tree.edge_reward[best_child])
        node = best_child


# ------------------------------------------------------------------
#  Expansion – add ONE θ‑child
# ------------------------------------------------------------------


def expand(tree: Tree, parent_idx: int, EXPLORE_K: int):

    ## if parent_idx has filled theta: chooose from action
    
    if tree._action_completed(parent_idx):

        action_children = [tree.actions[tree.action_idx[x]] for x in tree.children[parent_idx]]
        if tree.coupled[parent_idx] == 0:
            difference = [x for x in tree.single_actions if x not in action_children]
        else:
            difference = [x for x in tree.actions if (x not in action_children and x != tree.couplingaction)]

        action: Action = difference[0]
        if parent_idx != 0:
            completed_action = tree.actions[tree.action_idx[parent_idx]]
            reward = -completed_action.time
            inp = ActionInput(
                current_spectra=tree.states[parent_idx],
                a_spectrum=tree.state0[0],
                theta=tree.theta[parent_idx]
            )
            next_spectra = completed_action.spectral_iteration(inp)
        else:
            next_spectra = tree.states[0]
            reward = 0
        child_idx = tree._add_child(parent_idx, action, 0, next_spectra, reward)

    ## if parent_idx has incomplete theta: progressive widening on the parameters
    else:
        # theta = action.sample_theta(tree.rng)
        # get child with best theta so far
        kids = np.asarray(tree.children[parent_idx])
        action = tree.actions[tree.action_idx[parent_idx]]
        if len(kids) < EXPLORE_K:
            # no children yet → use random theta
            # adding a condition: the kids are bad and we need to explore
            theta = action.sample_theta(tree.num_params[parent_idx]+1, tree.rng)
        else:
            best_child = kids[np.argmax(tree.visits[kids])]
            best_theta = tree.theta[best_child, tree.num_params[best_child]-1]
            # jitter around best theta
            theta = action.sample_jitter(
                tree.num_params[best_child],
                mean=best_theta,
                # choose stddev inversely proportional to number of children
                stddev_scale=1.0 / math.log(tree.child_count[parent_idx] + 2),
                rng=tree.rng,
            )

        next_spectra = tree.states[parent_idx]
        reward = 0
        child_idx = tree._add_child(parent_idx, action, theta, next_spectra, reward)
    return child_idx, reward


# ------------------------------------------------------------------
#  Roll‑out
# ------------------------------------------------------------------
def rollout(
    tree: Tree,
    node: int,
    maxiters=300,
):
    reward = 0
    k = 0
    cur = tree.states[node]
    '''
    parent = tree.parent[node]
    if tree.best_value[node] < -100 and tree.best_value[parent] > -50:
        parent = tree.parent[node]
        print("********************************")
        print(tree.best_value[parent])
        for i in tree.children[parent]:
            print("------------------------------")
            print(tree.best_value[i])
            print(tree.theta[i])
            print("------------------------------")

        import time
        time.sleep(3)
    '''

    init_ = np.abs(cur[0].copy())

    action = tree.actions[tree.action_idx[node]]
    cur_filled = tree.num_params[node]
    theta0 = tree.theta[node].copy()

    inp = ActionInput(
        current_spectra=tree.states[node],
        a_spectrum=tree.state0[0],
        theta=[]
    )
    _, default_params = action.baseline_spectral_iteration(inp)
    while cur_filled < action.k:
        theta0[cur_filled] = default_params[cur_filled]
        cur_filled += 1

    reward += -action.time
    inp = ActionInput(
        current_spectra=tree.states[node],
        a_spectrum=tree.state0[0],
        theta=theta0
    )
    cur = action.spectral_iteration(inp)


    rollout_alg_idx = tree.rng.choice(tree.rollout_baseline_indices)
    rollout_alg = tree.rollout_baselines[rollout_alg_idx]

    if action.num_matrices == 1 and rollout_alg[0][0].num_matrices == 2:
        rollout_alg = [(tree.couplingaction,[])] + rollout_alg

    rollout_alg_theta = []
    for action, _ in rollout_alg:
        if tree.term_pred(cur):
            return reward, rollout_alg_theta

        inp = ActionInput(
            current_spectra=cur,
            a_spectrum=tree.state0[0],
            theta=[],
        )
        cur, thetar = action.baseline_spectral_iteration(inp)
        reward += -action.time
        rollout_alg_theta.append((action, thetar))
        k += 1
        if k >= maxiters:
            print("Rollout failed to converge.")
            print(
                "Condition # of Init spectrum that failed:",
                init_.max() / init_.min(),
            )
            return -1e30
    
    return reward, rollout_alg_theta


# ------------------------------------------------------------------
#  Back‑prop with edge costs
# ------------------------------------------------------------------
def backprop(tree: Tree, path: Sequence[int], edge_rewards: Sequence[float], r: float, rollout_alg: Sequence[Tuple[Action, np.ndarray]]):

    prefix = np.concatenate(([0.0], np.cumsum(edge_rewards[::-1])))[::-1]
    total_reward = prefix + r
    tree.visits[path] += 1
    tree.value[path] += total_reward

    if tree.best_value[0] < total_reward[0]:
        print(total_reward)
        tree.best_path = path
        tree.rollout_alg = rollout_alg
    tree.best_value[path] = np.maximum(tree.best_value[path], total_reward)
    return total_reward[0]


# Populate the tree with a particular algorithm
# given by a sequence of actions (could be a baseline for instance)
# run through these nodes once and backprop
def populate_tree(
    tree: Tree,
    algorithm: Sequence[Tuple[Action, np.ndarray]],
):
    """Populate the tree with a particular algorithm given by a
    sequence of actions (could be a baseline for instance).

    Parameters
    ----------
    tree : Tree
        The tree to populate.
    algorithm : Sequence[Tuple[Action, np.ndarray]]
        The sequence of actions to populate the tree with.


    Returns
    -------
    r_roll : float
        The reward of the rollout.
    cumulative_reward : float
        The cumulative reward of the algorithm + rollout.
    """
    path: List[int] = [0]
    edge_rewards: List[float] = []
    node = 0
    for action, theta in algorithm:

        ## if parent_idx has filled theta: choose from action

        if node == 0:
            node = tree._add_child(node, action, 0, tree.states[0], 0)
            edge_rewards.append(0)
            path.append(node)

        elif tree._action_completed(node):
            
            completed_action = tree.actions[tree.action_idx[node]]
            reward = -completed_action.time
            inp = ActionInput(
                current_spectra=tree.states[node],
                a_spectrum=tree.state0[0],
                theta=tree.theta[node]
            )
            next_spectra = completed_action.spectral_iteration(inp)
            node = tree._add_child(node, action, 0, next_spectra, reward)
            edge_rewards.append(-completed_action.time)
            path.append(node)

        for (i, thetas) in enumerate(theta):
            
            node = tree._add_child(node, action, thetas, tree.states[node], 0)
            edge_rewards.append(0)
            path.append(node)
        
        if tree.terminal[node]:
            break

    # rollout the last node
    r_roll, rollout_alg = rollout(tree, node)
    cumulative_reward = backprop(tree, path, edge_rewards, r_roll, rollout_alg)

    return r_roll, cumulative_reward


# %%
def add_rollout_baseline_to_tree(
    tree: Tree,
    algorithm: Sequence[Tuple[Action, np.ndarray]],
):
    if not "visser" in algorithm[0][0].name:
        tree.rollout_baselines.append(algorithm)
        tree.rollout_baseline_indices.append(
            len(tree.rollout_baselines) - 1
        )  # add the index of the new baseline
    if algorithm[0][0].num_matrices == 1:
        tree.single_rollout_baselines.append(algorithm)
        tree.single_rollout_baseline_indices.append(
            len(tree.single_rollout_baselines) - 1
        )

# ------------------------------------------------------------------
#  Driver
# ------------------------------------------------------------------
def run_mcts(
    tree: Tree,
    c_ucb: float = 1.4,
    budget: int = 1000,
    print_every: int = 100,
    max_termination_count: int = 10,
    termination_epsilon: float = 1e-4,
    EXPLORE_K: int = 5,
    _run=DummyRun(),
):

    a_spectrum = tree.state0
    outer_iters = (budget - print_every + 1) // print_every
    path = None
    total_cumulative_reward = 0.0
    avg_cumulative_reward = 0.0
    old_avg_cumulative_reward = 0.0

    total_rollout_reward = 0.0
    avg_rollout_reward = 0.0

    termination_count = 0

    for outer in richerator(
        range(outer_iters),
        description="MCTS",
        refresh_per_second=2,
    ):
        it = outer * print_every
        print(f"=== Iteration {it} ===")
        print(tree.size, " nodes in tree")
        #print(tree.nodes_each_depth)

        if path:
            old_avg_cumulative_reward = avg_cumulative_reward
            avg_cumulative_reward = total_cumulative_reward / print_every
            _run.log_scalar("avg_cumulative_reward", avg_cumulative_reward, it)

            avg_rollout_reward = total_rollout_reward / print_every
            _run.log_scalar("avg_rollout_reward", avg_rollout_reward, it)

            print(f"Path: {path}")
            print(f"Average cumulative reward:\t {avg_cumulative_reward}")
            print(f"Average rollout reward:\t\t {avg_rollout_reward}")
            print(f"Termination count: {termination_count}")
            print(f"Best value of root node: {tree.best_value[0]}")
            print(f"Best path: {tree.best_path}")
            '''
            for nodes in tree.best_path:

                if tree._action_completed(nodes):
                    print(tree.actions[tree.action_idx[nodes]].name)
                    print(tree.theta[nodes])
                    print(tree.coupled[nodes])
                    print(tree.visits[nodes])
                    print(tree.children[nodes][0:10])
                    print(tree.best_value[tree.children[nodes]][0:10])
                    print(tree.visits[tree.children[nodes]][0:10])  
                    print(_ucb(tree.best_value[tree.children[nodes]]*tree.visits[tree.children[nodes]], tree.visits[tree.children[nodes]], math.log(tree.visits[nodes]+1), c_ucb))
                    action_idx = [tree.action_idx[i] for i in tree.children[nodes]]
                    print([tree.actions[i].name for i in action_idx])
                    print("-----------------------------------------")
            '''
            # if the avg_cumulative_reward doesnt change for 10 iterations, terminate
            if (
                abs(avg_cumulative_reward - old_avg_cumulative_reward)
                < termination_epsilon
            ):
                termination_count += 1
                if termination_count > max_termination_count:
                    print("Terminating search due to no improvement.")
                    break
            else:
                termination_count = 0

        total_cumulative_reward = 0.0
        total_rollout_reward = 0.0
        for i in range(print_every):
            path, edge_rewards = select_pw(tree, c_ucb=c_ucb)
            node = path[-1]
            if tree.terminal[node]:
                r_roll = 0.0  # terminal bonus optional
                cumulative_reward = backprop(tree, path, edge_rewards, r_roll, [])
                total_cumulative_reward += cumulative_reward
                continue

            child, reward = expand(tree, node, EXPLORE_K)
            path.append(child)
            edge_rewards.append(reward)
            r_roll, rollout_alg = rollout(tree, child)

            total_rollout_reward += r_roll
            cumulative_reward = backprop(tree, path, edge_rewards, r_roll, rollout_alg)
            total_cumulative_reward += cumulative_reward


    # extract policy
    root_policy_visits, visits_reward = extract_policy(tree, root=0, criterion="visits")
    root_policy_value, value_reward = extract_policy(tree, root=0, criterion="value")
    root_policy_best_value, best_value_reward = extract_policy(
        tree, root=0, criterion="best_value"
    )

    out = (
        root_policy_visits,
        root_policy_value,
        root_policy_best_value,
        visits_reward,
        value_reward,
        best_value_reward,
    )
    _run.log_scalar("BestValue", tree.best_value[0])
    return out


def extract_policy(tree: Tree, root: int = 0, criterion: str = "visits"):
    """Return the *sequence* of (Action, θ) that leads from *root* to a
    terminal node, picking **one** child at each level according to
    *criterion*.

    Parameters
    ----------
    tree : Tree
        Finished search tree.
    root : int, default 0
        Index of the starting node (typically the root).
    criterion : {"visits", "value", "best_value"}, default "visits"
        Metric for choosing the best child at each step.

    Returns
    -------
    algo : list[tuple[Action, np.ndarray]]
        Ordered list of *(action_family, θ)* pairs that constitute the
        discovered iteration algorithm.
    total_est_reward : float
        Expected (negative) total reward of this path according to the
        tree statistics.
    """
    path_actions = []
    node = root
    total_est_reward = 0.0
    while not tree.terminal[node]:
        print(f"Node {node} is not terminal. Continue.")
        best = -1
        j = tree.first_child[node]
        if j == -1:
            print("No children found. Stop.")
            total_est_reward += rollout(tree, node)[0]
            break  # dead‑end; no terminal reachable

        kids = tree.children[node]

        if criterion == "visits":
            best = kids[np.argmax(tree.visits[kids])]
        elif criterion == "value":
            best = kids[np.argmax(tree.value[kids] / tree.visits[kids])]
        elif criterion == "best_value":
            best = kids[np.argmax(tree.best_value[kids])]
        if best == -1:
            break  # safety
        
        if tree._action_completed(node) and node != 0:
            act = tree.actions[tree.action_idx[node]]
            theta = tree.theta[node, : act.k].copy()
            path_actions.append((act, theta))
            
        total_est_reward += tree.edge_reward[best]
        node = best
        if tree.terminal[node]:
            print(f"Terminal node {node} reached. Stop.")
            break


    if criterion == "best_value":
        path_actions = []
        
        for node in tree.best_path:
            if tree._action_completed(node) and node != 0:
                act = tree.actions[tree.action_idx[node]]
                theta = tree.theta[node, : act.k].copy()
                path_actions.append((act, theta))

        if not tree._action_completed(tree.best_path[-1]):

            node = tree.best_path[-1]
            cur_filled = tree.num_params[node]
            theta0 = tree.theta[node].copy()
            action = tree.actions[tree.action_idx[node]]
            inp = ActionInput(
                current_spectra=tree.states[node],
                a_spectrum=tree.state0[0],
                theta=[]
            )
            _, default_params = action.baseline_spectral_iteration(inp)

            while cur_filled < action.k:
                theta0[cur_filled] = default_params[cur_filled]
                cur_filled += 1
            path_actions.append((action, theta0))
        return path_actions + tree.rollout_alg, 0
    else:
        return path_actions, total_est_reward


if __name__ == "__main__":
    import torch
    from make_algorithm.actions import ACTIONS, estimate_relative_times

    spec0 = np.random.randn(1000)
    actions = sorted(
        [
            ACTIONS["sign_ns"],
            ACTIONS["sign_newton"],
            ACTIONS["sign_newton_variant"],
        ],
        key=lambda a: a.name,
    )
    estimate_relative_times(
        actions,
        size=spec0.shape[0],
        repeats=10,
        device="cuda",
    )
    epsilon = 1e-4

    tree = Tree(
        capacity=2048,
        state0=spec0,
        term_pred=lambda x, y: LOSSES[SIGN](x, y) < epsilon,
        alpha_pw=0.3,
        actions=actions,
    )

    root_policy_visits, root_policy_value, visits_reward, value_reward = run_mcts(
        tree,
        a_spectrum=spec0,
        budget=int(1e6),
        print_every=int(1e3),
    )

    print("=== RESULT ===")
    print(f"By Visits: estimated reward: {visits_reward}")
    for act, theta in root_policy_visits:
        print(act.name, theta)

    print(f"By Value: estimated reward: {value_reward}")
    for act, theta in root_policy_value:
        print(act.name, theta)
    print("=== END ===")
