"""
Check feasible control points for REACH and CARRY phases.

This script analyzes the pre-computed valid control point indices
from stack_blocks_init.npz and reports overlap statistics.
"""

import numpy as np
import os

def main():
    init_file = os.path.join(os.path.dirname(__file__), "stack_blocks_init.npz")

    if not os.path.exists(init_file):
        print(f"ERROR: Init file not found: {init_file}")
        print("Please run stack_block_init.py first.")
        return

    print("=" * 60)
    print("FEASIBLE CONTROL POINTS ANALYSIS")
    print("=" * 60)

    data = np.load(init_file)

    canonical_params = data['canonical_params']
    reach_valid = data['reach_valid_indices']
    carry_valid = data['carry_valid_indices']

    n_total = len(canonical_params)
    n_reach = len(reach_valid)
    n_carry = len(carry_valid)

    reach_set = set(reach_valid)
    carry_set = set(carry_valid)

    both_set = reach_set & carry_set
    reach_only_set = reach_set - carry_set
    carry_only_set = carry_set - reach_set
    neither_set = set(range(n_total)) - reach_set - carry_set

    print(f"\nCanonical control points: {n_total} (5x5x5 grid)")
    print(f"  - (angle, dist_frac, pos_frac) parameterization")
    print(f"  - Control point radius: {float(data['control_point_radius']):.3f}m")
    print(f"  - Jitter threshold: {float(data['jitter_threshold']):.3f}")

    print(f"\n--- Validity Summary ---")
    print(f"  Valid for REACH:  {n_reach:3d} / {n_total} ({100*n_reach/n_total:.1f}%)")
    print(f"  Valid for CARRY:  {n_carry:3d} / {n_total} ({100*n_carry/n_total:.1f}%)")
    print(f"  Valid for BOTH:   {len(both_set):3d} / {n_total} ({100*len(both_set)/n_total:.1f}%)")
    print(f"  REACH only:       {len(reach_only_set):3d} / {n_total} ({100*len(reach_only_set)/n_total:.1f}%)")
    print(f"  CARRY only:       {len(carry_only_set):3d} / {n_total} ({100*len(carry_only_set)/n_total:.1f}%)")
    print(f"  Neither:          {len(neither_set):3d} / {n_total} ({100*len(neither_set)/n_total:.1f}%)")

    # Analyze the canonical params for each category
    print(f"\n--- Parameter Analysis ---")

    def analyze_params(indices, name):
        if len(indices) == 0:
            print(f"  {name}: (none)")
            return
        params = canonical_params[list(indices)]
        angles = params[:, 0]
        dist_fracs = params[:, 1]
        pos_fracs = params[:, 2]
        print(f"  {name} ({len(indices)} CPs):")
        print(f"    angle:     min={angles.min():.2f}, max={angles.max():.2f}, mean={angles.mean():.2f}")
        print(f"    dist_frac: min={dist_fracs.min():.2f}, max={dist_fracs.max():.2f}, mean={dist_fracs.mean():.2f}")
        print(f"    pos_frac:  min={pos_fracs.min():.2f}, max={pos_fracs.max():.2f}, mean={pos_fracs.mean():.2f}")

    analyze_params(both_set, "BOTH")
    analyze_params(reach_only_set, "REACH only")
    analyze_params(carry_only_set, "CARRY only")

    # Print actual indices
    print(f"\n--- Valid Indices ---")
    print(f"  BOTH ({len(both_set)}): {sorted(both_set)}")
    if len(reach_only_set) > 0:
        print(f"  REACH only ({len(reach_only_set)}): {sorted(reach_only_set)}")
    if len(carry_only_set) > 0:
        print(f"  CARRY only ({len(carry_only_set)}): {sorted(carry_only_set)}")

    # Combination analysis for independent sampling
    print(f"\n--- Independent Sampling Analysis ---")
    print(f"  If REACH and CARRY use SAME cp_idx:")
    print(f"    Available combinations: {len(both_set)}")
    print(f"  If REACH and CARRY use DIFFERENT cp_idx:")
    print(f"    Available combinations: {n_reach} x {n_carry} = {n_reach * n_carry}")

    # Show waypoint positions
    print(f"\n--- Waypoint Positions ---")
    print(f"  HOME:       [{data['home_pos'][0]:.4f}, {data['home_pos'][1]:.4f}, {data['home_pos'][2]:.4f}]")
    print(f"  Pregrasp:   [{data['pregrasp_pos'][0]:.4f}, {data['pregrasp_pos'][1]:.4f}, {data['pregrasp_pos'][2]:.4f}]")
    print(f"  Grasp:      [{data['grasp_pos'][0]:.4f}, {data['grasp_pos'][1]:.4f}, {data['grasp_pos'][2]:.4f}]")
    print(f"  Lift:       [{data['lift_pos'][0]:.4f}, {data['lift_pos'][1]:.4f}, {data['lift_pos'][2]:.4f}]")
    print(f"  Prerelease: [{data['prerelease_pos'][0]:.4f}, {data['prerelease_pos'][1]:.4f}, {data['prerelease_pos'][2]:.4f}]")
    print(f"  Release:    [{data['release_pos'][0]:.4f}, {data['release_pos'][1]:.4f}, {data['release_pos'][2]:.4f}]")
    print(f"  Object:     [{data['object_pos'][0]:.4f}, {data['object_pos'][1]:.4f}, {data['object_pos'][2]:.4f}]")
    print(f"  Target:     [{data['target_pos'][0]:.4f}, {data['target_pos'][1]:.4f}, {data['target_pos'][2]:.4f}]")

    print(f"\n" + "=" * 60)
    print("ANALYSIS COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()
