import re
import os
import glob
import csv
import math
import statistics




def parse_filename(fn):
    # Expect filenames like: id_rcliff_{data}-{modality}-{stage}_n{n_samples}_..._totalthreshold{threshold}.csv
    base = os.path.basename(fn)
    m = re.match(r'id_rcliff_(?P<data>[^-]+)-(?P<modality>[^-]+)-(?P<stage>[^_]+)_n(?P<n>\d+).*_totalthreshold(?P<threshold>[0-9\.eE+-]+)\.csv', base)
    if not m:
        return None
    info = m.groupdict()
    info['threshold'] = float(info['threshold'])
    return info


def collect_results(results_dir='03_results/reports/cliff'):
    pattern = os.path.join(results_dir, 'id_rcliff_*.csv')
    files = sorted(glob.glob(pattern))
    records = []
    for fn in files:
        info = parse_filename(fn)
        if info is None:
            continue
        try:
            with open(fn, 'r', newline='') as fh:
                reader = csv.DictReader(fh)
                last_row = None
                for row in reader:
                    last_row = row
        except Exception as e:
            print(f"Failed to read {fn}: {e}")
            continue

        outcome = None
        if last_row is not None and 'latent_dim' in last_row and last_row['latent_dim'] != '':
            try:
                outcome = float(last_row['latent_dim'])
            except Exception:
                outcome = None

        records.append({
            'file': fn,
            'data': info['data'],
            'modality': info['modality'],
            'stage': info['stage'],
            'n': int(info['n']),
            'threshold': info['threshold'],
            'outcome': outcome,
        })

    return records


def summarize(records):
    # group by modality and threshold and report mean +- SEM of the outcome
    groups = {}
    for rec in records:
        key = (rec['modality'], rec['threshold'])
        groups.setdefault(key, []).append(rec['outcome'])

    out = []
    for (modality, threshold), vals in sorted(groups.items()):
        vals = [v for v in vals if v is not None]
        count = len(vals)
        if count == 0:
            mean = None
            sem = None
        elif count == 1:
            mean = vals[0]
            sem = 0.0
        else:
            mean = sum(vals) / count
            std = statistics.stdev(vals)
            sem = std / math.sqrt(count)
        out.append({'modality': modality, 'threshold': threshold, 'mean': mean, 'sem': sem, 'count': count})
    return out


def main():
    results_dir = os.path.join(os.path.dirname(__file__), '..', '..', '03_results', 'reports', 'cliff')
    results_dir = os.path.abspath(results_dir)
    print(f"Scanning {results_dir} for cliff results...")
    records = collect_results(results_dir)
    if len(records) == 0:
        print('No results found.')
        return
    summary = summarize(records)
    print('\nSummary (mean \u00b1 SEM of outcome):')
    print(f"{'modality':<10} {'threshold':>8} {'mean':>8} {'sem':>8} {'n':>4}")
    for row in summary:
        mean_s = f"{row['mean']:.3f}" if row['mean'] is not None else '   -  '
        sem_s = f"{row['sem']:.3f}" if row['sem'] is not None else '   -  '
        print(f"{row['modality']:<10} {row['threshold']:8.3f} {mean_s:>8} {sem_s:>8} {row['count']:4d}")


if __name__ == '__main__':
    main()
