import os
import time
import hashlib
import warnings
from typing import Dict, List, Tuple, Optional, Set, Union, Any
from dataclasses import dataclass, field, asdict
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm import tqdm
import threading

from texture.curv_cache import (
    CurvatureCache, CurvatureCacheConfig, CurvatureCacheEntry,
    get_global_cache
)

QUIET_MODE = os.environ.get('TEXTURE_QUIET', '0') == '1'

def _log(msg: str):

    if not QUIET_MODE:
        print(msg)

TAIL_TOKEN = "<TAIL>"
TAIL_TOKEN_ID = -1

@dataclass
class FastCurvatureConfig:

    model_name: str = "distilroberta-base"
    model_revision: Optional[str] = None
    tokenizer_revision: Optional[str] = None
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    L: int = 8
    R: int = 8
    support_k: int = 50
    sinkhorn_iters: int = 2
    sinkhorn_reg: float = 0.1
    eps: float = 1e-12

    belief_L_max: int = 128
    belief_R_max: int = 128
    model_max_len: int = 512

    batch_size: int = 64
    use_fp16: bool = True
    use_bf16: bool = False

    stride: int = 32
    stride_fine: int = 8
    include_boundaries: bool = True

    adaptive_refine: bool = True
    refine_top_fraction: float = 0.2
    max_positions_cap: int = 1200

    query_conditioned: bool = False
    query_max_tokens: int = 64

    cache_dir: Optional[str] = "./cache/curvature"
    cache_mode: str = "readwrite"

    enable_profiling: bool = True

    def to_hash_dict(self) -> Dict:

        return {
            "L": self.L,
            "R": self.R,
            "support_k": self.support_k,
            "sinkhorn_iters": self.sinkhorn_iters,
            "sinkhorn_reg": self.sinkhorn_reg,
            "belief_L_max": self.belief_L_max,
            "belief_R_max": self.belief_R_max,
            "query_conditioned": self.query_conditioned,
            "query_max_tokens": self.query_max_tokens,
        }

@dataclass
class ProfilingStats:

    cache_hit: bool = False
    cache_hit_rate: float = 0.0
    num_positions_evaluated: int = 0
    num_coarse_positions: int = 0
    num_refined_positions: int = 0
    num_forward_passes: int = 0
    actual_belief_window_len: int = 0

    tokenization_time_ms: float = 0.0
    position_selection_time_ms: float = 0.0
    forward_pass_time_ms: float = 0.0
    kappa_compute_time_ms: float = 0.0
    total_curvature_time_ms: float = 0.0
    cache_lookup_time_ms: float = 0.0

    def to_dict(self) -> Dict:
        return asdict(self)

