import json
import math
from pathlib import Path
from typing import List, Dict, Any, Set, Tuple

# =========================
# 1. Config
# =========================


INPUT_PATH = "agenttrek-with-importance-diversity-score.json"

T0_MODE = "fixed"
T0_FIXED = 3
T0_PERCENTAGE = 0.25

LAMBDA = 1.0    # weight on diversity term D


# =========================
# 2. Load JSON
# =========================

def load_json(path: str) -> List[Dict[str, Any]]:
    p = Path(path)
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)

items = load_json(INPUT_PATH)
print(f"Loaded {len(items)} goal records")


# =========================
# 3. Helper: build D(i,j) from sims
# =========================

def make_D_from_sims(
    sims_states: List[List[float]],
    sims_responses: List[List[float]],
):
    """
    Given two N x N similarity matrices, return a function D(i,j) such that
    D(s_i,s_j) = max(1 - sims_states[i][j], 1 - sims_responses[i][j]).
    """
    def D(i: int, j: int) -> float:
        if i == j:
            return 0.0
        s_sim = sims_states[i][j]
        r_sim = sims_responses[i][j]
        return max(1.0 - s_sim, 1.0 - r_sim)
    return D


def compute_T0(segment_length: int) -> int:
    """Return T0 for a trajectory segment based on config."""
    if T0_MODE == "fixed":
        return T0_FIXED
    if T0_MODE == "percentage":
        return max(1, math.ceil(segment_length * T0_PERCENTAGE))
    raise ValueError(f"Unsupported T0_MODE={T0_MODE}")


def selected_indices_output_path() -> str:
    if T0_MODE == "fixed":
        return f"full_selected_dataset_indices_T0_{T0_FIXED}_agenttrek.json"
    pct = int(round(T0_PERCENTAGE * 100))
    return f"full_selected_dataset_indices_T0_{pct}pct_agenttrek.json"


# =========================
# 4. Greedy algorithm
# =========================

def greedy_select_indices(
    phi: List[float],
    D_func,
    T0: int,
) -> List[int]:
    """
    Greedy algorithm over local indices 0..(N-1).

    - phi[i]       = Φ(i), scalar score for state i
    - D_func(i,j)  = D(s_i, s_j)
    - T0           = target size |A|

    Returns:
        list of selected local indices (subset of {0,...,N-1}).
    """
    N = len(phi)
    if N <= T0:
        # If we don't even have T0 elements, keep all of them.
        return list(range(N))

    # ----- 1. Initialization: select first two elements -----
    best_pair: Tuple[int, int] = (0, 1)
    best_value = float("-inf")

    for i in range(N):
        for j in range(i + 1, N):
            value = phi[i] + phi[j] + LAMBDA * D_func(i, j)
            if value > best_value:
                best_value = value
                best_pair = (i, j)

    A: Set[int] = set(best_pair)

    # If T0 == 2 we are done
    if T0 == 2:
        return sorted(A)

    # ----- 2. Greedy growth for m = 3..T0 -----
    for m in range(3, T0 + 1):
        best_k = None
        best_delta = float("-inf")

        for k in range(N):
            if k in A:
                continue
            # Δ_k = Φ(k) + λ * Σ_{i∈A} D(s_k, s_i)
            diversity_sum = sum(D_func(k, i) for i in A)
            delta_k = phi[k] + LAMBDA * diversity_sum

            if delta_k > best_delta:
                best_delta = delta_k
                best_k = k

        if best_k is None:
            # No candidate left; stop early
            break

        A.add(best_k)

        if len(A) >= T0:
            break

    return sorted(A)


# =========================
# 5. Main loop over items & trajectory_groups
# =========================
# Keep selections grouped per trajectory for easier inspection.
full_selected_dataset_indices: List[List[int]] = []

for item_idx, item in enumerate(items):
    dataset_indices_all: List[int] = item.get("dataset_indices", [])
    phi_all: List[float] = item.get("bert_scores_obs_history_norm", [])

    if not dataset_indices_all or not phi_all:
        continue

    if len(dataset_indices_all) != len(phi_all):
        print(f"[WARN] item {item_idx}: len(dataset_indices) != len(phi)")
    # Map global dataset index -> local position in this goal's arrays
    idx_to_pos = {idx: pos for pos, idx in enumerate(dataset_indices_all)}

    # trajectory_groups is expected to be a list of trajectory dicts
    traj_groups = item.get("trajectory_groups", [])
    if not isinstance(traj_groups, list):
        continue

    for tg in traj_groups:
        seg_dataset_indices: List[int] = tg.get("dataset_indices", [])
        sims_states: List[List[float]] = tg.get("sims_states", [])
        sims_responses: List[List[float]] = tg.get("sims_responses", [])

        if not seg_dataset_indices or not sims_states or not sims_responses:
            # Nothing to do
            tg["selected_dataset_indices"] = seg_dataset_indices
            continue

        # Build Φ for this trajectory segment (local order = as given in seg_dataset_indices)
        phi_segment: List[float] = []
        missing = False
        for di in seg_dataset_indices:
            if di not in idx_to_pos:
                print(f"[WARN] item {item_idx}: dataset index {di} not in idx_to_pos")
                missing = True
                phi_segment.append(0.0)
            else:
                phi_segment.append(phi_all[idx_to_pos[di]])

        segment_length = len(seg_dataset_indices)
        T0_value = compute_T0(segment_length)
        # Make D function for this segment
        D_func = make_D_from_sims(sims_states, sims_responses)
        # Run greedy selection over local indices 0..(L-1)
        local_selected = greedy_select_indices(phi_segment, D_func, T0=T0_value)

        # Map back to global dataset indices
        selected_dataset_indices = [seg_dataset_indices[i] for i in local_selected]
        full_selected_dataset_indices.append(selected_dataset_indices)
        # import ipdb; ipdb.set_trace()
        # Store result
        tg["selected_dataset_indices"] = selected_dataset_indices

        print(
            f"item {item_idx}, traj with dataset_indices={seg_dataset_indices} "
            f"(T0={T0_value}) -> selected {selected_dataset_indices}"
        )


# Only save the full_selected_dataset_indices
selected_indices_path = selected_indices_output_path()
with open(selected_indices_path, "w") as f:
    json.dump(full_selected_dataset_indices, f, ensure_ascii=False, indent=2)

# Print the lenth of total number of elements in full_selected_dataset_indices
total_length = sum(len(indices) for indices in full_selected_dataset_indices)
print(f"Total length of all selected indices: {total_length}")

