from dataclasses import dataclass, field
from typing import Dict, Optional, List
from pyserini.output_writer import get_output_writer, OutputFormat
from GMR_reranker import GMRInferenceModel
from jina_reranker import JinaRerankerInferenceModel
from Mono_Qwen2vl_reranker import MonoQwenInferenceModel
from Qwen3_reranker import Qwen3RerankerInferenceModel
import numpy as np
import torch
import os


@dataclass
class SearchResult:
    docid: str
    score: float


class Ranker(object):
    def __init__(self, model_name_or_path, inference_type='yes_or_no', max_length=3200):
        if 'jina' in model_name_or_path.lower():
            self.reranker = JinaRerankerInferenceModel(model_name_or_path,max_length=max_length)
        elif 'mono' in model_name_or_path.lower():
            self.reranker = MonoQwenInferenceModel(model_name_or_path,inference_type=inference_type,max_length=max_length)
        elif 'qwen3' in model_name_or_path.lower():
            self.reranker = Qwen3RerankerInferenceModel(model_name_or_path)
        else:
            self.reranker = GMRInferenceModel(model_name_or_path,inference_type=inference_type,max_length=max_length)

    def stop(self):
        self.reranker.stop()

    def rank(self, qids: List[str], queries: List[dict], dids_list: List[str], docs_list: List[dict[dict]], batch_size: int=32, topk: int = 100, result_save_path: str=None, skip_sameid=True):
        lls = [len(docs) for docs in docs_list]
        assert len(qids) == len(queries) == len(dids_list) == len(docs_list)
        pairs = []
        for query, docs in zip(queries, docs_list):
            pairs.extend([(query['text'], query['image'], doc['text'], doc['image']) for doc in docs[:topk]])
        scores = self.reranker.process(pairs, batch_size=batch_size)
        final_scores = []
        start_index = 0
        for ll in lls:
            step_len = min(topk,ll)
            final_scores.append(scores[start_index:start_index+step_len])
            start_index += step_len
        search_results = [
            (qid, sorted([SearchResult(did, score) for did, score in zip(dids, scores)], key = lambda x: x.score, reverse=True))
            for qid, dids, scores in zip(qids, dids_list, final_scores)
        ]
        if result_save_path is not None:
            self.save_result(search_results, result_save_path, qids, topk, skip_sameid=skip_sameid)
        return search_results
    
    def save_result(self, search_results, result_save_path: str, qids: list, topk: 100, skip_sameid=True):
        output_writer = get_output_writer(result_save_path, OutputFormat(OutputFormat.TREC.value), 'w',
                                          max_hits=topk, tag='Faiss', topics=qids,
                                          use_max_passage=False,
                                          max_passage_delimiter='#',
                                          max_passage_hits=1000)
        with output_writer:
            for topic, hits in search_results:
                if skip_sameid:
                    hits = [hit for hit in hits if hit.docid != topic]
                output_writer.write(topic, hits)
    
