from abc import ABC, abstractmethod
import argparse
import numpy as np
from typing import Tuple
from tqdm import tqdm
from sklearn.decomposition import PCA
import time, wandb
import random
import os
from src.embeddings.memmap_dataset import MultiMemmapDatasetLoader
from src.models.nonlinear_mapping import nonlinear_mapping
from src.models.procrustes_mapping import procrustes_mapping_torch, procrustes_no_norm_scale_with_param, procrustes_pca_mapping
from loguru import logger
class VectorMapper(ABC):
    def __init__(self):
        pass
    @abstractmethod
    def fit(self, source_embeddings: np.ndarray, target_embeddings: np.ndarray,
              reference_indices: np.ndarray) -> None:
        pass
    @abstractmethod
    def transform(self, embeddings: np.ndarray) -> np.ndarray:
        pass
    def fit_multi(self, multi_dataloader: MultiMemmapDatasetLoader) -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support streaming training via fit_multi(). "
            f"Either implement fit_multi() in the subclass or use fit() with numpy arrays."
        )
    def stream_transform(
        self, 
        dataloader: MultiMemmapDatasetLoader, 
        cache_path: str,
        save_expert_ids: bool = False,
        use_majority_expert: bool = False
    ) -> None:
        logger.info(f"Stream transforming embeddings to {cache_path}")
        import tempfile
        import os
        temp_file = cache_path + '.tmp'
        output_memmap = None
        expert_ids_memmap = None
        processed_samples = 0
        total_samples = dataloader.total_samples
        output_dim = None
        output_dtype = None
        supports_expert_ids = save_expert_ids and hasattr(self, 'get_expert_ids')
        supports_majority_expert = use_majority_expert and hasattr(self, 'transform_with_majority_expert')
        if supports_expert_ids:
            expert_ids_file = cache_path.replace('.npy', '_expert_ids.npy')
            expert_ids_temp_file = expert_ids_file + '.tmp'
            logger.info(f"Will save expert IDs to {expert_ids_file}")
        majority_expert_id = None
        if supports_majority_expert:
            logger.info("=" * 60)
            logger.info("[MAJORITY EXPERT MODE] Pre-analyzing data to find majority expert...")
            logger.info("=" * 60)
            all_expert_ids = []
            for src_batch, _ in dataloader:
                src_np = src_batch.cpu().numpy()
                batch_expert_ids = self.get_expert_ids(src_np)
                all_expert_ids.append(batch_expert_ids)
            all_expert_ids = np.concatenate(all_expert_ids, axis=0)
            expert_counts = np.bincount(all_expert_ids, minlength=self.num_experts)
            majority_expert_id = int(np.argmax(expert_counts))
            majority_count = expert_counts[majority_expert_id]
            majority_ratio = majority_count / len(all_expert_ids)
            logger.info("\n📊 Expert Distribution Analysis:")
            logger.info(f"{'Expert ID':<12} {'Count':<15} {'Percentage':<12}")
            logger.info("-" * 45)
            for exp_id in range(self.num_experts):
                count = expert_counts[exp_id]
                percentage = count / len(all_expert_ids) * 100
                marker = " ⭐ SELECTED" if exp_id == majority_expert_id else ""
                logger.info(f"{exp_id:<12} {count:<15} {percentage:>6.2f}%{marker}")
            logger.info("-" * 45)
            logger.info(f"✓ Majority Expert ID: {majority_expert_id}")
            logger.info(f"✓ Coverage: {majority_count}/{len(all_expert_ids)} ({majority_ratio*100:.2f}%)")
            logger.info(f"✓ Will override: {len(all_expert_ids) - majority_count} samples ({(1-majority_ratio)*100:.2f}%)")
            logger.info("=" * 60)
            if supports_expert_ids:
                expert_ids_memmap = all_expert_ids
            dataloader_iter = iter(dataloader)
        for src_batch, tgt_batch in dataloader:
            src_np = src_batch.cpu().numpy()
            if supports_majority_expert:
                if majority_expert_id in self.experts:
                    transformed_batch = self.experts[majority_expert_id].transform(src_np)
                else:
                    logger.warning(f"Majority expert {majority_expert_id} not found, using expert 0")
                    transformed_batch = self.experts[0].transform(src_np)
            elif supports_expert_ids:
                try:
                    transformed_batch = self.transform(src_np)
                    batch_expert_ids = self.get_expert_ids(src_np)
                except Exception as e:
                    logger.warning(f"Failed to get expert IDs: {e}, continuing without expert tracking")
                    transformed_batch = self.transform(src_np)
                    supports_expert_ids = False
            else:
                transformed_batch = self.transform(src_np)
            if output_memmap is None:
                output_dim = transformed_batch.shape[1]
                output_dtype = transformed_batch.dtype
                output_memmap = np.memmap(
                    temp_file,
                    dtype=output_dtype,
                    mode='w+',
                    shape=(total_samples, output_dim)
                )
                logger.info(f"Created temporary memmap file: shape=({total_samples}, {output_dim}), dtype={output_dtype}")
                if supports_expert_ids and not supports_majority_expert:
                    expert_ids_memmap = np.memmap(
                        expert_ids_temp_file,
                        dtype=np.int32,
                        mode='w+',
                        shape=(total_samples,)
                    )
                    logger.info(f"Created expert IDs memmap file: shape=({total_samples},), dtype=int32")
            batch_size = transformed_batch.shape[0]
            end_idx = min(processed_samples + batch_size, total_samples)
            if end_idx > processed_samples:
                output_memmap[processed_samples:end_idx] = transformed_batch[:end_idx - processed_samples]
                if supports_expert_ids and not supports_majority_expert:
                    expert_ids_memmap[processed_samples:end_idx] = batch_expert_ids[:end_idx - processed_samples]
                processed_samples = end_idx
            if processed_samples >= total_samples:
                break
        if output_memmap is not None:
            output_memmap.flush()
            del output_memmap
            logger.info(f"Converting temporary memmap to standard .npy format...")
            temp_memmap = np.memmap(temp_file, mode='r', shape=(total_samples, output_dim), dtype=output_dtype)
            np.save(cache_path, temp_memmap)
            if os.path.exists(temp_file):
                os.remove(temp_file)
            logger.info(f"✓ Saved as standard .npy file: {cache_path}, shape=({total_samples}, {output_dim})")
        if supports_expert_ids and expert_ids_memmap is not None:
            if supports_majority_expert:
                logger.info(f"Saving natural expert IDs (before majority override)...")
                np.save(expert_ids_file, expert_ids_memmap)
                logger.info(f"✓ Saved natural expert IDs: {expert_ids_file}, shape=({len(expert_ids_memmap)},)")
            else:
                expert_ids_memmap.flush()
                del expert_ids_memmap
                logger.info(f"Converting expert IDs memmap to standard .npy format...")
                expert_ids_temp_memmap = np.memmap(expert_ids_temp_file, mode='r', shape=(total_samples,), dtype=np.int32)
                np.save(expert_ids_file, expert_ids_temp_memmap)
                if os.path.exists(expert_ids_temp_file):
                    os.remove(expert_ids_temp_file)
                logger.info(f"✓ Saved expert IDs: {expert_ids_file}, shape=({total_samples},)")
        if supports_majority_expert:
            logger.info("=" * 60)
            logger.info(f"✓ Majority Expert Transform completed")
            logger.info(f"  All {processed_samples} samples transformed using Expert {majority_expert_id}")
            logger.info("=" * 60)
        logger.info(f"Stream transform completed. Saved {processed_samples} samples to {cache_path}")
