import os
import argparse
import numpy as np
import torch
from procrustes.permutation import permutation_2sided

def distance_W(a: np.ndarray, b: np.ndarray) -> float:
    """Compute permutation-based Frobenius error between two weight matrices."""
    result = permutation_2sided(a, b, single=True, method="approx-normal1")
    return result.error

if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Compute pairwise weight distances")
    p.add_argument('--task',        type=str, required=True, help="Name of the task subfolder under model_dir/weights")
    p.add_argument('--model_dir',   type=str, default='models', help="Root directory containing <task>/weights/seed_*.pt")
    p.add_argument('--save_dir',    type=str, default='RNN-degeneracy/degeneracy/data/WD', help="Where to dump the .npy distance files")
    p.add_argument('--start_seed',  type=int, default=0, help="First seed index")
    p.add_argument('--end_seed',    type=int, default=50, help="One past the last seed index")
    p.add_argument('--idx',         type=int, required=True, help="Index into the seed array to compare against previous seeds")
    args = p.parse_args()

    seeds = np.arange(args.start_seed, args.end_seed)
    target_seed = seeds[args.idx]
    model_fp = os.path.join(args.model_dir, args.task, 'weights', f'seed_{target_seed}.pt')
    model_dict = torch.load(model_fp, map_location='cpu')

    def _get_W(d):
        return d.get('J') or d.get('recurrent.weight_hh_l0')
    W_target = _get_W(model_dict).cpu().numpy()

    # Compute distances to all earlier seeds
    distances = []
    for prev_seed in seeds[: args.idx]:
        fp = os.path.join(args.model_dir, args.task, 'weights', f'seed_{prev_seed}.pt')
        if not os.path.isfile(fp):
            continue
        prev_dict = torch.load(fp, map_location='cpu')
        W_prev = _get_W(prev_dict).cpu().numpy()
        distances.append(distance_W(W_target, W_prev))

    # Save
    out_folder = os.path.join(args.save_dir, args.task)
    os.makedirs(out_folder, exist_ok=True)
    out_file = os.path.join(out_folder, f"{args.idx}.npy")
    np.save(out_file, distances)

    print(f"Saved {len(distances)} distances to {out_file}")
