import numpy as np
import re
import torch
import logging
import concurrent.futures
from omegaconf import OmegaConf
import os, warnings

from colbert.modeling.colbert import AugmentationMixin

from src.cmuvera import FdeLateInteractionModel
from src.dataloader import get_dataloader
from src.embedder import ColBERTEmbedder
from src.greedymethods import BaseE2E
from src.utils import partial_chamfer_sim, partial_chamfer_sim_batched_with_rerank, rowwise_union_padded, save
import pickle
from tqdm import tqdm

# from torch_scatter import scatter

import faiss
from multiprocessing import Pool

from multiset import Multiset as multiset

logger = logging.getLogger(__name__)


##############################################################################
# Base class that uses colbert/other MV engine for searching
class MVBaseE2E(BaseE2E):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.dataloader = get_dataloader(self.config.data)
        self.embedder = ColBERTEmbedder(config.embedder)
        assert config.embedder.mv_type in ["colbertv2-base","colbertv2-plaid","xtr-base","colbertv2-plaid-fresh_hp"]
        if self.config.augment:
            if self.config.dbl_norm:
                self.mv_type = config.embedder.mv_type + "/dblnorm" # norm done twice for while augmentation
            else:
                self.mv_type = config.embedder.mv_type + "/norm"  # Default COLBERT behaviour to norm embeds before I/P
        else: 
            self.mv_type = config.embedder.mv_type + "/norm" # Default COLBERT behaviour 

    def _init_searcher(self):
        # If the index is generated on a server A and copied over to server B, then
        # during search, the configuration that was saved with the index on server A
        # will be loaded on server B, alongwith the current config being passed.
        # One example is the `collection` property of the searcher which may come from the index
        # if left unspecified here. This is a problem mainly for the LoTTE dataloader.
        # We need to specify it manually here to avoid cross-server issues.
        corpus_tsv_filename, _ = self.dataloader.get_tsv()

        if self.config.augment:
            # initialise i
            self.searcher = []
            
            #TODO: XXXX 
            if self.config.parallelization == "proc":
                return

            for i in range(self.config.num_rh_augment):
                # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{self.config.data.dataset_name}_aug{self.config.num_rh_augment}.{i}")):
                with Run().context(RunConfig(nranks=1, experiment=f"{self.config.data.dataset_name}/{self.config.embedder.type}/{self.mv_type}")):
                    colbert_config=ColBERTConfig(
                        nbits=self.config.colbert.nbits,
                        root="./colbert_beir_expts/",
                        augment=self.config.augment,
                        dim=2*self.config.embedder.emb_dim,
                        # embedder_type=self.config.embedder.type, NOTE: turned out unnecessary for now
                        dbl_norm=self.config.dbl_norm,
                        RH_file=f"./experiments/{self.config.data.dataset_name}/RH.{self.config.embedder.emb_dim}.{i}.pt",
                        generate_new_rh = False, #RH_file is supposed to be generated by the index function
                    )
                    Run().print(f"DEBUG: XXXX {self.config.data.dataset_name}/{self.config.embedder.type}/{self.mv_type}")
                    self.searcher.append(Searcher(
                        checkpoint="ColBERT/colbertv2.0",
                        index=f"nbits={self.config.colbert.nbits}.aug{self.config.num_rh_augment}.{i}",
                        config=colbert_config,
                        collection=corpus_tsv_filename,
                    ))
    
        else: # singule colbert index
            # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{self.config.data.dataset_name}")):
            with Run().context(RunConfig(nranks=1, experiment=f"{self.config.data.dataset_name}/{self.config.embedder.type}/{self.mv_type}")):
                colbert_config=ColBERTConfig(
                    nbits=self.config.colbert.nbits,
                    root="./colbert_beir_expts/",
                    augment=self.config.augment,
                    dim=self.config.embedder.emb_dim,
                )
                self.searcher = Searcher(
                    checkpoint="ColBERT/colbertv2.0",
                    # index=f"{self.mv_type}.{self.config.data.dataset_name}_nbits={self.config.colbert.nbits}.noaug",
                    index=f"nbits={self.config.colbert.nbits}.noaug",
                    collection=corpus_tsv_filename,
                    config=colbert_config,
                )
    
    def search(self, qembs,k):
        pass
    
    def run(self):
        pass
            
    # this function creates the index files required by colbert. the filenames etc are taken from config
    # Note that upon running index on the patched colbert : 
    # Document is broken into chunks ( behaviour of original colbert )
    # 128 dim embeddings are generated -> this is dumped to disk by our patch ( if the files are already present this is skipped )
    # zeroing out last dim / full augmentation happens (RH file is created or loaded -> please modify the config appropriately). The index is created on the modified embeddings
    def index(self):
        
        # TO enable parallel trigger for each RH augmentation during indexing. 
        if self.config.augment:
            assert self.config.num_rh_augment > 0
            if self.config.rh_num>=0:
                assert self.config.rh_num <= self.config.num_rh_augment
                self.config.rh_list = [self.config.rh_num]
            else:
                self.config.rh_list = list(range(self.config.num_rh_augment))
                
        corpus_tsv_filename, _ = self.dataloader.get_tsv()
        if self.config.augment:
            # for i in range(self.config.num_rh_augment):
            for i in self.config.rh_list:     
                try:
                    # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{self.config.data.dataset_name}_aug{self.config.num_rh_augment}.{i}")):
                    with Run().context(RunConfig(nranks=1, experiment=f"{self.config.data.dataset_name}/{self.config.embedder.type}/{self.mv_type}")):
                        colbert_config = ColBERTConfig(
                            nbits=self.config.colbert.nbits,
                            root="./colbert_beir_expts/",
                            augment=self.config.augment,
                            dim=2*self.config.embedder.emb_dim,
                            # embedder_type=self.config.embedder.type,  NOTE: turned out unnecessary for now
                            dbl_norm=self.config.dbl_norm,
                            RH_file=f"./experiments/{self.config.data.dataset_name}/RH.{self.config.embedder.emb_dim}.{i}.pt",
                            generate_new_rh = self.config.generate_new_rh
                        )
                        os.makedirs(f"./experiments/{self.config.data.dataset_name}/", exist_ok=True)
                        assert 2 * colbert_config.lin_dim == colbert_config.dim
                        indexer = Indexer(checkpoint="ColBERT/colbertv2.0", config=colbert_config)
                        indexer.index(name=f"nbits={self.config.colbert.nbits}.aug{self.config.num_rh_augment}.{i}",
                                      collection=corpus_tsv_filename,
                                      overwrite=self.config.overwrite_index
                                      )
                except AssertionError as e:
                    print(f"Assertion error: {e}")
                    continue
        else: # single colbert index
            # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{self.config.data.dataset_name}")):
            with Run().context(RunConfig(nranks=1, experiment=f"{self.config.data.dataset_name}/{self.config.embedder.type}/{self.mv_type}")):
                colbert_config = ColBERTConfig(
                    nbits=self.config.colbert.nbits,
                    root="./colbert_beir_expts/",
                    augment=self.config.augment,
                    dim=self.config.embedder.emb_dim
                )
                assert colbert_config.lin_dim == colbert_config.dim
                indexer = Indexer(checkpoint="ColBERT/colbertv2.0", config=colbert_config)
                indexer.index(name=f"nbits={self.config.colbert.nbits}.noaug",
                              collection=corpus_tsv_filename,
                              overwrite=self.config.overwrite_index
                              )

