import os
import ujson
import torch
import numpy as np
import tqdm

from colbert.search.index_loader import IndexLoader
from colbert.indexing.index_saver import IndexSaver
from colbert.indexing.collection_encoder import CollectionEncoder

from colbert.utils.utils import lengths2offsets, print_message, dotdict, flatten
from colbert.indexing.codecs.residual import ResidualCodec
from colbert.indexing.utils import optimize_ivf
from colbert.search.strided_tensor import StridedTensor
from colbert.modeling.checkpoint import Checkpoint
from colbert.utils.utils import print_message, batch
from colbert.data import Collection
from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings
from colbert.indexing.codecs.residual_embeddings_strided import (
    ResidualEmbeddingsStrided,
)
from colbert.indexing.utils import optimize_ivf

# For testing writing into new chunks, can set DEFAULT_CHUNKSIZE smaller (e.g. 1 or 2)
DEFAULT_CHUNKSIZE = 25000


class IndexUpdater:

    """
    IndexUpdater takes in a searcher and adds/remove passages from the searcher.
    A checkpoint for passage-encoding must be provided for adding passages.
    IndexUpdater can also persist the change of passages to the index on disk.

    Sample usage:

        index_updater = IndexUpdater(config, searcher, checkpoint)

        added_pids = index_updater.add(passages) # all passages added to searcher with their pids returned
        index_updater.remove(pids) # all pid within pids removed from searcher

        searcher.search() # the search now reflects the added & removed passages

        index_updater.persist_to_disk() # added & removed passages persisted to index on disk
        searcher.Searcher(index, config) # if we reload the searcher now from disk index, the changes we made persists

    """

    def __init__(self, config, searcher, checkpoint=None):
        self.config = config
        self.searcher = searcher
        self.index_path = searcher.index

        self.has_checkpoint = False
        if checkpoint:
            self.has_checkpoint = True
            self.checkpoint = Checkpoint(checkpoint, config)
            self.encoder = CollectionEncoder(config, self.checkpoint)

        self._load_disk_ivf()

        # variables to track removal / append of passages
        self.removed_pids = []
        self.first_new_emb = torch.sum(self.searcher.ranker.doclens).item()
        self.first_new_pid = len(self.searcher.ranker.doclens)

    def remove(self, pids):
        """
        Input:
            pids: list(int)
        Return: None

        Removes a list of pids from the searcher,
        these pids will no longer apppear in future searches with this searcher
        to erase passage data from index, call persist_to_disk() after calling remove()
        """
        print_message(f"#> Removing pids: {pids}...")
        self._remove_pid_from_ivf(pids)
        self.removed_pids.extend(pids)

    def add(self, passages):
        """
        Input:
            passages: list(string)
        Output:
            passage_ids: list(int)

        Adds new passages to the searcher,
        to add passages to the index, call persist_to_disk() after calling add()
        """
        if not self.has_checkpoint:
            raise ValueError(
                "No checkpoint was provided at IndexUpdater initialization."
            )

        # Find pid for the first added passage
        start_pid = len(self.searcher.ranker.doclens)
        curr_pid = start_pid

        # Extend doclens and embs of self.searcher.ranker
        embs, doclens = self.encoder.encode_passages(passages)
        compressed_embs = self.searcher.ranker.codec.compress(embs)

        # Update searcher
        # NOTE: For codes and residuals, the tensors end with padding of length 512,
        # hence we concatenate the new appendage in front of the padding
        self.searcher.ranker.embeddings.codes = torch.cat(
            (
                self.searcher.ranker.embeddings.codes[:-512],
                compressed_embs.codes,
                self.searcher.ranker.embeddings.codes[-512:],
            )
        )
        self.searcher.ranker.embeddings.residuals = torch.cat(
            (
                self.searcher.ranker.embeddings.residuals[:-512],
                compressed_embs.residuals,
                self.searcher.ranker.embeddings.residuals[-512:],
            ),
            dim=0,
        )

        self.searcher.ranker.doclens = torch.cat(
            (self.searcher.ranker.doclens, torch.tensor(doclens))
        )

        # Build partitions for each pid and update IndexUpdater's current ivf
        start = 0
        for doclen in doclens:
            end = start + doclen
            codes = compressed_embs.codes[start:end]
            partitions, _ = self._build_passage_partitions(codes)
            self._add_pid_to_ivf(partitions, curr_pid)

            start = end
            curr_pid += 1

        assert start == sum(doclens)

        # Update new ivf in searcher
        new_ivf_tensor = StridedTensor(
            self.curr_ivf, self.curr_ivf_lengths, use_gpu=False
        )
        assert new_ivf_tensor != self.searcher.ranker.ivf
        self.searcher.ranker.ivf = new_ivf_tensor

        # Rebuild StridedTensor within searcher
        self.searcher.ranker.embeddings_strided = ResidualEmbeddingsStrided(
            self.searcher.ranker.codec,
            self.searcher.ranker.embeddings,
            self.searcher.ranker.doclens,
        )

        print_message(f"#> Added {len(passages)} passages from pid {start_pid}.")
        new_pids = list(range(start_pid, start_pid + len(passages)))
        return new_pids

    def persist_to_disk(self):
        """
        Persist all previous stored changes in IndexUpdater to index on disk,
        changes include all calls to IndexUpdater.remove() and IndexUpdater.add()
        before persist_to_disk() is called.
        """

        print_message("#> Persisting index changes to disk")

        # Propagate all removed passages to disk
        self._load_metadata()
        for pid in self.removed_pids:
            self._remove_passage_from_disk(pid)

        # Propagate all added passages to disk
        # Rationale: keep record of all added passages in IndexUpdater.searcher,
        # divide passages into chunks and create / write chunks here

        self._load_metadata()  # Reload after removal

        # Calculate avg number of passages per chunk
        curr_num_chunks = self.metadata["num_chunks"]
        last_chunk_metadata = self._load_chunk_metadata(curr_num_chunks - 1)
        if curr_num_chunks == 1:
            avg_chunksize = DEFAULT_CHUNKSIZE
        else:
            avg_chunksize = last_chunk_metadata["passage_offset"] / (
                curr_num_chunks - 1
            )
        print_message(f"#> Current average chunksize is: {avg_chunksize}.")

        # Calculate number of additional passages we can write to the last chunk
        last_chunk_capacity = max(
            0, avg_chunksize - last_chunk_metadata["num_passages"]
        )
        print_message(
            f"#> The last chunk can hold {last_chunk_capacity} additional passages."
        )

        # Find the first and last passages to be persisted
        pid_start = self.first_new_pid
        emb_start = self.first_new_emb
        pid_last = len(self.searcher.ranker.doclens)
        emb_last = (
            emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:]).item()
        )

        # First populate the last chunk
        if last_chunk_capacity > 0:
            pid_end = min(pid_last, pid_start + last_chunk_capacity)
            emb_end = (
                emb_start
                + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item()
            )

            # Write to last chunk
            self._write_to_last_chunk(pid_start, pid_end, emb_start, emb_end)
            pid_start = pid_end
            emb_start = emb_end

        # Then create new chunks to hold the remaining added passages
        while pid_start < pid_last:
            pid_end = min(pid_last, pid_start + avg_chunksize)
            emb_end = (
                emb_start
                + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item()
            )

            # Write new chunk with id = curr_num_chunks
            self._write_to_new_chunk(
                curr_num_chunks, pid_start, pid_end, emb_start, emb_end
            )

            curr_num_chunks += 1
            pid_start = pid_end
            emb_start = emb_end

        assert pid_start == pid_last
        assert emb_start == emb_last

        # Update metadata
        print_message("#> Updating metadata for added passages...")
        self.metadata["num_chunks"] = curr_num_chunks
        self.metadata["num_embeddings"] = torch.sum(self.searcher.ranker.doclens).item()
        metadata_path = os.path.join(self.index_path, "metadata.json")
        with open(metadata_path, "w") as output_metadata:
            ujson.dump(self.metadata, output_metadata)

        # Save current IVF to disk
        optimized_ivf_path = os.path.join(self.index_path, "ivf.pid.pt")
        torch.save((self.curr_ivf, self.curr_ivf_lengths), optimized_ivf_path)
        print_message(f"#> Persisted updated IVF to {optimized_ivf_path}")

    # HELPER FUNCTIONS BELOW

    def _load_disk_ivf(self):
        print_message(f"#> Loading IVF...")

        if os.path.exists(os.path.join(self.index_path, "ivf.pid.pt")):
            ivf, ivf_lengths = torch.load(
                os.path.join(self.index_path, "ivf.pid.pt"), map_location="cpu"
            )
        else:
            assert os.path.exists(os.path.join(self.index_path, "ivf.pt"))
            ivf, ivf_lengths = torch.load(
                os.path.join(self.index_path, "ivf.pt"), map_location="cpu"
            )
            ivf, ivf_lengths = optimize_ivf(ivf, ivf_lengths, self.index_path)

        self.curr_ivf = ivf
        self.curr_ivf_lengths = ivf_lengths

    def _load_metadata(self):
        with open(os.path.join(self.index_path, "metadata.json")) as f:
            self.metadata = ujson.load(f)

    def _load_chunk_doclens(self, chunk_idx):
        doclens = []

        print_message("#> Loading doclens...")

        with open(os.path.join(self.index_path, f"doclens.{chunk_idx}.json")) as f:
            chunk_doclens = ujson.load(f)
            doclens.extend(chunk_doclens)

        doclens = torch.tensor(doclens)
        return doclens

    def _load_chunk_codes(self, chunk_idx):
        codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt")
        return torch.load(codes_path, map_location="cpu")

    def _load_chunk_residuals(self, chunk_idx):
        residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt")
        return torch.load(residuals_path, map_location="cpu")

    def _load_chunk_metadata(self, chunk_idx):
        with open(os.path.join(self.index_path, f"{chunk_idx}.metadata.json")) as f:
            chunk_metadata = ujson.load(f)
        return chunk_metadata

    def _get_chunk_idx(self, pid):
        for i in range(self.metadata["num_chunks"]):
            chunk_metadata = self._load_chunk_metadata(i)
            if (
                chunk_metadata["passage_offset"] <= pid
                and chunk_metadata["passage_offset"] + chunk_metadata["num_passages"]
                > pid
            ):
                return i
        raise ValueError("Passage ID out of range")

    def _remove_pid_from_ivf(self, pids):
        # Helper function for IndexUpdater.remove()

        new_ivf = []
        new_ivf_lengths = []
        runner = 0
        pids = set(pids)

        # Construct mask of where pids to be removed appear in ivf
        mask = torch.isin(self.curr_ivf, torch.tensor(list(pids)))
        indices = mask.nonzero()

        # Calculate end-indices of each centroid section in ivf
        section_end_indices = []
        c = 0
        for length in self.curr_ivf_lengths.tolist():
            c += length
            section_end_indices.append(c)

        # Record the number of pids removed from each centroid section
        removed_len = [0 for _ in range(len(section_end_indices))]
        j = 0
        for ind in indices:
            while ind >= section_end_indices[j]:
                j += 1
            removed_len[j] += 1

        # Update changes
        new_ivf = torch.masked_select(self.curr_ivf, ~mask)
        new_ivf_lengths = self.curr_ivf_lengths - torch.tensor(removed_len)

        new_ivf_tensor = StridedTensor(new_ivf, new_ivf_lengths, use_gpu=False)
        assert new_ivf_tensor != self.searcher.ranker.ivf
        self.searcher.ranker.ivf = new_ivf_tensor

        self.curr_ivf = new_ivf
        self.curr_ivf_lengths = new_ivf_lengths

    def _build_passage_partitions(self, codes):
        # Helper function for IndexUpdater.add()
        # Return a list of ordered, unique centroid ids from codes of a passage
        codes = codes.sort()
        ivf, values = codes.indices, codes.values
        partitions, ivf_lengths = values.unique_consecutive(return_counts=True)
        return partitions, ivf_lengths

    def _add_pid_to_ivf(self, partitions, pid):
        """
        Helper function for IndexUpdater.add()

        Input:
            partitions: list(int), centroid ids of the passage
            pid: int, passage id
        Output: None

        Adds the pid of new passage into the ivf.
        """
        new_ivf = []
        new_ivf_lengths = []
        old_ivf = self.curr_ivf.tolist()
        old_ivf_lengths = self.curr_ivf_lengths.tolist()

        partitions_runner = 0
        ivf_runner = 0
        for i in range(len(old_ivf_lengths)):
            # First copy existing partition pids to new ivf
            new_ivf.extend(old_ivf[ivf_runner : ivf_runner + old_ivf_lengths[i]])
            new_ivf_lengths.append(old_ivf_lengths[i])
            ivf_runner += old_ivf_lengths[i]

            # Add pid if partition_index i is in the passage's partitions
            if (
                partitions_runner < len(partitions)
                and i == partitions[partitions_runner]
            ):
                new_ivf.append(pid)
                new_ivf_lengths[-1] += 1
                partitions_runner += 1

        assert ivf_runner == len(old_ivf)
        assert sum(new_ivf_lengths) == len(new_ivf)

        # Replace the current ivf with new_ivf
        self.curr_ivf = torch.tensor(new_ivf)
        self.curr_ivf_lengths = torch.tensor(new_ivf_lengths)

    def _write_to_last_chunk(self, pid_start, pid_end, emb_start, emb_end):
        # Helper function for IndexUpdater.persist_to_disk()

        print_message(f"#> Writing {pid_end - pid_start} passages to the last chunk...")
        num_chunks = self.metadata["num_chunks"]

        # Append to current last chunk
        curr_embs = ResidualEmbeddings.load(self.index_path, num_chunks - 1)
        curr_embs.codes = torch.cat(
            (curr_embs.codes, self.searcher.ranker.embeddings.codes[emb_start:emb_end])
        )
        curr_embs.residuals = torch.cat(
            (
                curr_embs.residuals,
                self.searcher.ranker.embeddings.residuals[emb_start:emb_end],
            )
        )
        path_prefix = os.path.join(self.index_path, f"{num_chunks - 1}")
        curr_embs.save(path_prefix)

        # Update doclen of last chunk
        curr_doclens = self._load_chunk_doclens(num_chunks - 1).tolist()
        curr_doclens.extend(self.searcher.ranker.doclens.tolist()[pid_start:pid_end])
        doclens_path = os.path.join(self.index_path, f"doclens.{num_chunks - 1}.json")
        with open(doclens_path, "w") as output_doclens:
            ujson.dump(curr_doclens, output_doclens)

        # Update metadata of last chunk
        chunk_metadata = self._load_chunk_metadata(num_chunks - 1)
        chunk_metadata["num_passages"] += pid_end - pid_start
        chunk_metadata["num_embeddings"] += emb_end - emb_start
        chunk_metadata_path = os.path.join(
            self.index_path, f"{num_chunks - 1}.metadata.json"
        )
        with open(chunk_metadata_path, "w") as output_chunk_metadata:
            ujson.dump(chunk_metadata, output_chunk_metadata)

    def _write_to_new_chunk(self, chunk_idx, pid_start, pid_end, emb_start, emb_end):
        # Helper function for IndexUpdater.persist_to_disk()

        # Save embeddings to new chunk
        curr_embs = ResidualEmbeddings(
            self.searcher.ranker.embeddings.codes[emb_start:emb_end],
            self.searcher.ranker.embeddings.residuals[emb_start:emb_end],
        )
        path_prefix = os.path.join(self.index_path, f"{chunk_idx}")
        curr_embs.save(path_prefix)

        # Create doclen json file for new chunk
        curr_doclens = self.searcher.ranker.doclens.tolist()[pid_start:pid_end]
        doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json")
        with open(doclens_path, "w+") as output_doclens:
            ujson.dump(curr_doclens, output_doclens)

        # Create metadata json file for new chunk
        chunk_metadata = {
            "passage_offset": pid_start,
            "num_passages": pid_end - pid_start,
            "embedding_offset": emb_start,
            "num_embeddings": emb_end - emb_start,
        }
        chunk_metadata_path = os.path.join(
            self.index_path, f"{chunk_idx}.metadata.json"
        )
        with open(chunk_metadata_path, "w+") as output_chunk_metadata:
            ujson.dump(chunk_metadata, output_chunk_metadata)

    def _remove_passage_from_disk(self, pid):
        # Helper function for IndexUpdater.persist_to_disk()

        chunk_idx = self._get_chunk_idx(pid)

        chunk_metadata = self._load_chunk_metadata(chunk_idx)
        i = pid - chunk_metadata["passage_offset"]
        doclens = self._load_chunk_doclens(chunk_idx)
        codes, residuals = (
            self._load_chunk_codes(chunk_idx),
            self._load_chunk_residuals(chunk_idx),
        )

        # Remove embeddings from codes and residuals
        start = sum(doclens[:i])
        end = start + doclens[i]
        codes = torch.cat((codes[:start], codes[end:]))
        residuals = torch.cat((residuals[:start], residuals[end:]))

        codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt")
        residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt")

        torch.save(codes, codes_path)
        torch.save(residuals, residuals_path)

        # Change doclen for passage to 0
        doclens = doclens.tolist()
        doclen_to_remove = doclens[i]
        doclens[i] = 0
        doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json")
        with open(doclens_path, "w") as output_doclens:
            ujson.dump(doclens, output_doclens)

        # Modify chunk_metadata['num_embeddings'] for chunk_idx
        chunk_metadata["num_embeddings"] -= doclen_to_remove
        chunk_metadata_path = os.path.join(
            self.index_path, f"{chunk_idx}.metadata.json"
        )
        with open(chunk_metadata_path, "w") as output_chunk_metadata:
            ujson.dump(chunk_metadata, output_chunk_metadata)

        # Modify chunk_metadata['embedding_offset'] for all later chunks (minus num_embs_removed)
        for idx in range(chunk_idx + 1, self.metadata["num_chunks"]):
            metadata = self._load_chunk_metadata(idx)
            metadata["embedding_offset"] -= doclen_to_remove
            metadata_path = os.path.join(self.index_path, f"{idx}.metadata.json")
            with open(metadata_path, "w") as output_chunk_metadata:
                ujson.dump(metadata, output_chunk_metadata)

        # Modify num_embeddings in overall metadata (minus num_embs_removed)
        self.metadata["num_embeddings"] -= doclen_to_remove
        metadata_path = os.path.join(self.index_path, "metadata.json")
        with open(metadata_path, "w") as output_metadata:
            ujson.dump(self.metadata, output_metadata)
