import numpy as np
import os
import sys
from scipy import linalg, dot
import json

features_name = sys.argv[1]
source_split = sys.argv[2]
target_split = sys.argv[3]

print(f"Processing {features_name} ...")
sys.stdout.flush()

source_features_dir = f"./pascal-5i/VOC2012/{features_name}_{source_split}"
target_features_dir = f"./pascal-5i/VOC2012/{features_name}_{target_split}"

trn_set = False
if source_split == 'trn' and target_split == 'trn':
    trn_set = True

print("source_features_dir:", source_features_dir)
print("target_features_dir:", target_features_dir)

meta_root_source = f"./evaluate/splits/pascal/{source_split}"
meta_root_target = f"./evaluate/splits/pascal/{target_split}"

for foldid in [0, 1, 2, 3]:
    feature_file = 'folder' + str(foldid) + '.npz'
    print(f"Processing {feature_file} ...")
    sys.stdout.flush()
    source_path = os.path.join(source_features_dir, feature_file)
    target_path = os.path.join(target_features_dir, feature_file)
    print("source_path:", source_path)
    try:
        source_file_npz = np.load(source_path)
        target_file_npz = np.load(target_path)
    except Exception as e:
        print(f"no folder {feature_file} ...")
        sys.stdout.flush()
        continue

    source_examples = source_file_npz["examples"].tolist()
    target_examples = target_file_npz["examples"].tolist()
    source_features = source_file_npz["features"].astype(np.float32)
    target_features = target_file_npz["features"].astype(np.float32)

    source_features = source_features.reshape(source_features.shape[0], -1)
    target_features = target_features.reshape(target_features.shape[0], -1)

    print('source_features shape: ', source_features.shape)
    print('target_features shape: ', target_features.shape)

    source_meta_path = os.path.join(meta_root_source, f'fold{foldid}.txt')
    target_meta_path = os.path.join(meta_root_target, f'fold{foldid}.txt')

    source_meta_dict = {}
    with open(source_meta_path, 'r') as f:
        for line in f:
            line = line.strip()
            parts = line.split('__')
            if len(parts) == 2:
                img_id, cat = parts
            else:
                img_id, cat = parts[0], "unknown"
            if img_id not in source_meta_dict:
                source_meta_dict[img_id] = []
            if cat not in source_meta_dict[img_id]:
                source_meta_dict[img_id].append(cat)

    target_meta_dict = {}
    with open(target_meta_path, 'r') as f:
        for line in f:
            line = line.strip()
            parts = line.split('__')
            if len(parts) == 2:
                img_id, cat = parts
            else:
                img_id, cat = parts[0], "unknown"
            if img_id not in target_meta_dict:
                target_meta_dict[img_id] = []
            if cat not in target_meta_dict[img_id]:
                target_meta_dict[img_id].append(cat)

    target_sample_idx = np.random.choice(
        target_features.shape[0],
        size=target_features.shape[0],
        replace=False
    )
    target_sample_feature = target_features[target_sample_idx, :]

    similarity = dot(source_features, target_sample_feature.T) / (
            linalg.norm(source_features, axis=1, keepdims=True) *
            linalg.norm(target_sample_feature, axis=1, keepdims=True).T
    )
    similarity_idx = np.argsort(similarity, axis=1)
    print("similarity_idx shape: ", similarity_idx.shape)

    similarity_idx_dict = {}

    for i, (cur_example, cur_similarity) in enumerate(zip(source_examples, similarity_idx)):
        cur_basename_full = os.path.basename(cur_example).strip()
        cur_img_id = os.path.splitext(cur_basename_full)[0]
        source_cat_list = source_meta_dict.get(cur_img_id, ["unknown"])

        for cat in source_cat_list:
            cur_key = f"{cur_img_id}__{cat}"
            filtered_similar_names = []
            for idx in cur_similarity[::-1]:
                candidate_example = target_examples[target_sample_idx[idx]]
                candidate_basename_full = os.path.basename(candidate_example).strip()
                candidate_img_id = os.path.splitext(candidate_basename_full)[0]
                candidate_cat_list = target_meta_dict.get(candidate_img_id, ["unknown"])
                if cat in candidate_cat_list:
                    candidate_key = f"{candidate_img_id}__{cat}"
                    filtered_similar_names.append(candidate_key)
            filtered_similar_names = list(dict.fromkeys(filtered_similar_names))
            assert len(filtered_similar_names) >= 50, (
                f"num of filtered similar names for image {cur_key} is too small, please enlarge the similarity_idx size"
            )
            if trn_set:
                if filtered_similar_names[0] == cur_key:
                    selected_similars = filtered_similar_names[1:51]
                else:
                    selected_similars = filtered_similar_names[:50]
            else:
                selected_similars = filtered_similar_names[:]
            similarity_idx_dict[cur_key] = selected_similars

    output_path = os.path.join(source_features_dir, f"folder{foldid}_top_50-similarity.json")
    with open(output_path, "w") as outfile:
        json.dump(similarity_idx_dict, outfile)
    print(f"Saved results to {output_path}")