### ColBERT requires a Query object to be passed to the searcher. Here i use a dummy query object that just has the two methods defined used in the searcher
class DummyQueryForColbert:
    def __init__(self,size):
        self.size = size
    def keys(self):
        return range(self.size)
    def provenance(self):
        return f"dummy_provenance_sized{self.size}"


class MuveraTopK(MVBaseE2E):
    def __init__(self, config):
        super().__init__(config)

        # Load the index
        self.index_path = f"./experiments/{self.config.data.dataset_name}/muvera_index_{self.config.embedder.emb_dim}.index"
        if not self.config.index:
            try:
                self.index = faiss.read_index(self.index_path)
            except Exception as e:
                logger.error(f"Failed to read index from {self.index_path}: {e}")
            else:
                logger.info(f"Loaded index from {self.index_path} successfully.")

        self.muvera_query_gen = FdeLateInteractionModel(self.config.muvera.num_repetitions,
                                                        self.config.muvera.num_simhash_projections,
                                                        self.config.muvera.projection_dimension,
                                                        self.config.muvera.final_projection_dimension)

    def index(self):
        self.index = faiss.IndexFlatIP(self.config.muvera.final_projection_dimension)

        # Load embeddings directly from disk
        parent_path = f"./experiments/{self.config.data.dataset_name}/BERT/corpus/compressed_muvera_full_{self.config.embedder.emb_dim}"
        ffs = [x for x in os.listdir(parent_path) if x.endswith(".pkl")]

        # Sort using batch_id and minibatch_id extracted from the filename
        sorted_files = sorted(
            ffs,
            key=lambda name: tuple(map(int, re.findall(r'\d+', name)))
        )

        embs_all = []
        for ff in tqdm(sorted_files, desc="Loading Muvera embeddings"):
            file_path = os.path.join(parent_path, ff)
            embed_dict = torch.load(file_path, map_location='cpu', weights_only=False)
            embs = embed_dict['embs_muvera']

            assert embs.shape[1] == self.config.muvera.final_projection_dimension, \
                f"Embedding dimension mismatch: {embs.shape[1]} != {self.config.muvera.final_projection_dimension}"

            embs_all.append(embs)

        # Safe to do even for larger datasets as the embeddings are single-vector
        embs_all = np.vstack(embs_all)
        self.index.add(embs_all)

        faiss.write_index(self.index, self.index_path)

    def run(self):
        # XXX: copied from colbert iid. refactor later.
        result_file_path = f"./pickles/results/{self.config.embedder.type}/muvera_iid_{self.config.data.dataset_name}_k{self.config.k}.pkl"

        self.embedder.embed_full_dataset(self.dataloader,mode=self.config.embedder.mode)
        qembs, qmasks = self.embedder.qembs, self.embedder.qmasks

        result_ids = self.search(qembs, qmasks, k=self.config.k)
        chamfer_scores = []
        print(f"result ids shape : {result_ids.shape}")

        if self.config.embedder.mode =="mem":
            # TODO: Compute chamfer scores here.
            ## compute real scores with opts
            cembs, cmasks = self.embedder.get_corpus(result_ids)
        else:
            corpus = self.embedder.get_corpus(result_ids)
            logger.info("All required documents are loaded")
            logger.info(f"Corpus shape : {corpus.idx.shape}")
            cembs = []
            cmasks = []
            for q_id in tqdm(range(qembs.shape[0]), desc="Processing queries"):
                cemb, cmask = corpus[
                    (q_id * torch.ones(result_ids.shape[1], dtype=torch.long, device=corpus.device)),
                    torch.arange(result_ids.shape[1], dtype=torch.long, device=corpus.device)
                ]

                out = partial_chamfer_sim(qembs[q_id][qmasks[q_id].bool()], cemb, cmask, device=qembs.device, bs=1024)
                out_sum = out.sum(dim=0)
                sorted_inds = torch.argsort(out_sum, descending=True)
                sorted_out = out[:, sorted_inds]
                res = torch.cummax(sorted_out, dim=1)[0].sum(dim=0)
                chamfer_scores.append(res)

                result_ids[q_id] = result_ids[q_id][sorted_inds]

                cembs.append(cemb)
                cmasks.append(cmask)
            cembs = torch.stack(cembs)
            cmasks = torch.stack(cmasks)
            chamfer_scores = torch.stack(chamfer_scores)

        # Common code to disk and mem mode
        ## qembs: 300,64,128, qmasks: 300,64
        ## cembs: 300,100,330,128, cmasks: 300,100,330

        ## qembs is already masked out, and zeroed out
        # dp = torch.einsum("abc,adec->adbe",qembs,cembs.to(qembs.device))
        # masked = torch.where(cmasks.bool().to(qembs.device).unsqueeze(2),dp,-10)
        # partials = torch.cummax(torch.amax(masked,dim=3),dim=1)[0]
        # result_scores = torch.sum(partials,dim=2)

        result_scores = chamfer_scores

        save((result_ids, result_scores), result_file_path)
        return result_ids, result_scores

    def search(self, qembs, qmasks, k):
        # Muvera FDE generation

        try:
            # XXX: Below line is source of bug if embedder.num_queries is set in config.
            qembs_muvera = np.load(f"./pickles/muvera_query_embs_{self.config.data.dataset_name}.npy")
        except Exception as e:
            logger.error(f"Failed to load Muvera query embeddings: {e}")
            logger.info("Generating Muvera query embeddings...")

            qembs_muvera = []
            for qidx, qemb in tqdm(enumerate(qembs), desc=f"Converting query embeddings to muvera"):
                qmask = qmasks[qidx]
                qemb = qemb[qmask]  # Filter out padded tokens

                fde_clean = self.muvera_query_gen.encode_single_item(qemb)
                qembs_muvera.append(fde_clean)

            qembs_muvera = np.vstack(qembs_muvera)
            logger.info("Muvera query embeddings generated successfully. Saving to file...")
            np.save(f"./pickles/muvera_query_embs_{self.config.data.dataset_name}.npy", qembs_muvera)

        logger.info("Searching Muvera index for top-k results...")
        start = time.time()

        batch_size = 512
        all_I = []
        logger.info(f"k = {k}")

        for i in range(0, len(qembs_muvera), batch_size):
            batch = qembs_muvera[i:i + batch_size]
            _, I = self.index.search(batch.astype(np.float32), k)
            all_I.append(I)
            logger.info(f"Processed batch {i // batch_size + 1} / {int(np.ceil(len(qembs_muvera)/batch_size))}")

        ids = torch.from_numpy(np.vstack(all_I))
        end = time.time()
        logger.info(f"Muvera search completed in {end - start:.2f} seconds.")

        # XXX: some rows in ids can have -1 values incase FAISS was not able to find enough neighbors
        return ids