class FastCurvatureComputer:

    _instances: Dict[str, 'FastCurvatureComputer'] = {}
    _lock = threading.Lock()

    @classmethod
    def get_instance(cls, config: Optional[FastCurvatureConfig] = None) -> 'FastCurvatureComputer':

        config = config or FastCurvatureConfig()
        key = f"{config.model_name}_{config.device}"

        with cls._lock:
            if key not in cls._instances:
                cls._instances[key] = cls(config)
            return cls._instances[key]

    def __init__(self, config: Optional[FastCurvatureConfig] = None):
        self.config = config or FastCurvatureConfig()
        self.device = torch.device(self.config.device)

        if self.config.use_bf16 and torch.cuda.is_available():
            try:
                if torch.cuda.is_bf16_supported():
                    self.dtype = torch.bfloat16
                    _log(f"[FastCurvature] Using BFloat16")
                else:
                    self.dtype = torch.float16
                    _log(f"[FastCurvature] Using Float16")
            except:
                self.dtype = torch.float16
                _log(f"[FastCurvature] Using Float16")
        elif self.config.use_fp16:
            self.dtype = torch.float16
            _log(f"[FastCurvature] Using Float16")
        else:
            self.dtype = torch.float32

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.model_name,
            revision=self.config.tokenizer_revision
        )

        _log(f"[FastCurvature] Loading model: {self.config.model_name}")
        self.model = AutoModelForMaskedLM.from_pretrained(
            self.config.model_name,
            revision=self.config.model_revision,
            torch_dtype=self.dtype if self.dtype != torch.float32 else None
        )
        self.model.to(self.device)
        self.model.eval()

        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.mask_id = self.tokenizer.mask_token_id
        self.pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id or 0
        self.vocab_size = self.tokenizer.vocab_size

        if self.config.cache_dir and self.config.cache_mode != "disabled":
            cache_config = CurvatureCacheConfig(
                cache_dir=self.config.cache_dir,
                mode=self.config.cache_mode
            )
            self.cache = CurvatureCache(cache_config)
        else:
            self.cache = None

        self.last_profiling_stats: Optional[ProfilingStats] = None

        self._warmup()

        _log(f"[FastCurvature] Ready on {self.device} "
             f"(batch={self.config.batch_size}, L_max={self.config.belief_L_max}, "
             f"R_max={self.config.belief_R_max})")

    def _warmup(self):

        if self.device.type != 'cuda':
            return
        dummy = torch.randint(0, 1000, (4, 32), device=self.device)
        with torch.no_grad(), torch.inference_mode():
            for _ in range(3):
                self.model(dummy)
        torch.cuda.synchronize()

    def tokenize(self, text: str) -> List[int]:

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=".*Token indices sequence length.*")
            return self.tokenizer.encode(text, add_special_tokens=False)

    def build_masked_window_input(
        self,
        tokens: List[int],
        pos: int,
        side: str
    ) -> Tuple[List[int], int]:

        L_max = self.config.belief_L_max
        R_max = self.config.belief_R_max
        L_belief = self.config.L
        R_belief = self.config.R

        left_start = max(0, pos - L_max)
        right_end = min(len(tokens), pos + 1 + R_max)

        if side == 'left':
            left_tokens = tokens[max(0, pos - L_belief):pos]
            right_tokens = []
        elif side == 'right':
            left_tokens = []
            right_tokens = tokens[pos + 1:min(len(tokens), pos + 1 + R_belief)]
        else:
            left_tokens = tokens[max(0, pos - L_belief):pos]
            right_tokens = tokens[pos + 1:min(len(tokens), pos + 1 + R_belief)]

        input_ids = [self.cls_id] + left_tokens + [self.mask_id] + right_tokens + [self.sep_id]
        mask_idx = 1 + len(left_tokens)

        if len(input_ids) > self.config.model_max_len:
            excess = len(input_ids) - self.config.model_max_len
            left_cut = excess // 2
            right_cut = excess - left_cut

            left_tokens = left_tokens[left_cut:]
            right_tokens = right_tokens[:len(right_tokens) - right_cut] if right_cut > 0 else right_tokens

            input_ids = [self.cls_id] + left_tokens + [self.mask_id] + right_tokens + [self.sep_id]
            mask_idx = 1 + len(left_tokens)

        return input_ids, mask_idx

    def build_masked_window_input_query_conditioned(
        self,
        context_tokens: List[int],
        query_tokens: List[int],
        pos: int,
        side: str
    ) -> Tuple[List[int], int, List[int]]:

        L_belief = self.config.L
        R_belief = self.config.R
        query_max = self.config.query_max_tokens

        query_tokens = query_tokens[:query_max]

        left_tokens = context_tokens[max(0, pos - L_belief):pos]
        right_tokens = context_tokens[pos + 1:min(len(context_tokens), pos + 1 + R_belief)]

        input_ids = (
            [self.cls_id] +
            query_tokens +
            [self.sep_id] +
            left_tokens +
            [self.mask_id] +
            right_tokens +
            [self.sep_id]
        )

        mask_idx = 1 + len(query_tokens) + 1 + len(left_tokens)

        attention_mask = [1] * len(input_ids)

        if side == 'left':
            right_start = mask_idx + 1
            right_end = len(input_ids) - 1
            for i in range(right_start, right_end):
                attention_mask[i] = 0
        elif side == 'right':
            left_start = 1 + len(query_tokens) + 1
            left_end = mask_idx
            for i in range(left_start, left_end):
                attention_mask[i] = 0

        if len(input_ids) > self.config.model_max_len:
            excess = len(input_ids) - self.config.model_max_len
            if side == 'left':
                left_tokens = left_tokens[excess:]
            else:
                right_tokens = right_tokens[:len(right_tokens) - excess] if excess < len(right_tokens) else []

            input_ids = (
                [self.cls_id] +
                query_tokens +
                [self.sep_id] +
                left_tokens +
                [self.mask_id] +
                right_tokens +
                [self.sep_id]
            )
            mask_idx = 1 + len(query_tokens) + 1 + len(left_tokens)

            attention_mask = [1] * len(input_ids)
            if side == 'left':
                right_start = mask_idx + 1
                right_end = len(input_ids) - 1
                for i in range(right_start, right_end):
                    attention_mask[i] = 0
            elif side == 'right':
                left_start = 1 + len(query_tokens) + 1
                left_end = mask_idx
                for i in range(left_start, left_end):
                    attention_mask[i] = 0

        return input_ids, mask_idx, attention_mask

    @torch.no_grad()
    def compute_boundary_beliefs_batch(
        self,
        tokens: List[int],
        positions: List[int],
        side: str,
        batch_size: Optional[int] = None
    ) -> Dict[int, np.ndarray]:

        batch_size = batch_size or self.config.batch_size

        all_inputs = []
        all_mask_indices = []
        valid_positions = []

        for pos in positions:
            if side == 'left' and pos < self.config.L:
                continue
            if side == 'right' and pos >= len(tokens) - self.config.R:
                continue

            input_ids, mask_idx = self.build_masked_window_input(tokens, pos, side)
            all_inputs.append(input_ids)
            all_mask_indices.append(mask_idx)
            valid_positions.append(pos)

        if not all_inputs:
            return {}

        results = {}

        for batch_start in range(0, len(all_inputs), batch_size):
            batch_end = min(batch_start + batch_size, len(all_inputs))
            batch_inputs = all_inputs[batch_start:batch_end]
            batch_mask_idx = all_mask_indices[batch_start:batch_end]
            batch_positions = valid_positions[batch_start:batch_end]

            max_len = max(len(x) for x in batch_inputs)
            padded = torch.full(
                (len(batch_inputs), max_len),
                self.pad_id,
                dtype=torch.long,
                device=self.device
            )
            attention_mask = torch.zeros(
                len(batch_inputs), max_len,
                dtype=torch.long,
                device=self.device
            )

            for i, ids in enumerate(batch_inputs):
                padded[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
                attention_mask[i, :len(ids)] = 1

            with torch.inference_mode():
                if self.dtype != torch.float32 and self.device.type == 'cuda':
                    with torch.cuda.amp.autocast(dtype=self.dtype):
                        outputs = self.model(padded, attention_mask=attention_mask)
                else:
                    outputs = self.model(padded, attention_mask=attention_mask)

            logits = outputs.logits
            for i, (pos, mask_idx) in enumerate(zip(batch_positions, batch_mask_idx)):
                mask_logits = logits[i, mask_idx]
                probs = F.softmax(mask_logits.float(), dim=-1)
                results[pos] = probs.cpu().numpy()

        return results

    @torch.no_grad()
    def compute_boundary_beliefs_batch_query_conditioned(
        self,
        context_tokens: List[int],
        query_tokens: List[int],
        positions: List[int],
        side: str,
        batch_size: Optional[int] = None
    ) -> Dict[int, np.ndarray]:

        batch_size = batch_size or self.config.batch_size

        all_inputs = []
        all_mask_indices = []
        all_attention_masks = []
        valid_positions = []

        for pos in positions:
            if side == 'left' and pos < self.config.L:
                continue
            if side == 'right' and pos >= len(context_tokens) - self.config.R:
                continue

            input_ids, mask_idx, attn_mask = self.build_masked_window_input_query_conditioned(
                context_tokens, query_tokens, pos, side
            )
            all_inputs.append(input_ids)
            all_mask_indices.append(mask_idx)
            all_attention_masks.append(attn_mask)
            valid_positions.append(pos)

        if not all_inputs:
            return {}

        results = {}

        for batch_start in range(0, len(all_inputs), batch_size):
            batch_end = min(batch_start + batch_size, len(all_inputs))
            batch_inputs = all_inputs[batch_start:batch_end]
            batch_mask_idx = all_mask_indices[batch_start:batch_end]
            batch_attn = all_attention_masks[batch_start:batch_end]
            batch_positions = valid_positions[batch_start:batch_end]

            max_len = max(len(x) for x in batch_inputs)
            padded = torch.full(
                (len(batch_inputs), max_len),
                self.pad_id,
                dtype=torch.long,
                device=self.device
            )
            attention_mask = torch.zeros(
                len(batch_inputs), max_len,
                dtype=torch.long,
                device=self.device
            )

            for i, (ids, attn) in enumerate(zip(batch_inputs, batch_attn)):
                padded[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
                attention_mask[i, :len(attn)] = torch.tensor(attn, dtype=torch.long)

            with torch.inference_mode():
                if self.dtype != torch.float32 and self.device.type == 'cuda':
                    with torch.cuda.amp.autocast(dtype=self.dtype):
                        outputs = self.model(padded, attention_mask=attention_mask)
                else:
                    outputs = self.model(padded, attention_mask=attention_mask)

            logits = outputs.logits
            for i, (pos, mask_idx) in enumerate(zip(batch_positions, batch_mask_idx)):
                mask_logits = logits[i, mask_idx]
                probs = F.softmax(mask_logits.float(), dim=-1)
                results[pos] = probs.cpu().numpy()

        return results

    def _build_candidate_set(
        self,
        left_probs: np.ndarray,
        right_probs: np.ndarray,
        k: int
    ) -> List[int]:

        top_left = set(np.argsort(left_probs)[-k:].tolist())
        top_right = set(np.argsort(right_probs)[-k:].tolist())
        candidates = list(top_left | top_right)
        candidates.append(TAIL_TOKEN_ID)
        return candidates

    def _project_to_candidates(
        self,
        probs: np.ndarray,
        candidates: List[int]
    ) -> np.ndarray:

        result = np.zeros(len(candidates))
        total_in_candidates = 0.0

        for i, c in enumerate(candidates):
            if c == TAIL_TOKEN_ID:
                continue
            result[i] = probs[c]
            total_in_candidates += probs[c]

        tail_idx = candidates.index(TAIL_TOKEN_ID)
        result[tail_idx] = max(1.0 - total_in_candidates, 0.0)

        result = np.maximum(result, self.config.eps)
        result /= result.sum()

        return result

    def _compute_kappa_single(
        self,
        mu_L: np.ndarray,
        mu_R: np.ndarray
    ) -> float:

        eps = self.config.eps

        H_L = -np.sum(mu_L * np.log(mu_L + eps))
        H_R = -np.sum(mu_R * np.log(mu_R + eps))

        mu_mid = np.sqrt(mu_L * mu_R)
        mu_mid = mu_mid / (mu_mid.sum() + eps)
        mu_mid = np.maximum(mu_mid, eps)
        mu_mid = mu_mid / mu_mid.sum()

        H_mid = -np.sum(mu_mid * np.log(mu_mid + eps))

        kappa = 8.0 * (H_mid - 0.5 * (H_L + H_R))

        return float(kappa)

    def compute_kappa_at_positions(
        self,
        tokens: List[int],
        positions: List[int],
        show_progress: bool = False,
        query_tokens: Optional[List[int]] = None
    ) -> Tuple[Dict[int, float], int]:

        valid_positions = [
            p for p in positions
            if self.config.L <= p < len(tokens) - self.config.R
        ]

        if not valid_positions:
            return {}, 0

        if query_tokens is not None and self.config.query_conditioned:
            left_beliefs = self.compute_boundary_beliefs_batch_query_conditioned(
                tokens, query_tokens, valid_positions, 'left'
            )
            right_beliefs = self.compute_boundary_beliefs_batch_query_conditioned(
                tokens, query_tokens, valid_positions, 'right'
            )
        else:
            left_beliefs = self.compute_boundary_beliefs_batch(
                tokens, valid_positions, 'left'
            )
            right_beliefs = self.compute_boundary_beliefs_batch(
                tokens, valid_positions, 'right'
            )

        num_forward = len(left_beliefs) + len(right_beliefs)

        kappa_dict = {}

        iterator = valid_positions
        if show_progress:
            iterator = tqdm(iterator, desc="Computing κ")

        for pos in iterator:
            if pos not in left_beliefs or pos not in right_beliefs:
                continue

            left_probs = left_beliefs[pos]
            right_probs = right_beliefs[pos]

            candidates = self._build_candidate_set(
                left_probs, right_probs, self.config.support_k
            )

            mu_L = self._project_to_candidates(left_probs, candidates)
            mu_R = self._project_to_candidates(right_probs, candidates)

            kappa_dict[pos] = self._compute_kappa_single(mu_L, mu_R)

        return kappa_dict, num_forward

    def select_coarse_positions(
        self,
        tokens: List[int],
        stride: Optional[int] = None
    ) -> List[int]:

        stride = stride or self.config.stride
        L, R = self.config.L, self.config.R
        n = len(tokens)

        positions = list(range(L, n - R, stride))

        if self.config.include_boundaries:
            for i, tok_id in enumerate(tokens):
                if L <= i < n - R:
                    tok = self.tokenizer.decode([tok_id])
                    if any(p in tok for p in ['.', '?', '!', '\n', '\n\n']):
                        positions.append(i)

        positions = sorted(set(positions))

        if len(positions) > self.config.max_positions_cap:
            step = len(positions) // self.config.max_positions_cap
            positions = positions[::step]

        return positions

    def select_refinement_positions(
        self,
        tokens: List[int],
        kappa_dict: Dict[int, float],
        top_fraction: Optional[float] = None
    ) -> List[int]:

        if not kappa_dict:
            return []

        top_fraction = top_fraction or self.config.refine_top_fraction
        stride_fine = self.config.stride_fine
        L, R = self.config.L, self.config.R
        n = len(tokens)

        sorted_positions = sorted(kappa_dict.keys(), key=lambda p: abs(kappa_dict[p]), reverse=True)
        n_top = max(1, int(len(sorted_positions) * top_fraction))
        top_positions = sorted_positions[:n_top]

        refine_positions = set()
        for pos in top_positions:
            for offset in range(-stride_fine * 2, stride_fine * 2 + 1, stride_fine):
                new_pos = pos + offset
                if L <= new_pos < n - R and new_pos not in kappa_dict:
                    refine_positions.add(new_pos)

        return sorted(refine_positions)

    def compute_sparse(
        self,
        text: str,
        stride: Optional[int] = None,
        include_boundaries: bool = True,
        cache_key: Optional[str] = None,
        adaptive_refine: Optional[bool] = None,
        dataset: str = "default",
        split: str = "default",
        example_id: str = "0",
        show_progress: bool = False,
        query: Optional[str] = None
    ) -> Tuple[List[int], List[float]]:

        t_start = time.time()
        stats = ProfilingStats()

        query_tokens = None
        query_hash = ""
        if query is not None and self.config.query_conditioned:
            query_tokens = self.tokenize(query)[:self.config.query_max_tokens]
            query_hash = hashlib.sha256(query.encode('utf-8')).hexdigest()[:16]

        t_cache_start = time.time()
        if self.cache is not None:
            config_dict = self.config.to_hash_dict()
            cache_text = text if not query_hash else f"{query_hash}||{text}"
            cached = self.cache.get(
                text=cache_text,
                config_dict=config_dict,
                model_name=self.config.model_name,
                tokenizer_name=self.config.model_name,
                model_revision=self.config.model_revision,
                tokenizer_revision=self.config.tokenizer_revision,
                dataset=dataset,
                split=split,
                example_id=example_id
            )
            if cached is not None:
                stats.cache_hit = True
                stats.cache_lookup_time_ms = (time.time() - t_cache_start) * 1000
                stats.num_positions_evaluated = len(cached.positions)
                stats.total_curvature_time_ms = stats.cache_lookup_time_ms
                self.last_profiling_stats = stats
                query_info = f" (query={query_hash[:8]}...)" if query_hash else ""
                _log(f"[FastCurvature] Cache HIT: {dataset}/{split}/{example_id}{query_info} ({len(cached.positions)} positions)")
                return cached.positions, cached.kappa_values
            else:
                query_info = f" (query={query_hash[:8]}...)" if query_hash else ""
                _log(f"[FastCurvature] Cache MISS: {dataset}/{split}/{example_id}{query_info}")
        else:
            _log(f"[FastCurvature] Cache DISABLED (cache={self.cache is not None})")

        stats.cache_lookup_time_ms = (time.time() - t_cache_start) * 1000

        t_tok_start = time.time()
        tokens = self.tokenize(text)
        stats.tokenization_time_ms = (time.time() - t_tok_start) * 1000

        t_pos_start = time.time()
        stride = stride or self.config.stride
        coarse_positions = self.select_coarse_positions(tokens, stride)
        stats.num_coarse_positions = len(coarse_positions)
        stats.position_selection_time_ms = (time.time() - t_pos_start) * 1000

        t_fwd_start = time.time()
        kappa_dict, num_fwd = self.compute_kappa_at_positions(
            tokens, coarse_positions, show_progress=show_progress,
            query_tokens=query_tokens
        )
        stats.num_forward_passes = num_fwd

        adaptive = adaptive_refine if adaptive_refine is not None else self.config.adaptive_refine
        if adaptive and kappa_dict:
            refine_positions = self.select_refinement_positions(tokens, kappa_dict)
            if refine_positions:
                refine_kappa, refine_fwd = self.compute_kappa_at_positions(
                    tokens, refine_positions, show_progress=show_progress,
                    query_tokens=query_tokens
                )
                kappa_dict.update(refine_kappa)
                stats.num_forward_passes += refine_fwd
                stats.num_refined_positions = len(refine_positions)

        stats.forward_pass_time_ms = (time.time() - t_fwd_start) * 1000
        stats.num_positions_evaluated = len(kappa_dict)

        sorted_positions = sorted(kappa_dict.keys())
        sorted_values = [kappa_dict[p] for p in sorted_positions]

        if self.cache is not None and sorted_positions:
            config_dict = self.config.to_hash_dict()
            cache_text = text if not query_hash else f"{query_hash}||{text}"
            self.cache.put(
                text=cache_text,
                config_dict=config_dict,
                positions=sorted_positions,
                kappa_values=sorted_values,
                model_name=self.config.model_name,
                tokenizer_name=self.config.model_name,
                computation_time_ms=stats.forward_pass_time_ms,
                num_forward_passes=stats.num_forward_passes,
                model_revision=self.config.model_revision,
                tokenizer_revision=self.config.tokenizer_revision,
                belief_window_L=self.config.belief_L_max,
                belief_window_R=self.config.belief_R_max,
                stride_used=stride,
                dataset=dataset,
                split=split,
                example_id=example_id
            )

        stats.total_curvature_time_ms = (time.time() - t_start) * 1000
        if self.cache:
            cache_stats = self.cache.get_stats()
            stats.cache_hit_rate = cache_stats['hit_rate']
        self.last_profiling_stats = stats

        return sorted_positions, sorted_values

    def compute(
        self,
        text: str,
        positions: List[int],
        cache_key: Optional[str] = None,
        **kwargs
    ) -> Dict[int, float]:

        tokens = self.tokenize(text)
        kappa_dict, _ = self.compute_kappa_at_positions(tokens, positions)
        return kappa_dict

    def get_profiling_stats(self) -> Optional[ProfilingStats]:

        return self.last_profiling_stats

    def print_profiling_stats(self, prefix: str = "[Curvature]"):

        if self.last_profiling_stats is None:
            _log(f"{prefix} No profiling stats available")
            return

        s = self.last_profiling_stats
        _log(f"{prefix} Cache: {'HIT' if s.cache_hit else 'MISS'} "
             f"(hit_rate={s.cache_hit_rate*100:.1f}%)")
        _log(f"{prefix} Positions: {s.num_positions_evaluated} "
             f"(coarse={s.num_coarse_positions}, refined={s.num_refined_positions})")
        _log(f"{prefix} Forward passes: {s.num_forward_passes}")
        _log(f"{prefix} Time: {s.total_curvature_time_ms:.1f}ms "
             f"(fwd={s.forward_pass_time_ms:.1f}ms, cache_lookup={s.cache_lookup_time_ms:.1f}ms)")

def get_fast_curvature_computer(
    config: Optional[FastCurvatureConfig] = None
) -> FastCurvatureComputer:

    return FastCurvatureComputer.get_instance(config)

def compute_kappa_fast(
    text: str,
    positions: Optional[List[int]] = None,
    stride: int = 32,
    config: Optional[FastCurvatureConfig] = None,
    query: Optional[str] = None,
    **kwargs
) -> Tuple[List[int], List[float]]:

    computer = get_fast_curvature_computer(config)
    return computer.compute_sparse(text, stride=stride, query=query, **kwargs)
