import numpy as np
from pprint import pprint
import argparse
import os

def knn_evaluate(
    D: np.ndarray,
    y_train: np.ndarray,
    y_test: np.ndarray,
    k: int = 1,
):
    idx = np.argpartition(D, k, axis=1)[:, :k]
    y_pred = np.apply_along_axis(lambda x: np.bincount(x).argmax(), 1, y_train[idx])

    acc = (y_pred == y_test).mean()
    
    unique_classes = np.unique(y_test)
    per_class_acc = {}
    for c in unique_classes:
        mask = (y_test == c)
        per_class_acc[c] = (y_pred[mask] == y_test[mask]).mean()
    
    return acc, per_class_acc

parser = argparse.ArgumentParser(description='Compute pairwise losses with selectable method.')
parser.add_argument('--method', type=str, default='IQDOT', choices=['IQDOT', 'SGW'],
    help='Distance')
parser.add_argument('--point_step', type=int, help='sample points', default=1)
args = parser.parse_args()
point_step = args.point_step

path = os.path.join('losses', f'Loss-{args.method.lower()}-{int(2048 / point_step)}.npy')
loss = np.load(path)

y_train = np.load('modelnet40_7cls_labels_7cls_remap.npy')
y_test = np.load('shapenetpart_7cls_labels_7cls_remap.npy')
print(y_train.shape)
print(y_test.shape)
print(loss.shape)

k = 1
print("="*5+"Modelnet40→ShapeNetPart"+"="*5)
pprint(knn_evaluate(loss, y_train, y_test, k))
print("="*5+"ShapeNetPart→Modelnet40"+"="*5)
pprint(knn_evaluate(loss.T, y_test, y_train, k))