## Just topk with no augmentation
class ColBERTBaseline(MVBaseE2E):
    def __init__(self, config):
        super().__init__(config)
        self.variety = f"{self.mv_type}_base_n{self.config.colbert.nbits}_d{self.config.embedder.emb_dim}"

    def search(self,qembs,qmasks,k):
        queries = DummyQueryForColbert(len(qembs))
        ranking = self.searcher.search_all_modified(queries,Q=qembs,M=qmasks ,k=k, normalize_q = self.config.colbert.normalize_q).todict()
        
        res = []
        for i in range(len(qembs)):
            ids, _, _ = zip(*ranking[i])
            res.append(ids) 
            
        max_len = max([len(val) for val in res])
        logger.info(f"Max length of retrieved ids comes out to: {max_len}. Average length of merged ids is {sum([len(val) for val in res])/len(res)}. Otherwise would have been {self.config.colbert_topk}")
        ## TODO:tensorize -- just initialise the tensor to -1 then do torch.where ==-1 tensor[:,:1]
        result_tensor = torch.zeros((len(qembs), max_len), dtype=torch.int64)
        for i in range(len(qembs)):
            current_set = sorted(list(res[i]))
            current_set.extend([current_set[0]] * (max_len - len(current_set)))
            result_tensor[i] = torch.tensor(current_set)
        
        return result_tensor
    
    def run(self):
        self._init_searcher()
        result_file_path = f"./pickles/results/{self.config.embedder.type}/{self.variety}_{self.config.data.dataset_name}_k{self.config.k}{self.suffix}.pkl"
        
        self.embedder.embed_full_dataset(self.dataloader,mode=self.config.embedder.mode)
        qembs, qmasks = self.embedder.qembs, self.embedder.qmasks
        # YYYY TODO : add an assert here perhaps
        # YYYY TODO 2 : we might need batching for larger query sizes

        result_ids = self.search(qembs, qmasks, k=self.config.k)
        chamfer_scores = []
        print(f"result ids shape : {result_ids.shape}")
        if self.config.embedder.mode =="mem":
            # TODO: Compute Chamfer scores here.
            ## compute real scores with opts
            cembs, cmasks = self.embedder.get_corpus(result_ids)
        else:
            corpus = self.embedder.get_corpus(result_ids)
            logger.info("All required documents are loaded")
            cembs = []
            cmasks = []
            for q_id in tqdm(range(qembs.shape[0]), desc="Processing queries"):
                cemb, cmask = corpus[
                    (q_id * torch.ones(result_ids.shape[1], dtype=torch.long, device=corpus.device)),
                    torch.arange(result_ids.shape[1], dtype=torch.long, device=corpus.device)
                ]
                out = partial_chamfer_sim(qembs[q_id][qmasks[q_id].bool()], cemb, cmask, device=qembs.device, bs=1024)
                out_sum = out.sum(dim=0)
                sorted_inds = torch.argsort(out_sum, descending=True)
                sorted_out = out[:, sorted_inds]
                res = torch.cummax(sorted_out, dim=1)[0].sum(dim=0)
                chamfer_scores.append(res)

                result_ids[q_id] = result_ids[q_id][sorted_inds]

                cembs.append(cemb)
                cmasks.append(cmask)
            cembs = torch.stack(cembs)
            cmasks = torch.stack(cmasks)
            chamfer_scores = torch.stack(chamfer_scores)

        # Common code to disk and mem mode
        ## qembs: 300,64,128, qmasks: 300,64
        ## cembs: 300,100,330,128, cmasks: 300,100,330

        ## qembs is already masked out, and zeroed out
        # dp = torch.einsum("abc,adec->adbe",qembs,cembs.to(qembs.device))
        # masked = torch.where(cmasks.bool().to(qembs.device).unsqueeze(2),dp,-10)
        # partials = torch.cummax(torch.amax(masked,dim=3),dim=1)[0]
        # result_scores = torch.sum(partials,dim=2)

        result_scores = chamfer_scores
        
        save((result_ids, result_scores), result_file_path)
        return result_ids, result_scores


def search_muvera_worker_for_th(qembs_muvera, k, index, rh):
    logger.info(f"Searching Muvera index {rh} for top-k results...")
    start = time.time()

    batch_size = 512
    all_I = []
    logger.info(f"k = {k}")

    for i in range(0, len(qembs_muvera), batch_size):
        batch = qembs_muvera[i:i + batch_size]
        _, I = index.search(batch.astype(np.float32), k)
        all_I.append(I)
        logger.info(f"Processed batch {i // batch_size + 1} / {int(np.ceil(len(qembs_muvera)/batch_size))}")

    ids = np.vstack(all_I)
    end = time.time()
    logger.info(f"Muvera search on index {rh} completed in {end - start:.2f} seconds.")
    return ids

