import random
from dataclasses import dataclass
from StateActionTracker import StateActionStack
from utils import hash_observation


@dataclass
class FixCandidates:
    state: list
    actions: list
    round: int
    alive_nodes: int
    setting_id: int

    def __str__(self):
        return f"FixCandidates(state={self.state}, actions={self.actions})"


def find_similar_top_k(prob_vector, threshold, max_diff, dominance_margin):
    """
    Find the largest subset of valid probabilities where all values are similar,
    and no single value dominates.

    Args:
        prob_vector (list of float): The input probability vector.
        threshold (float): Only values greater than this are considered valid.
        max_diff (float): Maximum allowed difference between any two in the subset.
        dominance_margin (float): All values must be within this of the global max.

    Returns:
        list of tuples: List of (index, value) for selected similar actions, or empty list.
    """
    # Filter valid values
    valid_probs = [(i, p) for i, p in enumerate(prob_vector) if p > threshold]
    if len(valid_probs) < 2:
        return None

    p_max = max(p for _, p in valid_probs)

    # Only consider items within dominance margin
    filtered = [(i, p) for i, p in valid_probs if p_max - p <= dominance_margin]

    # Sort for easier diff checking
    filtered.sort(key=lambda x: x[1], reverse=True)

    # Try to find the largest group of elements that are all close to each other
    best_group = []

    for i in range(len(filtered)):
        group = [filtered[i]]
        for j in range(i + 1, len(filtered)):
            if abs(filtered[j][1] - group[0][1]) <= max_diff:
                group.append(filtered[j])
            else:
                break
        if len(group) > len(best_group):
            best_group = group

    if len(best_group) < 2:
        return None

    return tuple(best_group)


# Current design, fix actions with close probablities
def fix_close_prob(dfs_tracker: StateActionStack, state, policy_output, round, lost_value, setting_id):
    fix_groups = find_similar_top_k(policy_output, threshold=0.2, max_diff=0.1, dominance_margin=0.1)
    if fix_groups is not None:
        actions = [pair[0] for pair in fix_groups]
        candidate = FixCandidates(state, actions, round, sum(1 for val in state[1:] if val != lost_value), setting_id)
        return candidate
    return None


def store_candidates(
    num_round, dfs_tracker, players, states, policy_outputs, actions, lost_value, setting_id, candidates: list[FixCandidates]
):
    """
    Fix actions to some specific states that meet criteria

    :param mcts: MCTS object
    :param env_mgr: EnvManager object
    :param states: list of states of all rounds
    :param policy_outputs: list of policy training traget of all rounds after MCTS
    """
    index_pairs = []
    for r_idx in range(num_round):
        for p_idx in range(players):
            index_pairs.append((r_idx, p_idx))

    random.shuffle(index_pairs)

    for r_idx, p_idx in index_pairs:
        if actions[r_idx][p_idx] is lost_value:
            continue
        candidate = fix_close_prob(dfs_tracker, states[r_idx][p_idx], policy_outputs[r_idx][p_idx], r_idx, lost_value, setting_id)
        if candidate:
            print(f"Found candidate: {candidate}")
            candidates.append(candidate)


def select_candidate_to_fix(candidates: list[FixCandidates], dfs_tracker: StateActionStack):
    if not candidates:
        return False

    # Sort by: (round descending, alive_nodes descending)
    candidates.sort(key=lambda x: (x.round, x.alive_nodes), reverse=True)
    pushed = dfs_tracker.push_state(candidates[0].state, set(candidates[0].actions))
    if pushed:
        dfs_tracker.set_fixing_action(candidates[0].state, candidates[0].actions[0])
        dfs_tracker.record_state(candidates[0].state, candidates[0].setting_id)
        return True
    else:
        return False


def unfix(num_round, dfs_tracker: StateActionStack, players, states, rewards, actions, lost_value):
    unfix_states = set()
    for round in range(num_round):
        for i in range(players):
            if actions[round][i] is lost_value:
                continue
            if rewards[i] < 0 and dfs_tracker.is_fixed(states[round][i]) and not dfs_tracker.is_locked(states[round][i]):
                unfix_states.add(hash_observation(states[round][i]))

    fixed_action = []
    for state in unfix_states:
        fixed_action.append(dfs_tracker.get_fixing_action(state))

    # IMPORTANT: this condition is hard coded.
    if len(fixed_action) == 0:
        return False
    elif len(unfix_states) == 1:
        # If only one action is fixed, we can unfix it
        unfix_key = list(unfix_states)[0]
    else:
        unfix_key = dfs_tracker.find_latest_key(list(unfix_states))

    dfs_tracker.mark_explored(unfix_key, dfs_tracker.get_fixing_action(unfix_key))  # Mark the action as explored
    next_action = dfs_tracker.pop_next_unexplored_action(unfix_key)
    if next_action is not None:
        dfs_tracker.set_fixing_action(unfix_key, next_action)
    else:
        dfs_tracker.pop_specific_key(unfix_key)
    return True


def fix_specific(dfs_tracker: StateActionStack, state, action):
    """
    Fix an action for a specific state
    """
    dfs_tracker.push_state(state, {action})
    dfs_tracker.set_fixing_action(state, action)


def find_neighbors(state):
    """
    TODO:
    Find neighbors of a specific state
    """
    pass
