import json
import argparse
from collections import Counter
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import os


def load_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            yield json.loads(line)


def main(val_path, test_path, output_path, model_name='all-MiniLM-L6-v2', top_k=5, similarity_threshold=0.9):
    # Load validation (query) questions
    val_data = list(load_jsonl(val_path))
    val_questions = [item['question'] for item in val_data]

    # Load test (candidate) questions
    test_data = list(load_jsonl(test_path))
    test_questions = [item['question'] for item in test_data]
    test_answers = [item.get('answer', item.get('answer_eval')) for item in test_data]
    test_image_paths = []
    for item in test_data:
        if isinstance(item['image_path'], list):
            test_image_paths.append(item["image_path"][0])
        else:
            test_image_paths.append(item["image_path"])
    test_evidences = []
    if "evidence" in test_data[0]:
        test_evidences = [f"Title: {item['wikipedia_title']}\n"+item['evidence'] for item in test_data]

    # Check existence of image files
    valid_test_idxs = [i for i, path in enumerate(test_image_paths) if os.path.exists(path)]
    test_questions = [test_questions[i] for i in valid_test_idxs]
    test_answers = [test_answers[i] for i in valid_test_idxs]
    test_image_paths = [test_image_paths[i] for i in valid_test_idxs]

    # Compute embeddings
    model = SentenceTransformer(model_name)
    val_embs = model.encode(val_questions, convert_to_tensor=False)
    test_embs = model.encode(test_questions, convert_to_tensor=False)

    # Compute cosine similarity
    sims = cosine_similarity(val_embs, test_embs)

    # For each val question, retrieve top_k similar test examples with de-duplication and valid images
    with open(output_path, 'w', encoding='utf-8') as fout:
        for idx, val_item in enumerate(val_data):
            sim_scores = sims[idx]
            sorted_idxs = np.argsort(-sim_scores)

            few_qs = []
            few_as = []
            few_imgs = []
            few_embs = []
            few_evidences = []

            for i in sorted_idxs:
                if len(few_qs) >= top_k:
                    break

                candidate_q = test_questions[i]
                #answer_counts = Counter(test_answers[i])
                #candidate_a = [ans for ans, _ in answer_counts.most_common()][0]
                candidate_a = test_answers[i]
                candidate_img = test_image_paths[i]
                candidate_emb = test_embs[i]
                candidate_evidence = test_evidences[i]

                if not few_qs:
                    few_qs.append(candidate_q)
                    few_as.append(candidate_a)
                    few_imgs.append(candidate_img)
                    few_embs.append(candidate_emb)
                    few_evidences.append(candidate_evidence)
                else:
                    new_sim = cosine_similarity(
                        [candidate_emb], few_embs
                    ).max()
                    if new_sim < similarity_threshold:
                        few_qs.append(candidate_q)
                        few_as.append(candidate_a)
                        few_imgs.append(candidate_img)
                        few_embs.append(candidate_emb)
                        few_evidences.append(candidate_evidence)

            out = {
                #'question_id': val_item.get('question_id'),
                'questions': few_qs,
                'answers': few_as,
                'image_paths': few_imgs,
                'evidence': few_evidences
            }
            fout.write(json.dumps(out, ensure_ascii=False) + '\n')

    print(f"Few-shot examples for validation saved to {output_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Create few-shot examples for validation using test questions as pool")
    parser.add_argument('--val', type=str, required=True, help='Path to validation JSONL file')
    parser.add_argument('--test', type=str, required=True, help='Path to test JSONL file (shots pool)')
    parser.add_argument('--output', type=str, required=True, help='Output JSONL path')
    parser.add_argument('--model', type=str, default='all-MiniLM-L6-v2', help='SentenceTransformer model name')
    parser.add_argument('--top_k', type=int, default=4, help='Number of similar examples to retrieve')
    parser.add_argument('--similarity_threshold', type=float, default=0.95, help='Threshold to filter redundant questions')
    args = parser.parse_args()
    main(args.val, args.test, args.output, args.model, args.top_k, args.similarity_threshold)