## MUVERA on augmented embeddings
class MuveraAugmented(MVBaseE2E, AugmentationMixin):
    def __init__(self, config):
        super().__init__(config)
        # XXX: not used directly
        self.variety = f"{self.mv_type}_aug_n{self.config.colbert.nbits}_d{self.config.embedder.emb_dim}_rh{self.config.num_rh_augment}"
        # Hacks to use AugmentationMixin
        self.RH = None

        self.indices = {i: None for i in range(self.config.num_rh_augment)}
        for i in tqdm(range(self.config.num_rh_augment), desc="Loading Muvera aug indices"):
            index_path = self._get_index_path(i)

            try:
                # Loading 8 indices can be very slow, so use IO_FLAG_MMAP.
                self.indices[i] = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
            except Exception as e:
                logger.error(f"Failed to read index from {index_path}: {e}")
                logger.info("Please run the indexing step for MuveraAugmented first.")
                break

        self.muvera_query_gen = FdeLateInteractionModel(self.config.muvera.num_repetitions,
                                                        self.config.muvera.num_simhash_projections,
                                                        self.config.muvera.projection_dimension,
                                                        self.config.muvera.final_projection_dimension)

    def _get_index_path(self, rh):
        return f"./experiments/{self.config.data.dataset_name}/muvera_index_{self.config.embedder.emb_dim}_rh{rh}.index"

    def augment_and_muvera_encode(self, qembs, qmasks):
        with torch.no_grad():
            aug_qembs = []

            # TODO: Bug here on query-side augmentation. self.RH has to be set to None on every rh iteration.
            for i in range(self.config.num_rh_augment):
                RH_file = f"./experiments/{self.config.data.dataset_name}/RH.{self.config.embedder.emb_dim}.{i}.pt"
                augmented_qembs = self._RH_augmentation_query(qembs, RH_file=RH_file, generate_new_rh=False)
                augmented_qembs = torch.nn.functional.normalize(augmented_qembs, p=2, dim=2)
                augmented_qembs = augmented_qembs.half()
                aug_qembs.append(augmented_qembs)

            augmented_qembs = torch.cat(aug_qembs, dim=0)
            # 8 * |C| x seq_len
            # masks are the single source of truth for valid/non-validness of
            # query tokens. So can't run an assertion based on checking if
            # query token embeddings are fully zeroed out or not. This is done
            # on corpus side however.
            qmasks = qmasks.repeat(self.config.num_rh_augment, 1)
            assert qmasks.shape[0] == augmented_qembs.shape[0], \
                f"qmasks shape {qmasks.shape} does not match qembs shape {qembs.shape} after augmentation"

        try:
            qembs_muvera = np.load(f"./pickles/muvera_aug_query_embs_{self.config.data.dataset_name}.npy")
        except Exception as e:
            logger.error(f"Failed to load Muvera augmented query embeddings: {e}")
            logger.info("Generating Muvera augmented query embeddings...")

            qembs_muvera = []
            for qidx, qemb in enumerate(tqdm(augmented_qembs, desc=f"Converting augmented query embeddings to muvera")):
                qmask = qmasks[qidx]
                qemb = qemb[qmask]  # Filter out padded tokens

                fde_aug = self.muvera_query_gen.encode_single_item(qemb)
                qembs_muvera.append(fde_aug)

            qembs_muvera = np.vstack(qembs_muvera)

            logger.info("Muvera augmented query embeddings generated successfully. Saving to file...")
            np.save(f"./pickles/muvera_aug_query_embs_{self.config.data.dataset_name}.npy", qembs_muvera)

        return qembs_muvera

    def index(self):
        for i in range(self.config.num_rh_augment):
            self.indices[i] = faiss.IndexFlatIP(self.config.muvera.final_projection_dimension)

        # Load embeddings directly from disk
        parent_path = f"./experiments/{self.config.data.dataset_name}/BERT/corpus/compressed_muvera_aug_{self.config.embedder.emb_dim}"
        ffs = [x for x in os.listdir(parent_path) if x.endswith(".pkl")]

        # Sort using batch_id and minibatch_id extracted from the filename
        sorted_files = sorted(
            ffs,
            key=lambda name: tuple(map(int, re.findall(r'\d+', name)))
        )

        embs_map = {i: [] for i in range(self.config.num_rh_augment)}
        for ff in tqdm(sorted_files, desc="Loading Muvera embeddings"):

            file_path = os.path.join(parent_path, ff)
            embed_dict = torch.load(file_path, map_location='cpu', weights_only=False)
            embs = embed_dict['embs_muvera_aug']

            assert embs.shape[1] == self.config.muvera.final_projection_dimension, \
                f"Embedding dimension mismatch: {embs.shape[1]} != {self.config.muvera.final_projection_dimension}"

            num_rh = self.config.num_rh_augment
            embs_per_rh = embs.shape[0] // num_rh  # should be exact

            for rh in range(num_rh):
                start = rh * embs_per_rh
                end = (rh + 1) * embs_per_rh
                embs_map[rh].append(embs[start:end])

        # Index and write
        for key in tqdm(embs_map, desc="Writing indices"):
            index_path = self._get_index_path(key)
            merged_embs = np.vstack(embs_map[key])
            self.indices[key].add(merged_embs)
            faiss.write_index(self.indices[key], index_path)
            logger.info(f"Index for RH {key} written to {index_path}")

    def run(self):
        # XXX: Copied. Refactor later.
        result_file_path = f"./pickles/results/{self.config.embedder.type}/muvera_aug_{self.config.data.dataset_name}_k{self.config.k}.pkl"
        
        self.embedder.embed_full_dataset(self.dataloader,mode=self.config.embedder.mode)
        
        qembs, qmasks = self.embedder.qembs, self.embedder.qmasks
        optvec = -torch.ones_like(qmasks,dtype=qembs.dtype) * 2 # Just to be safe
        opt_scores = torch.zeros((qmasks.size(0),self.config.k),dtype=torch.float32)
        opt_inds = -torch.ones((qmasks.size(0),self.config.k),dtype=torch.int64)

        # you can add batching here
        with torch.inference_mode():
            for i in tqdm(range(self.config.k)):
                logger.info(f"Running iteration {i+1}/{self.config.k}")
                qembs[:,:,-1] = optvec

                inds = self.search(qembs, qmasks, k=self.config.colbert_topk)
                logger.info(f"Search Done - iter {i+1}/{self.config.k}")

                if self.config.embedder.mode =="mem":
                    cembs, cmasks = self.embedder.get_corpus(inds)
                    # with open(f"./mem_corpus_iter_{i}.pkl", "wb") as f:
                    #     pickle.dump((cembs, cmasks), f)
                   
                    max_sim_partial, max_sim_indices, max_sim_scores =  (
                        qembs, qmasks, optvec.unsqueeze(-1), cembs, cmasks
                    )
                    real_indices = inds[torch.arange(inds.size(0)), max_sim_indices]
                    optvec = torch.maximum(optvec, max_sim_partial.to(optvec.device))
                    opt_scores[:,i].copy_(max_sim_scores)
                    opt_inds[:,i].copy_(real_indices)
                else:
                    corpus = self.embedder.get_corpus(inds)
                    logger.info("All required documents are loaded")
                    # cembs = []
                    # cmasks = []
                    for q_id in tqdm(range(qembs.shape[0]), desc="Processing queries"):
                        cemb, cmask = corpus[
                            (q_id * torch.ones(inds.shape[1], dtype=torch.long, device=corpus.device)),
                            torch.arange(inds.shape[1], dtype=torch.long, device=corpus.device)
                        ]

                        max_sim_partial, max_sim_indices, max_sim_scores = partial_chamfer_sim_batched_with_rerank(
                            qembs[q_id].unsqueeze(0), qmasks[q_id].unsqueeze(0), optvec.unsqueeze(-1)[q_id].unsqueeze(0), cemb.unsqueeze(0), cmask.unsqueeze(0)
                        )
                        real_indices = inds[q_id, max_sim_indices.cpu()]
                        optvec[q_id] = torch.maximum(optvec[q_id], max_sim_partial.squeeze(0).to(optvec.device))
                        opt_scores[q_id,i] = max_sim_scores
                        opt_inds[q_id,i] = real_indices
                    
                    # cembs = torch.stack(cembs)
                    # cmasks = torch.stack(cmasks)
                    # with open(f"./disk_corpus_iter_{i}.pkl", "wb") as f:
                    #     pickle.dump((cembs, cmasks), f)
            
        save((opt_inds,opt_scores), result_file_path)
        return opt_inds, opt_scores

    ## non tensorized loop 
    def search(self, qembs, qmasks, k):
        qembs_muvera = self.augment_and_muvera_encode(qembs, qmasks)

        num_rh = self.config.num_rh_augment
        qembs_per_rh = qembs_muvera.shape[0] // num_rh  # should be exact

        if self.config.parallelization == "thread":
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [
                    executor.submit(search_muvera_worker_for_th, qembs_muvera[rh * qembs_per_rh: (rh + 1) * qembs_per_rh], k, self.indices[rh], rh)
                    for rh in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "none":
            all_ids = [search_muvera_worker_for_th(qembs_muvera[rh * qembs_per_rh, (rh + 1) * qembs_per_rh], k, self.indices[rh], rh) for rh in range(self.config.num_rh_augment)]
        else:
            raise ValueError(f"Unknown parallelization method: {self.config.parallelization}")

        result = rowwise_union_padded(all_ids, self.config.num_rh_augment, self.config.colbert_topk)
        return torch.from_numpy(result)


### helper worker functions for multiprocessing/multithreading
def search_worker_for_mp(i, config, queries, qembs, qmasks, k):
    with torch.inference_mode():
        # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{config.data.dataset_name}")):
        with Run().context(RunConfig(nranks=1, experiment=f"{config.data.dataset_name}/{config.embedder.type}/{config.embedder.mv_type + ('_dblnorm' if config.dbl_norm else '_norm')}")):
            colbert_config=ColBERTConfig(
                nbits=config.colbert.nbits,
                root="./colbert_beir_expts/",
                augment=True,
                dim=2*config.embedder.emb_dim,
                # embedder_type=config.embedder.type,  NOTE: turned out unnecessary for now
                dbl_norm=config.dbl_norm,
                RH_file=f"./experiments/{config.data.dataset_name}/RH.{config.embedder.emb_dim}.{i}.pt",
                generate_new_rh = False # RH_file is supposed to be generated by the index function
            )
            return Searcher(
                checkpoint="./ColBERT/colbertv2.0",
                index=f"nbits={config.colbert.nbits}.aug{config.num_rh_augment}.{i}",
                config=colbert_config,
            ).search_all_modified(queries, qembs, qmasks, k, normalize_q = config.colbert.normalize_q).todict()
            
def search_worker_for_th(searcher, queries, qembs, qmasks, k):
    with torch.inference_mode():
        return searcher.search_all_modified(queries, qembs, qmasks, k, normalize_q = config.colbert.normalize_q).todict()

### Augmented implementation
class LatePoolDISCo(MVBaseE2E):
    def __init__(self, config):
        super().__init__(config)
        self.variety = f"{self.mv_type}_aug_n{self.config.colbert.nbits}_d{self.config.embedder.emb_dim}_rh{self.config.num_rh_augment}"
    
    ## non tensorized loop 
    def search(self,qembs, qmasks, k):
        queries = DummyQueryForColbert(len(qembs))
        if self.config.parallelization == "proc":
            
            with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
                futures = [
                executor.submit(search_worker_for_mp, i, self.config, queries, qembs, qmasks, k)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "thread":
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [
                executor.submit(search_worker_for_th, self.searcher[i], queries, qembs, qmasks, k)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "none":
            all_ids = [self.searcher[i].search_all_modified(queries, qembs,qmasks,  k, normalize_q = self.config.colbert.normalize_q).todict() for i in range(self.config.num_rh_augment)]
        else:
            raise ValueError(f"Unknown parallelization method: {self.config.parallelization}")
        
        merged_ids = {}
        for key in range(len(qembs)):
            merged_ids[key] = set()
            for ids_dict in all_ids:
                ids, _, _ = zip(*ids_dict[key])
                merged_ids[key] = merged_ids[key].union(set(ids))
                
        max_len = max([len(val) for val in merged_ids.values()])
        logger.info(f"Max length of merged ids comes out to: {max_len}. Average length of merged ids is {sum([len(val) for val in merged_ids.values()])/len(merged_ids)}. Otherwise would have been {self.config.colbert_topk*self.config.num_rh_augment}")
        
        ## TODO:tensorize -- just initialise the tensor to -1 then do torch.where ==-1 tensor[:,:1]
        result = []
        for key in range(len(qembs)):
            current_set = sorted(list(merged_ids[key]))
            current_set.extend([current_set[0]] * (max_len - len(current_set)))
            result.append(current_set)
        # convert merged_ids to the required format
        return torch.tensor(result)
    
    def run(self):
        self._init_searcher()
        result_file_path = f"./pickles/results/{self.config.embedder.type}/{self.variety}_{self.config.data.dataset_name}_k{self.config.k}_rerankperh{self.config.colbert_topk}{self.suffix}.pkl"
        load_state = False
        
        self.embedder.embed_full_dataset(self.dataloader,mode=self.config.embedder.mode)
        
        qembs, qmasks = self.embedder.qembs, self.embedder.qmasks
        # YYYY TODO : add an assert here perhaps
        # YYYY TODO 2 : we might need batching for larger query sizes
        optvec = -torch.ones_like(qmasks,dtype=qembs.dtype) * 2 # Just to be safe
        opt_scores = torch.zeros((qmasks.size(0),self.config.k),dtype=torch.float32)
        opt_inds = -torch.ones((qmasks.size(0),self.config.k),dtype=torch.int64)
        # you can add batching here
        with torch.inference_mode():
            for i in tqdm(range(self.config.k)):
                logger.info(f"Running iteration {i+1}/{self.config.k}")
                qembs[:,:,-1] = optvec

                # YYYY TODO: save stats?

                if self.config.dbl_norm:
                    file_path = f"./colbert_ids_dbl_norm/colbert_ids_{self.config.data.dataset_name}_{i}.pkl"
                else:
                    file_path = f"./colbert_ids_single_norm/colbert_ids_{self.config.data.dataset_name}_{i}.pkl"
                print(file_path)

                # if os.path.exists(file_path):
                if load_state and os.path.exists(file_path):
                    print("Loading indices from file...")
                    with open(file_path, "rb") as f:
                        inds = pickle.load(f)
                else:
                    print("File not found. Running search...")
                    # TODO: Only this changes for MUVERA, rest can be copy pasted.
                    # The search process is the same.
                    # embedder.load_full_dataset and embedder.get_corpus will perhaps change (include MUVERA encodings)
                    inds = self.search(qembs, qmasks, k=self.config.colbert_topk)
                    # if self.config.embedder.mode=="mem":
                    #     with open(f"./mem_query_iter_{i}.pkl", "wb") as f:
                    #         pickle.dump((qembs, qmasks), f)
                    #     with open(f"./mem_inds_iter_{i}.pkl", "wb") as f:
                    #         pickle.dump(inds, f)
                    # else:
                    #     with open(f"./disk_query_iter_{i}.pkl", "wb") as f:
                    #         pickle.dump((qembs, qmasks), f)
                    #     with open(f"./disk_inds_iter_{i}.pkl", "wb") as f:
                    #         pickle.dump(inds, f)
                    with open(file_path, "wb") as f:
                        pickle.dump(inds, f)
                    
                logger.info(f"Search Done - iter {i+1}/{self.config.k}")

                if self.config.embedder.mode =="mem":
                    cembs, cmasks = self.embedder.get_corpus(inds)
                    # with open(f"./mem_corpus_iter_{i}.pkl", "wb") as f:
                    #     pickle.dump((cembs, cmasks), f)
                   
                    max_sim_partial, max_sim_indices, max_sim_scores = partial_chamfer_sim_batched_with_rerank(
                        qembs, qmasks, optvec.unsqueeze(-1), cembs, cmasks
                    )
                    real_indices = inds[torch.arange(inds.size(0)), max_sim_indices]
                    optvec = torch.maximum(optvec, max_sim_partial.to(optvec.device))
                    opt_scores[:,i].copy_(max_sim_scores)
                    opt_inds[:,i].copy_(real_indices)
                else:
                    corpus = self.embedder.get_corpus(inds)
                    logger.info("All required documents are loaded")
                    # cembs = []
                    # cmasks = []
                    for q_id in tqdm(range(qembs.shape[0]), desc="Processing queries"):
                        cemb, cmask = corpus[
                            (q_id * torch.ones(inds.shape[1], dtype=torch.long, device=corpus.device)),
                            torch.arange(inds.shape[1], dtype=torch.long, device=corpus.device)
                        ]

                        max_sim_partial, max_sim_indices, max_sim_scores = partial_chamfer_sim_batched_with_rerank(
                            qembs[q_id].unsqueeze(0), qmasks[q_id].unsqueeze(0), optvec.unsqueeze(-1)[q_id].unsqueeze(0), cemb.unsqueeze(0), cmask.unsqueeze(0)
                        )
                        real_indices = inds[q_id, max_sim_indices.cpu()]
                        optvec[q_id] = torch.maximum(optvec[q_id], max_sim_partial.squeeze(0).to(optvec.device))
                        opt_scores[q_id,i] = max_sim_scores
                        opt_inds[q_id,i] = real_indices
                    
                    # cembs = torch.stack(cembs)
                    # cmasks = torch.stack(cmasks)
                    # with open(f"./disk_corpus_iter_{i}.pkl", "wb") as f:
                    #     pickle.dump((cembs, cmasks), f)
            
        save((opt_inds,opt_scores), result_file_path)
        return opt_inds, opt_scores
    
## LatePool that makes use of threshold based selection at the pooling step.
class LatePoolDISCo_thresholded(LatePoolDISCo):
    def __init__(self, config):
        super().__init__(config)
        self.variety = f"{self.mv_type}_aug_n{self.config.colbert.nbits}_d{self.config.embedder.emb_dim}_rh{self.config.num_rh_augment}_threshold{self.config.threshold}"
        self.threshold = self.config.threshold
        
    def search(self, qembs, qmasks, k):
        queries = DummyQueryForColbert(len(qembs))
        
        if self.config.parallelization == "proc":
            
            with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
                futures = [
                executor.submit(search_worker_for_mp, i, self.config, queries, qembs, qmasks, k)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "thread":
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [
                executor.submit(search_worker_for_th, self.searcher[i], queries, qembs, qmasks, k)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "none":
            all_ids = [self.searcher[i].search_all_modified(queries, qembs, qmasks, k, normalize_q = self.config.colbert.normalize_q).todict() for i in range(self.config.num_rh_augment)]
        else:
            raise ValueError(f"Unknown parallelization method: {self.config.parallelization}")
        
        max_len = 0
        merged_ids = {}
        for key in range(len(qembs)):
            merged_ids[key] = multiset({})
            for ids_dict in all_ids:
                ids, _, _ = zip(*ids_dict[key])
                merged_ids[key].update(set(ids))
            logger.debug(f"Query {key}: Unique elements = {len(merged_ids[key].distinct_elements())}, Total elements = {sum(merged_ids[key].multiplicities())}, Maximum count = {max(merged_ids[key].multiplicities())}")
        if self.threshold==1:
            # ins
            # YYYY TODO: save stats differently?
            save(merged_ids, f"{self.variety}_{self.config.data.dataset_name}_b{self.config.colbert_topk}.pkl") 
        for key in range(len(qembs)):
            merged_ids[key] = [elem for elem, count in merged_ids[key].items() if count >= self.threshold]
            max_len = max(max_len, len(merged_ids[key])) 
        average_len = sum([len(val) for val in merged_ids.values()]) / len(merged_ids)
        
        
        
        logger.debug(f"Max length of merged ids after threshold {self.threshold} comes out to: {max_len}. Average length of merged ids after threshold {self.threshold} is {average_len}. Otherwise would have been {self.config.colbert_topk*self.config.num_rh_augment}")           
        
        ## TODO:tensorize -- just initialise the tensor to -1 then do torch.where ==-1 tensor[:,:1]
        result = []
        for key in range(len(qembs)):
            current_set = sorted(list(merged_ids[key]))
            current_set.extend([current_set[0]] * (max_len - len(current_set)))
            result.append(current_set)
        # convert merged_ids to the required format
        return torch.tensor(result)
    
class DISCo(LatePoolDISCo):
    def __init__(self, config):
        super().__init__(config)
        self.variety = f"{self.mv_type}_int_n{self.config.colbert.nbits}_d{self.config.embedder.emb_dim}_rh{self.config.num_rh_augment}_int{config.disco.rerank_internal}_ext{config.disco.rerank_external}"
        assert config.disco.rerank_external or config.disco.rerank_internal
        
    def _init_base_searcher(self):
        mv = self.mv_type.split("/")[0]+"/norm"

        # If the index is generated on a server A and copied over to server B, then
        # during search, the configuration that was saved with the index on server A
        # will be loaded on server B, alongwith the current config being passed.
        # One example is the `collection` property of the searcher which may come from the index
        # if left unspecified here. This is a problem mainly for the LoTTE dataloader.
        # We need to specify it manually here to avoid cross-server issues.
        corpus_tsv_filename, _ = self.dataloader.get_tsv()
        with Run().context(RunConfig(nranks=1, experiment=f"{self.config.data.dataset_name}/{self.config.embedder.type}/{mv}")):
            colbert_config=ColBERTConfig(
                nbits=self.config.colbert.nbits,
                root="./colbert_beir_expts/",
                augment=self.config.augment,
                dim=self.config.embedder.emb_dim,
            )
            self.base_searcher = Searcher(
                checkpoint="ColBERT/colbertv2.0",
                # index=f"{self.mv_type}.{self.config.data.dataset_name}_nbits={self.config.colbert.nbits}.noaug",
                index=f"nbits={self.config.colbert.nbits}.noaug",
                collection=corpus_tsv_filename,
                config=colbert_config,
            )
    def run(self):
        self._init_searcher() # Searchers over augmented indices
        self._init_base_searcher() # Searcher over base index
        result_file_path = f"./pickles/results/{self.config.embedder.type}/{self.variety}_{self.config.data.dataset_name}_k{self.config.k}_rerankperh{self.config.colbert_topk}{self.suffix}.pkl"
        
        self.embedder.embed_full_dataset(self.dataloader,mode=self.config.embedder.mode)
        
        qembs, qmasks = self.embedder.qembs, self.embedder.qmasks
        # YYYY TODO : add an assert here perhaps
        # YYYY TODO 2 : we might need batching for larger query sizes
        optvec = -torch.ones_like(qmasks,dtype=qembs.dtype) * 2 # Just to be safe
        opt_scores = torch.zeros((qmasks.size(0),self.config.k),dtype=torch.float32)
        opt_inds = -torch.ones((qmasks.size(0),self.config.k),dtype=torch.int64)
        # you can add batching here
        with torch.inference_mode():
            for i in tqdm(range(self.config.k)):
                logger.info(f"Running iteration {i+1}/{self.config.k}")
                # qembs[:,:,-1] = optvec # THIS STEP IS NOT NEEDED FOR INTERNAL
                # k = self.config.colbert_topk if self.config.disco.rerank_external else 1
                inds = self.search(qembs, k=self.config.colbert_topk, opt_vec = optvec)
                # YYYY TODO: save stats?
                # save(inds,"debug_inds.pkl")
                logger.info(f"Search Done - iter {i+1}/{self.config.k}")
        
                if self.config.embedder.mode=="mem":
                    cembs, cmasks = self.embedder.get_corpus(inds)
                else:
                    corpus = self.embedder.get_corpus(inds)
                    logger.info("All required documents are loaded")
                    cembs = []
                    cmasks = []
                    for q_id in tqdm(range(qembs.shape[0]), desc="Processing queries"):
                        cemb, cmask = corpus[
                            (q_id * torch.ones(inds.shape[1], dtype=torch.long, device=corpus.device)),
                            torch.arange(inds.shape[1], dtype=torch.long, device=corpus.device)
                        ]

                        cembs.append(cemb)
                        cmasks.append(cmask)
                    cembs = torch.stack(cembs)
                    cmasks = torch.stack(cmasks)
                
                max_sim_partial, max_sim_indices, max_sim_scores = partial_chamfer_sim_batched_with_rerank(
                    qembs, qmasks, optvec.unsqueeze(-1), cembs, cmasks
                )
                real_indices = inds[torch.arange(inds.size(0)), max_sim_indices.cpu()]
                optvec = torch.maximum(optvec, max_sim_partial.to(optvec.device))
                opt_scores[:,i].copy_(max_sim_scores)
                opt_inds[:,i].copy_(real_indices)
            
        save((opt_inds,opt_scores), result_file_path)
        return opt_inds, opt_scores
    
    def search(self, qembs, k, opt_vec):
        
        ## picked up from inside colbert
        
        all_scored_pids = [
            list(
                self._search_helper(
                        qembs[query_idx:query_idx+1],
                        opt_vec[query_idx:query_idx+1],
                        k
                    )
                )
            for query_idx in tqdm(range(len(qembs)))
        ]
        max_len = max([len(val[0]) for val in all_scored_pids])
        logger.info(f"Max length of merged ids comes out to: {max_len}. Average length of merged ids is {sum([len(val) for val in all_scored_pids])/len(all_scored_pids)}")
        result_ids = []
        for i in range(len(qembs)):
            current_set = list(all_scored_pids[i][0])
            current_set.extend([current_set[0]] * (max_len - len(current_set)))
            result_ids.append(current_set)
        return torch.tensor(result_ids)
        
        
    
    ## picked up from dense_search in colbert searcher
    def _search_helper(self, query,opt_vec, k):
        query2 = query.clone()
        query2[:,:,-1] = opt_vec
        ## the "augmented" indices are only used for candidate gen
        aggregated_ids, aggregated_scores = self._aggregate_over_indices(
                                    self._parallel_search(query2, k),
                                    prune_candidates=self.config.disco.prune_candidates
                            )
        
        if (not self.config.disco.rerank_internal):
            return aggregated_ids, aggregated_scores
        ## rerank over the base index
        pids, scores = self.base_searcher.rank_modified(query, opt_vec, pids=aggregated_ids)
        ## "127" dim embeddings 
        
        return pids[:k], list(range(1, k+1)), scores[:k]
    
    ## pid, pid_centroid_score (estimate) for each index (0-7)
    ## pid_centroid_score.shape = #pids x |Q| 
    ## [\max_w for each q]-> pointwise maximum over dim=1 for same pid (scatter op)
    ### sim(D,Q) = \max_w sim(aug(D,w),aug(Q,w))
    
    ### doc_ids from each w (may) be different
    ### using a scatter op
    
    
    def _aggregate_over_indices(self,args, prune_candidates):
        ## what size?
        ## pids_i, scores_i ->  from i^th index (/8)
        ## we bring centroid scores to the same place, so that colbert_score_reduce can be used
        ## then we can threshold based on this score
        
        ids, id_centroid_scores, codes_lengths = zip(*args)
        ids = torch.cat(ids, dim=0)
        id_centroid_scores = torch.cat(id_centroid_scores, dim=0)
        
        argsort_for_ids = torch.argsort(ids, dim=0)
        sorted_ids = ids[argsort_for_ids]
        
        unique_ids, inverse_unique, counts = torch.unique(sorted_ids, return_inverse=True, return_counts=True) # should use unique_consecutive perhaps
        
        logger.debug(f"Unique ids: {unique_ids.size(0)}")
        if not prune_candidates:
            ## No need to prune candidates, so no need to compute the agg. score estimates
            return unique_ids, id_centroid_scores
        
        codes_lengths = torch.cat(codes_lengths, dim=0).to(self.device)
        new_code_lengths = codes_lengths[argsort_for_ids]
        cum_code_lengths = torch.cumsum(codes_lengths, dim=0)
        sorted_cum_code_lengths = torch.cumsum(new_code_lengths, dim=0)       
        
        aggregated_code_lengths = torch.zeros_like(unique_ids, dtype=new_code_lengths.dtype)
        aggregated_code_lengths.scatter_add_(0, inverse_unique, new_code_lengths)
        
            
        all_scores = torch.empty_like(id_centroid_scores)
        new_start = sorted_cum_code_lengths - new_code_lengths
        old_end = cum_code_lengths[argsort_for_ids]
        old_start = old_end-new_code_lengths
        
        for i in range(len(sorted_cum_code_lengths)): ## TODO: Good tensorisation?
            all_scores[new_start[i]:sorted_cum_code_lengths[i]] = id_centroid_scores[old_start[i]:old_end[i]]
          
        ## ideal "tensorised splicing"  
        ## all_scores[sorted_cum_code_lengths - new_code_length : sorted_cum_code_lengths] = id_centroid_scores[cum_code_lengths[argsort_for_ids]-new_code_length : cum_code_lengths[argsort_for_ids]]
        
        approx_scores_strided = StridedTensor(all_scores, aggregated_code_lengths, use_gpu=True)
        approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
        approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask,self.searcher[0].config)
        # pick average over id lengths?
        ## For now we heuristically go for something similar to the other colbert modification
        threshold_ndocs = config.colbert_topk*config.num_rh_augment #config.ndocs//4 ## TODO:decide/CLEAN
        
        if threshold_ndocs < len(approx_scores):
            unique_ids = unique_ids[torch.topk(approx_scores, k=threshold_ndocs).indices]
        
        logger.debug(f"Unique ids after pruning: {unique_ids.size(0)}")
        return unique_ids, None # TODO: here the scores are not important - you should change the signature of the function rtype, and everywhere ahead
    
        
    
    def _parallel_search(self, qembs,k):
        queries = DummyQueryForColbert(len(qembs))
        
        if self.config.parallelization == "proc":
            
            with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
                futures = [
                executor.submit(candidate_worker_for_mp, i, self.config, qembs, k, self.config.disco.prune_candidates)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "thread":
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [
                executor.submit(candidate_worker_for_th, self.searcher[i], qembs, k, self.config.disco.prune_candidates)
                for i in range(self.config.num_rh_augment)
                ]
                all_ids = [future.result() for future in concurrent.futures.as_completed(futures)]
        elif self.config.parallelization == "none":
            all_ids = [self.searcher[i].gen_candidates(qembs, k, self.config.disco.prune_candidates) for i in range(self.config.num_rh_augment)]
        else:
            raise ValueError(f"Unknown parallelization method: {self.config.parallelization}")
        return all_ids
        
def candidate_worker_for_mp(i, config, qembs, k, prune_candidates):
    with torch.inference_mode():
        # with Run().context(RunConfig(nranks=1, experiment=f"colbert_{config.data.dataset_name}")):
        with Run().context(RunConfig(nranks=1, experiment=f"{config.data.dataset_name}/{config.embedder.type}/{config.embedder.mv_type + ('_dblnorm' if config.dbl_norm else '_norm')}")):
            colbert_config=ColBERTConfig(
                nbits=config.colbert.nbits,
                root="./colbert_beir_expts/",
                augment=True,
                dim=2*config.embedder.emb_dim,
                # embedder_type=config.embedder.type,  NOTE: turned out unnecessary for now
                dbl_norm=config.dbl_norm,
                RH_file=f"./experiments/{config.data.dataset_name}/RH.{config.embedder.emb_dim}.{i}.pt",
                generate_new_rh = False # RH_file is supposed to be generated by the index function
            )
            return Searcher(
                checkpoint="./ColBERT/colbertv2.0",
                index=f"nbits={config.colbert.nbits}.aug{config.num_rh_augment}.{i}",
                config=colbert_config,
            ).gen_candidates(qembs, k, prune_candidates)
            
def candidate_worker_for_th(searcher, qembs, k, prune_candidates):
    with torch.inference_mode():
        return searcher.gen_candidates(qembs, k, prune_candidates)
  
################################################################################

if __name__=="__main__":

    warnings.filterwarnings("ignore", category=FutureWarning)
    
    
    config = OmegaConf.load("configs/colbert.yaml")
    cliconfig = OmegaConf.from_cli()
    
    conf = OmegaConf.merge(config, cliconfig)
    os.makedirs("logs/colbert", exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(process)d - %(message)s',
        handlers=[
            logging.FileHandler(f'logs/colbert/{conf.method}_{conf.data.dataset_name}_{conf.retriever.type}_pid:{os.getpid()}.log'),
            logging.StreamHandler()
        ]
    )
    logger.info(conf)

    ## IMPORTANT: switching between plaid and base
    if conf.embedder.mv_type == "colbertv2-plaid":
        from colbert.infra import Run, RunConfig, ColBERTConfig
        from colbert import Indexer, Searcher
    elif conf.embedder.mv_type == "colbertv2-base":
        from colbert_base.infra import Run, RunConfig, ColBERTConfig
        from colbert_base import Indexer, Searcher

    if conf.method == "muvera_iid":
        assert not conf.augment
        logging.info(f"Running MUVERA iid on ColBERT embeddings with FAISS")
        obj = MuveraTopK(conf)

    elif conf.method == "muvera_augmented":
        assert conf.augment
        logging.info(f"Running MUVERA augmented on ColBERT embeddings with FAISS")
        obj = MuveraAugmented(conf)

    elif conf.method == "baseline":
        assert conf.augment == False
        logging.info(f"Running ColBERT with baseline")
        obj = ColBERTBaseline(conf)
    elif conf.method == "latepool":
        assert conf.augment == True
        logging.info(f"Running LatePool DISCo")
        obj = LatePoolDISCo_thresholded(conf)
    elif conf.method == "disco":
        assert conf.augment == True
        logging.info(f"Running DISCo")
        obj = DISCo(conf)
    if conf.index:
        import time
        start = time.time()
        print("Starting Indexing")
        obj.index()
        end = time.time()
        print("Done in ", end-start)
    else:
        import time
        start = time.time()
        print("Starting Run")
        
        obj.run()
        end = time.time()
        
        print("Done in ", end-start)     