#!/usr/bin/env python3
"""
Analyze molecule movement CSV.

Computes:
- percentage of "inside_corridor"
- mean of "movement_travelled"
- sum of "reached" and "reached_orientation_at_goal"
- sum of "crashed"
- prints the (single) value of "final_orientations"
- for "final_distances": parse values (including list-like strings),
  look at all values > 0.3 and report:
    - total count
    - mean of those values
- row counts per identical goal_x
"""

import argparse
import ast
import re

import pandas as pd


def to_numeric_series(series, extra_boolish=False):
    """
    Convert a pandas Series to numeric (float), with some robustness.
    If extra_boolish=True, also map 'true/false/yes/no' to 1/0.
    """
    s = series.copy()

    if extra_boolish:
        bool_map = {
            "true": 1, "t": 1, "yes": 1, "y": 1, "1": 1,
            "false": 0, "f": 0, "no": 0, "n": 0, "0": 0,
        }

        def map_boolish(x):
            if pd.isna(x):
                return x
            txt = str(x).strip().lower()
            return bool_map.get(txt, x)

        s = s.map(map_boolish)

    return pd.to_numeric(s, errors="coerce")


def parse_distance_cell(val):
    """
    Parse one cell from 'final_distances' into a list of floats.

    Handles:
    - plain numbers: "0.42"
    - Python-literal lists/tuples: "[0.1, 0.5]", "(0.2, 0.4)"
    - separated values: "0.1,0.5", "0.1;0.5", "0.1 0.5"
    Returns [] if nothing valid.
    """
    if pd.isna(val):
        return []

    s = str(val).strip()
    if not s:
        return []

    # Try Python literal first: [0.1, 0.5], (0.2, 0.4), 0.3, etc.
    try:
        parsed = ast.literal_eval(s)
        if isinstance(parsed, (int, float)):
            return [float(parsed)]
        elif isinstance(parsed, (list, tuple)):
            out = []
            for el in parsed:
                try:
                    out.append(float(el))
                except Exception:
                    continue
            return out
        # if it's something else, fall through to separator logic
    except Exception:
        pass

    # Try splitting on common delimiters
    if any(sep in s for sep in [",", ";", " "]):
        parts = re.split(r"[,\s;]+", s)
        out = []
        for p in parts:
            if not p:
                continue
            try:
                out.append(float(p))
            except Exception:
                continue
        if out:
            return out

    # Last resort: single float
    try:
        return [float(s)]
    except Exception:
        return []


def main(path):
    df = pd.read_csv(path)

    print(f"File: {path}")
    print("=" * 60)

    # 1) percentage of "inside_corridor"
    if "inside_corridor" in df.columns:
        inside_num = to_numeric_series(df["inside_corridor"], extra_boolish=True)
        if inside_num.notna().any():
            percentage_inside = inside_num.mean() * 100.0
            print(f"inside_corridor: {percentage_inside:.2f}%")
        else:
            print("inside_corridor: no usable numeric data")
    else:
        print("inside_corridor: column not found")

    # 2) mean "movement_travelled"
    if "movement_travelled" in df.columns:
        mv = to_numeric_series(df["movement_travelled"])
        if mv.notna().any():
            print(f"movement_travelled: mean = {mv.mean():.6f}")
        else:
            print("movement_travelled: no usable numeric data")
    else:
        print("movement_travelled: column not found")

    # 3) sum of "reached" and "reached_orientation_at_goal"
    human_readable = {"reached": "reached target position", "reached_orientation_at_goal": "reached orientation at target position"}
    for col in ["reached", "reached_orientation_at_goal"]:
        if col in df.columns:
            s = to_numeric_series(df[col], extra_boolish=True)
            if s.notna().any():
                print(f"{human_readable[col]}: sum = {s.sum():.0f}")
            else:
                print(f"{human_readable[col]}: no usable numeric data")
        else:
            print(f"{col}: column not found")

    # 4) sum "crashed"
    if "crashed" in df.columns:
        crashed = to_numeric_series(df["crashed"], extra_boolish=True)
        if crashed.notna().any():
            print(f"crashed: sum = {crashed.sum():.0f}")
        else:
            print("crashed: no usable numeric data")
    else:
        print("crashed: column not found")

    # 5) final_orientations: just print the (single) value
    if "final_orientations" in df.columns:
        non_null = df["final_orientations"].dropna()
        if len(non_null) == 0:
            print("final_orientations: no value (all NaN)")
        else:
            value = non_null.iloc[0]
            print(f"wrong oriented: {value}")
    else:
        print("final_orientations: column not found")

    # 6) final_distances: parse and analyze values > 0.3
    if "final_distances" in df.columns:
        all_vals = []
        for val in df["final_distances"]:
            all_vals.extend(parse_distance_cell(val))

        if not all_vals:
            print("final_distances: no numeric values parsed")
        else:
            over = [v for v in all_vals if v > 0.3]
            count_over = len(over)
            if count_over > 0:
                mean_over = sum(over) / count_over
                print("final_distances > 0.3:")
                print(f"  total count = {count_over}")
                print(f"  mean        = {mean_over:.6f}")
            else:
                print("wrong positioned: no values > 0.3")
    else:
        print("final_distances: column not found")

    # 7) row counts per identical (goal_x, goal_y)
    if "goal_x" in df.columns and "goal_y" in df.columns:
        goal_counts = (
            df.groupby(["goal_x", "goal_y"])
              .size()
              .reset_index(name="row_count")
              .sort_values(["goal_x", "goal_y"])
        )
        #print("\nRow counts per (goal_x, goal_y):")
        #print(goal_counts.to_string(index=False))

        # 8) number of (goal_x, goal_y) pairs with row_count >= 1000
        big_pairs = goal_counts[goal_counts["row_count"] >= 1000]
        num_big_pairs = len(big_pairs)
        # optional: show them
    else:
        missing = [c for c in ["goal_x", "goal_y"] if c not in df.columns]
        print(f"(goal_x, goal_y) counts: missing columns: {', '.join(missing)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze molecule movement CSV.")
    parser.add_argument("csv_path", help="Path to CSV file to analyze")
    args = parser.parse_args()
    main(args.csv_path)