class BaseMapper:
    def __init__(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, args, pca_corpus_emb_1=None, pca_corpus_emb_2=None):
        self.corpus_emb_1 = corpus_emb_1
        self.corpus_emb_2 = corpus_emb_2
        self.args = args
        self.pca_corpus_emb_1 = pca_corpus_emb_1
        self.pca_corpus_emb_2 = pca_corpus_emb_2
    def map(self, overlap_ids, bound_ids):
        raise NotImplementedError("This method should be overridden by subclasses")
    def calculate_beta(self, transformed_candidate, bound_ids):
        beta = np.linalg.norm(transformed_candidate - self.corpus_emb_2[bound_ids], axis=1)
        return beta
    def calculate_alpha(self, overlap_ids, alpha):
        alpha_array = np.zeros(self.corpus_emb_1.shape[0])
        alpha_array[overlap_ids] = alpha
        return alpha_array
    def check_replacement(self, all_replaced_ids, D1):
        if set(all_replaced_ids) != set(D1):
            raise ValueError("Not all rows of corpus_emb_1_transformed have been replaced.")
class NonlinearMapper(BaseMapper):
    def map(self, overlap_ids, bound_ids):
        return nonlinear_mapping(self.corpus_emb_1, self.corpus_emb_2, overlap_ids, self.corpus_emb_1[bound_ids], self.corpus_emb_2[bound_ids])
class CombMapper(BaseMapper):
    def map(self, overlap_ids, bound_ids):
        return self.corpus_emb_1[bound_ids], None
class ProcrustesMapper(BaseMapper):
    def map(self, overlap_ids, bound_ids):
        return procrustes_mapping_torch(
            self.corpus_emb_1, 
            self.corpus_emb_2, 
            overlap_ids, 
            self.corpus_emb_1[bound_ids], 
            self.corpus_emb_2[bound_ids], 
            approximate=self.args.approximate, 
            q=self.args.q, 
            with_rotation=self.args.with_rotation
        ), None
