from pathlib import Path
import json
import traceback
import numpy as np
from transformers import AutoModel, AutoTokenizer
from add_embeddings_to_documents import get_chunked_embeddings
from tqdm.auto import tqdm
from argparse import ArgumentParser
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


class multi_vector_index:
    def __init__(self, documents):
        self.documents = documents
        self.index = list()
        self.idx2doc = list()
        idx = 0
        for doc in self.documents:
            embs = doc['code_embedding']
            for emb in embs:
                if len(emb) == 1:
                    emb = emb[0]
                self.index.append(emb)
                self.idx2doc.append(doc['id'])
                idx += 1
        self.index = np.array(self.index)

    def search(self, input_embs, k):
        "Return top k documents without duplicates"
        input_embs = np.array(input_embs)
        if len(input_embs.shape) == 3:
            input_embs = input_embs.squeeze(1)
        # Compute the similarity scores
        scores = np.dot(input_embs, self.index.T)
        # Flatten the scores and get indices of the top-k scores
        flat_scores = scores.flatten()
        sorted_indices = np.argsort(flat_scores)[::-1]  # reverse to get descending order
        top_k_indices = sorted_indices[:k]
        # Convert flat indices back to 2D
        embedding_indices, index_indices = np.unravel_index(top_k_indices, scores.shape)
        results = []
        seen_docs = set()
        for emb_idx, idx_idx in zip(embedding_indices, index_indices):
            docid = self.idx2doc[idx_idx]
            if docid not in seen_docs:
                seen_docs.add(docid)
                results.append({'docid': docid, 'score': scores[emb_idx, idx_idx]})
        return results
    
    # def search(self, input_embs, k):
    #     "Return top k documents without duplicates"
    #     breakpoint()
    #     input_embs = np.array(input_embs)
    #     if len(input_embs.shape) == 3:
    #         input_embs = input_embs.squeeze(0)
    #     scores = np.dot(input_embs, self.index.T)
    #     sorted_idx = np.argsort(scores, axis=1)
    #     results = list()
    #     for idx in sorted_idx:
    #         for i in idx:
    #             if self.idx2doc[i] not in results:
    #                 results.append({'docid': self.idx2doc[i], 'score': scores[i]})
    #                 break
    #         if len(results) == k:
    #             break
    #     return results


def main(instances_files, indexes_dir, output_dir, encoder):
    model = AutoModel.from_pretrained(encoder)
    tokenizer = AutoTokenizer.from_pretrained(encoder)
    if not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True)
    for instances_file in instances_files:
        instances = {instance['instance_id']: instance for instance in [json.loads(line) for line in open(instances_file)]}
        output_file = Path(output_dir, Path(instances_file).stem + '.retrieval.jsonl')
        missing_indexes = list()
        failed_indexes = list()
        problem_files = list()
        with open(output_file, 'w') as f:
            for instance_id, instance in tqdm(instances.items(), total=len(instances), desc=f"Retrieving for {instances_file}"):
                try:
                    documents = Path(indexes_dir, instance['repo'].replace('/', '__'), instance['base_commit'], 'documents.jsonl')
                    if not documents.exists():
                        missing_indexes.append(instance_id)
                        continue
                    try:
                        with open(documents, 'r') as doc_f:
                            documents = [json.loads(line) for line in doc_f]
                    except Exception as e:
                        print(e)
                        traceback.print_exc()
                        failed_indexes.append(instance_id)
                        problem_files.append(documents.as_posix())
                        continue
                    index = multi_vector_index(documents)
                    issue_embs = get_chunked_embeddings(instance['problem_statement'], model, tokenizer)
                    hits = index.search(issue_embs, k=20)
                    results = {'instance_id': instance_id, 'hits': hits}
                    print(json.dumps(results), file=f, flush=True)
                except Exception as e:
                    print(e)
                    traceback.print_exc()
                    failed_indexes.append(instance_id)
        missing = '\n\t'.join(missing_indexes)
        failed = '\n\t'.join(failed_indexes)
        logger.info(f"Missing indexes for {len(missing_indexes)} instances.\n{missing}")
        logger.info(f"Failed indexes for {len(failed_indexes)} instances.\n{failed}")
        logger.info(f"Saved retrieval results to {output_file}")
        print(' '.join(problem_files))


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--instances_files", type=str, nargs='+', required=True, help="File containing instances"
    )
    parser.add_argument(
        "--indexes_dir",
        type=str,
        help="Directory where indexes are stored",
        required=True,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Directory where retrieval results are stored",
        required=True,
    )
    parser.add_argument(
        "--encoder", type=str, help="Encoder to use", required=True,
    )
    args = parser.parse_args()
    main(**vars(args))