class TranslationOnlyMapper(BaseMapper):
    def map(self, overlap_ids, bound_ids):
        assert self.args.cluster_method == "ours", "only for our cluster method, the first overlap embedding is the nearest to the bound embedding"
        nearest_overlap_embedding_index = overlap_ids[0]
        translation_vector = self.corpus_emb_2[nearest_overlap_embedding_index] - self.corpus_emb_1[nearest_overlap_embedding_index]
        transformed_corpus_emb_1 = self.corpus_emb_1[bound_ids] + translation_vector
        return transformed_corpus_emb_1, None
class OursMapper(BaseMapper):
    def map(self, overlap_ids, bound_ids):
        approximate = getattr(self.args, 'approximate', True)
        q = getattr(self.args, 'q', 1500)
        with_rotation = getattr(self.args, 'with_rotation', True)
        if self.args.procrustes_pca_type == "outer" and self.args.reduced_dim > 0:
            corpus_emb_1_reduced = self.pca_corpus_emb_1.transform(self.corpus_emb_1)
            corpus_emb_2_reduced = self.pca_corpus_emb_2.transform(self.corpus_emb_2)
        else:
            corpus_emb_1_reduced = self.corpus_emb_1
            corpus_emb_2_reduced = self.corpus_emb_2
        if self.args.procrustes_pca_type == "inner":
            transformed, _ = procrustes_pca_mapping(
                corpus_emb_1_reduced, 
                corpus_emb_2_reduced, 
                overlap_ids, 
                corpus_emb_1_reduced[bound_ids], 
                corpus_emb_2_reduced[bound_ids], 
                apporximate=approximate, 
                q=q, 
                with_rotation=with_rotation, 
                reduced_dim=self.args.reduced_dim
            )
        else:
            if self.args.use_norm:
                transformed, _ = procrustes_mapping_torch(
                    corpus_emb_1_reduced, 
                    corpus_emb_2_reduced, 
                    overlap_ids, 
                    corpus_emb_1_reduced[bound_ids], 
                    corpus_emb_2_reduced[bound_ids], 
                    apporximate=approximate, 
                    q=q, 
                    with_rotation=with_rotation
                )
            else:
                transformed, _ = procrustes_no_norm_scale_with_param(
                    corpus_emb_1_reduced, 
                    corpus_emb_2_reduced, 
                    overlap_ids, 
                    corpus_emb_1_reduced[bound_ids], 
                    corpus_emb_2_reduced[bound_ids], 
                    apporximate=approximate, 
                    q=q, 
                    with_rotation=with_rotation
                )
        if self.args.procrustes_pca_type == "outer" and self.args.reduced_dim > 0:
            transformed = self.pca_corpus_emb_2.inverse_transform(transformed)
        return transformed, None
class OursOfflineMapper(BaseMapper):
    def __init__(self, corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1=None, pca_corpus_emb_2=None, d0=None, d1=None, d2=None):
        super().__init__(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1, pca_corpus_emb_2)
        import faiss
        self.d0, self.d1, self.d2 = d0, d1, d2
        self.index = faiss.IndexFlatL2(self.corpus_emb_1.shape[1])
        self.index.add(self.corpus_emb_1[d0])
        temp_sqlitedict_path = f'sqlitedict_{random.randint(0, 1000000)}.sqlite'
        if not os.path.exists(temp_sqlitedict_path):
            os.makedirs(temp_sqlitedict_path)
        self.temp_sqlitedict_path = temp_sqlitedict_path
        self.index2ref_set = sqlitedict.SqliteDict(os.path.join(temp_sqlitedict_path, 'index2ref_set.sqlite'), autocommit=True)
        self.index2params = sqlitedict.SqliteDict(os.path.join(temp_sqlitedict_path, 'index2params.sqlite'), autocommit=True)
        self.precalculate_cluster(overlap_ids=d0, k=args.reduced_dim, index=self.index)
        self.precalculate_map_function(self.index2ref_set)
    def map(self, overlap_ids, bound_ids):
        transformed_corpus_emb_1 = np.zeros_like(self.corpus_emb_1)
        for i, bound_id in enumerate(bound_ids):
            _, nearest_index = self.index.search(self.corpus_emb_1[bound_id].reshape(1, -1), k=1)
            nearest_ref_set = self.index2ref_set[str(self.d0[nearest_index.flatten()[0]])]
            params = self.index2params[str(self.d0[nearest_index.flatten()[0]])]
            transformed_corpus_emb_1[bound_id], _ = procrustes_no_norm_scale_with_param(
                self.corpus_emb_1, 
                self.corpus_emb_2, 
                nearest_ref_set, 
                self.corpus_emb_1[bound_id], 
                self.corpus_emb_2[bound_id], 
                params=params
            )
        return transformed_corpus_emb_1[bound_ids], None
    def precalculate_cluster(self, overlap_ids, k: int=10, index=None):
        _, nearest_index = index.search(self.corpus_emb_1[overlap_ids], k)
        for i, single_overlap_id in enumerate(overlap_ids):
            self.index2ref_set[int(single_overlap_id)] = list(overlap_ids[nearest_index[i]])
    def precalculate_map_function(self, index2ref_set):
        for single_overlap_id, ref_set in tqdm(index2ref_set.items(), desc="Precalculating mapping function", total=len(index2ref_set)):
            if not isinstance(single_overlap_id, int):
                single_overlap_id = int(single_overlap_id)
            _, params = procrustes_no_norm_scale_with_param(
                self.corpus_emb_1, 
                self.corpus_emb_2, 
                [single_overlap_id], 
                self.corpus_emb_1, 
                self.corpus_emb_2, 
            )
            self.index2params[single_overlap_id] = params
            self.index2params.commit()
    def close(self):
        self.index2ref_set.close()
        self.index2params.close()
        os.remove(self.temp_sqlitedict_path)
def get_mapper(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1=None, pca_corpus_emb_2=None, d0=None, d1=None, d2=None):
    if args.mapping_method == "nonlinear":
        return NonlinearMapper(corpus_emb_1, corpus_emb_2, args)
    elif args.mapping_method == "comb":
        return CombMapper(corpus_emb_1, corpus_emb_2, args)
    elif args.mapping_method == "procrustes":
        return ProcrustesMapper(corpus_emb_1, corpus_emb_2, args)
    elif args.mapping_method in ["ours", "ours_mlp"]:
        return OursMapper(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1, pca_corpus_emb_2)
    elif args.mapping_method == "translation_only":
        return TranslationOnlyMapper(corpus_emb_1, corpus_emb_2, args)
    elif args.mapping_method == "ours_offline":
        return OursOfflineMapper(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1, pca_corpus_emb_2, d0, d1, d2)
    elif args.mapping_method == "diffusion":
        from .diffusion_mapper import DiffusionBaseMapper
        return DiffusionBaseMapper(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1, pca_corpus_emb_2)
    else:
        raise ValueError(f"Invalid mapping method: {args.mapping_method}")
def initialize_pca(corpus_emb_1, corpus_emb_2, args):
    if args.procrustes_pca_type == "outer" and args.reduced_dim > 0:
        pca_corpus_emb_1 = PCA(n_components=min(args.reduced_dim, corpus_emb_1.shape[1]), svd_solver='randomized')
        pca_corpus_emb_2 = PCA(n_components=min(args.reduced_dim, corpus_emb_2.shape[1]), svd_solver='randomized')
        pca_corpus_emb_1.fit(corpus_emb_1)
        pca_corpus_emb_2.fit(corpus_emb_2)
        return pca_corpus_emb_1, pca_corpus_emb_2
    return None, None
def mapping_by_ref_list(corpus_emb_1, corpus_emb_2, query_emb_1, query_emb_2, ref_set_list, args, d0, d1, d2, approximate=False, q=None, only_one_cluster=False):
    pca_corpus_emb_1, pca_corpus_emb_2 = initialize_pca(corpus_emb_1, corpus_emb_2, args)
    mapper = get_mapper(corpus_emb_1, corpus_emb_2, args, pca_corpus_emb_1, pca_corpus_emb_2, d0, d1, d2)
    corpus_emb_1_transformed = np.zeros_like(corpus_emb_1)
    beta_array = np.zeros(corpus_emb_1.shape[0])
    alpha_array = np.zeros(corpus_emb_1.shape[0])
    all_replaced_ids = set()
    start_time = time.time()
    for i, cluster_data in tqdm(enumerate(ref_set_list), desc="Mapping by reference list", total=len(ref_set_list)):
        overlap_ids = cluster_data.ref_index
        bound_ids = cluster_data.bound_index
        if len(bound_ids) == 0:
            continue
        transformed_candidate, alpha = mapper.map(overlap_ids, bound_ids)
        corpus_emb_1_transformed[bound_ids] = transformed_candidate
        beta_array[bound_ids] = mapper.calculate_beta(transformed_candidate, bound_ids)
        alpha_array += mapper.calculate_alpha(overlap_ids, alpha)
        all_replaced_ids.update(bound_ids)
    end_time = time.time()
    logger.info(f"Mapping by reference list time: {end_time - start_time} seconds")
    wandb.log({"mapping_time": end_time - start_time})
    mapper.check_replacement(all_replaced_ids, d1)
    return corpus_emb_1_transformed, 0, 0, beta_array, np.zeros_like(beta_array), alpha_array
