import os
import hashlib
from typing import Dict, List, Tuple, Optional, Set, Union
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm import tqdm

TAIL_TOKEN = "<TAIL>"
TAIL_TOKEN_ID = -1

@dataclass
class CurvatureConfig:

    model_name: str = "distilroberta-base"
    L: int = 8
    R: int = 8
    L_max: int = 128
    R_max: int = 128
    support_k: int = 50
    sinkhorn_iters: int = 2
    sinkhorn_reg: float = 0.1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    cache_dir: Optional[str] = "./cache/curvature"
    batch_size: int = 32
    eps: float = 1e-12

    stride: int = 32
    include_boundaries: bool = True
    adaptive_refine: bool = True
    refine_top_fraction: float = 0.2
    refine_stride: int = 8
    use_batched_compute: bool = True

class CurvatureComputer:

    def __init__(self, config: Optional[CurvatureConfig] = None):

        self.config = config or CurvatureConfig()
        self.device = torch.device(self.config.device)

        print(f"[Curvature] Loading model: {self.config.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(self.config.model_name)
        self.model.to(self.device)
        self.model.eval()

        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.mask_id = self.tokenizer.mask_token_id
        self.pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        self.vocab_size = self.tokenizer.vocab_size

        if self.config.cache_dir:
            os.makedirs(self.config.cache_dir, exist_ok=True)

        print(f"[Curvature] Model loaded on {self.device}")

        self._prior_belief_cache = None

    def _get_prior_belief(self) -> np.ndarray:

        if self._prior_belief_cache is None:
            self._prior_belief_cache = self._get_belief([], [])
        return self._prior_belief_cache

    def _get_cache_path(self, cache_key: str) -> str:

        hash_key = hashlib.md5(cache_key.encode()).hexdigest()[:16]
        return os.path.join(self.config.cache_dir, f"kappa_{hash_key}.pt")

    def _load_from_cache(self, cache_key: str) -> Optional[Dict[int, float]]:

        if not self.config.cache_dir:
            return None

        cache_path = self._get_cache_path(cache_key)
        if os.path.exists(cache_path):
            try:
                data = torch.load(cache_path, weights_only=True)
                return data
            except Exception:
                return None
        return None

    def _save_to_cache(self, cache_key: str, kappa_dict: Dict[int, float]):

        if not self.config.cache_dir:
            return

        cache_path = self._get_cache_path(cache_key)
        torch.save(kappa_dict, cache_path)

    def tokenize(self, text: str) -> List[int]:

        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=".*Token indices sequence length.*")
            return self.tokenizer.encode(text, add_special_tokens=False)

    def _build_masked_input(
        self,
        left_tokens: List[int],
        right_tokens: List[int]
    ) -> Tuple[List[int], int]:

        input_ids = [self.cls_id] + left_tokens + [self.mask_id] + right_tokens + [self.sep_id]
        mask_position = 1 + len(left_tokens)
        return input_ids, mask_position

    def _get_belief(
        self,
        left_tokens: List[int],
        right_tokens: List[int]
    ) -> np.ndarray:

        input_ids, mask_pos = self._build_masked_input(left_tokens, right_tokens)
        input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor)
            logits = outputs.logits[0, mask_pos]
            probs = F.softmax(logits, dim=-1)

        return probs.cpu().numpy()

    def _get_beliefs_batch(
        self,
        inputs: List[Tuple[List[int], List[int]]]
    ) -> List[np.ndarray]:

        if not inputs:
            return []

        all_input_ids = []
        all_mask_positions = []

        for left_tokens, right_tokens in inputs:
            input_ids, mask_pos = self._build_masked_input(left_tokens, right_tokens)
            all_input_ids.append(input_ids)
            all_mask_positions.append(mask_pos)

        max_len = max(len(ids) for ids in all_input_ids)
        padded = []
        attention_masks = []

        for input_ids in all_input_ids:
            padding = [self.pad_id] * (max_len - len(input_ids))
            padded.append(input_ids + padding)
            attention_masks.append([1] * len(input_ids) + [0] * len(padding))

        input_tensor = torch.tensor(padded, dtype=torch.long, device=self.device)
        attention_tensor = torch.tensor(attention_masks, dtype=torch.long, device=self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor, attention_mask=attention_tensor)
            logits = outputs.logits

        results = []
        for i, mask_pos in enumerate(all_mask_positions):
            probs = F.softmax(logits[i, mask_pos], dim=-1)
            results.append(probs.cpu().numpy())

        return results

    def _build_candidate_set(
        self,
        left_belief: np.ndarray,
        right_belief: np.ndarray,
        k: int
    ) -> Set[int]:

        top_left = set(np.argsort(left_belief)[-k:].tolist())
        top_right = set(np.argsort(right_belief)[-k:].tolist())
        return top_left | top_right

    def _project_to_candidates(
        self,
        full_probs: np.ndarray,
        candidate_set: Set[int],
        eps: float = 1e-12
    ) -> Dict[int, float]:

        mu = {}
        prob_sum = 0.0

        for token_id in candidate_set:
            p = float(full_probs[token_id])
            mu[token_id] = p
            prob_sum += p

        mu[TAIL_TOKEN_ID] = max(1.0 - prob_sum, 0.0)

        total = sum(mu.values())
        if total > 0:
            for k in mu:
                mu[k] = max(mu[k] / total, eps)
        else:
            n = len(mu)
            for k in mu:
                mu[k] = 1.0 / n

        total = sum(mu.values())
        for k in mu:
            mu[k] /= total

        return mu

    def _compute_cost_matrix(
        self,
        mu_L: Dict[int, float],
        mu_R: Dict[int, float],
        states: List[int]
    ) -> np.ndarray:

        n = len(states)
        C = np.zeros((n, n))
        eps = self.config.eps

        for i, s in enumerate(states):
            for j, t in enumerate(states):
                pL_s = max(mu_L.get(s, eps), eps)
                pR_s = max(mu_R.get(s, eps), eps)
                pL_t = max(mu_L.get(t, eps), eps)
                pR_t = max(mu_R.get(t, eps), eps)

                C[i, j] = -np.log(np.sqrt(pL_s * pR_t) + eps) - np.log(np.sqrt(pL_t * pR_s) + eps)

        return C

    def _sinkhorn_midpoint(
        self,
        mu_L: Dict[int, float],
        mu_R: Dict[int, float],
        states: List[int],
        n_iters: int = 2,
        reg: float = 0.1
    ) -> Tuple[Dict[int, float], float]:

        n = len(states)
        eps = self.config.eps

        p = np.array([mu_L.get(s, eps) for s in states])
        q = np.array([mu_R.get(s, eps) for s in states])

        p = p / (p.sum() + eps)
        q = q / (q.sum() + eps)

        C = self._compute_cost_matrix(mu_L, mu_R, states)

        K = np.exp(-C / reg)

        u = np.ones(n)
        v = np.ones(n)

        for _ in range(n_iters):
            u = p / (K @ v + eps)
            v = q / (K.T @ u + eps)

        Pi = np.diag(u) @ K @ np.diag(v)

        transport_cost = np.sum(Pi * C)

        midpoint = np.zeros(n)
        for i in range(n):
            for j in range(n):
                midpoint[i] += Pi[i, j] * np.sqrt(p[i] * q[j])
                midpoint[j] += Pi[i, j] * np.sqrt(p[i] * q[j])

        midpoint = midpoint / (midpoint.sum() + eps)

        gamma = {states[i]: float(midpoint[i]) for i in range(n)}

        return gamma, transport_cost

    def _compute_free_energy(
        self,
        mu: Dict[int, float],
        reference: Dict[int, float],
        states: List[int]
    ) -> float:

        eps = self.config.eps
        kl = 0.0

        for s in states:
            mu_s = mu.get(s, eps)
            pi_s = reference.get(s, eps)
            if mu_s > eps:
                kl += mu_s * (np.log(max(mu_s, eps)) - np.log(max(pi_s, eps)))

        return max(kl, 0.0)

    def _compute_curvature_at_position(
        self,
        tokens: List[int],
        position: int
    ) -> float:

        L = self.config.L
        R = self.config.R
        k = self.config.support_k
        eps = self.config.eps

        left_start = max(0, position - L)
        left_tokens = tokens[left_start:position]

        right_end = min(len(tokens), position + 1 + R)
        right_tokens = tokens[position + 1:right_end]

        left_belief_full = self._get_belief(left_tokens, [])

        right_belief_full = self._get_belief([], right_tokens)

        prior_belief_full = self._get_belief([], [])

        joint_belief_full = self._get_belief(left_tokens, right_tokens)

        candidate_set = self._build_candidate_set(left_belief_full, right_belief_full, k)
        states = list(candidate_set) + [TAIL_TOKEN_ID]

        mu_L = self._project_to_candidates(left_belief_full, candidate_set, eps)
        mu_R = self._project_to_candidates(right_belief_full, candidate_set, eps)
        pi = self._project_to_candidates(prior_belief_full, candidate_set, eps)
        mu_joint = self._project_to_candidates(joint_belief_full, candidate_set, eps)

        gamma, transport_cost = self._sinkhorn_midpoint(
            mu_L, mu_R, states,
            n_iters=self.config.sinkhorn_iters,
            reg=self.config.sinkhorn_reg
        )

        F_gamma = self._compute_free_energy(gamma, pi, states)
        F_L = self._compute_free_energy(mu_L, pi, states)
        F_R = self._compute_free_energy(mu_R, pi, states)

        kappa = 8.0 * (F_gamma - 0.5 * F_L - 0.5 * F_R)

        return kappa

    def _compute_curvature_batched(
        self,
        tokens: List[int],
        positions: List[int]
    ) -> Dict[int, float]:

        if not positions:
            return {}

        L = min(self.config.L, self.config.L_max)
        R = min(self.config.R, self.config.R_max)
        k = self.config.support_k
        eps = self.config.eps
        batch_size = self.config.batch_size

        prior_belief_full = self._get_prior_belief()

        left_inputs = []
        right_inputs = []

        for pos in positions:
            left_start = max(0, pos - L)
            left_tokens = tokens[left_start:pos]

            right_end = min(len(tokens), pos + 1 + R)
            right_tokens = tokens[pos + 1:right_end]

            left_inputs.append((left_tokens, []))
            right_inputs.append(([], right_tokens))

        left_beliefs = []
        for i in range(0, len(left_inputs), batch_size):
            batch = left_inputs[i:i + batch_size]
            left_beliefs.extend(self._get_beliefs_batch(batch))

        right_beliefs = []
        for i in range(0, len(right_inputs), batch_size):
            batch = right_inputs[i:i + batch_size]
            right_beliefs.extend(self._get_beliefs_batch(batch))

        kappa_dict = {}
        for idx, pos in enumerate(positions):
            left_belief_full = left_beliefs[idx]
            right_belief_full = right_beliefs[idx]

            candidate_set = self._build_candidate_set(left_belief_full, right_belief_full, k)
            states = list(candidate_set) + [TAIL_TOKEN_ID]

            mu_L = self._project_to_candidates(left_belief_full, candidate_set, eps)
            mu_R = self._project_to_candidates(right_belief_full, candidate_set, eps)
            pi = self._project_to_candidates(prior_belief_full, candidate_set, eps)

            gamma, transport_cost = self._sinkhorn_midpoint(
                mu_L, mu_R, states,
                n_iters=self.config.sinkhorn_iters,
                reg=self.config.sinkhorn_reg
            )

            F_gamma = self._compute_free_energy(gamma, pi, states)
            F_L = self._compute_free_energy(mu_L, pi, states)
            F_R = self._compute_free_energy(mu_R, pi, states)

            kappa = 8.0 * (F_gamma - 0.5 * F_L - 0.5 * F_R)
            kappa_dict[pos] = kappa

        return kappa_dict

    def compute(
        self,
        text: str,
        positions: List[int],
        cache_key: Optional[str] = None
    ) -> Dict[int, float]:

        if cache_key:
            cached = self._load_from_cache(cache_key)
            if cached is not None:
                return {p: cached.get(p, 0.0) for p in positions}

        tokens = self.tokenize(text)

        valid_positions = [p for p in positions if 0 <= p < len(tokens)]

        if self.config.use_batched_compute:
            kappa_dict = self._compute_curvature_batched(tokens, valid_positions)
        else:
            kappa_dict = {}
            for pos in tqdm(valid_positions, desc="Computing curvature", leave=False):
                kappa_dict[pos] = self._compute_curvature_at_position(tokens, pos)

        if cache_key:
            self._save_to_cache(cache_key, kappa_dict)

        return kappa_dict

    def compute_sparse(
        self,
        text: str,
        stride: Optional[int] = None,
        include_boundaries: Optional[bool] = None,
        boundary_positions: Optional[List[int]] = None,
        cache_key: Optional[str] = None,
        adaptive_refine: Optional[bool] = None
    ) -> Tuple[List[int], List[float]]:

        if stride is None:
            stride = self.config.stride
        if include_boundaries is None:
            include_boundaries = self.config.include_boundaries
        if adaptive_refine is None:
            adaptive_refine = self.config.adaptive_refine

        tokens = self.tokenize(text)
        n_tokens = len(tokens)

        L = min(self.config.L, self.config.L_max)
        R = min(self.config.R, self.config.R_max)

        positions = set(range(0, n_tokens, stride))

        if include_boundaries and boundary_positions:
            positions.update(p for p in boundary_positions if 0 <= p < n_tokens)

        valid_positions = sorted([
            p for p in positions
            if L <= p < n_tokens - R
        ])

        kappa_dict = self.compute(text, valid_positions, cache_key)

        if adaptive_refine and len(kappa_dict) > 0:
            sorted_by_kappa = sorted(
                kappa_dict.items(),
                key=lambda x: abs(x[1]),
                reverse=True
            )

            n_refine = max(1, int(len(sorted_by_kappa) * self.config.refine_top_fraction))
            top_positions = [pos for pos, _ in sorted_by_kappa[:n_refine]]

            refine_positions = set()
            refine_stride = self.config.refine_stride
            refine_radius = stride

            for pos in top_positions:
                for offset in range(-refine_radius, refine_radius + 1, refine_stride):
                    new_pos = pos + offset
                    if L <= new_pos < n_tokens - R and new_pos not in kappa_dict:
                        refine_positions.add(new_pos)

            if refine_positions:
                refine_list = sorted(refine_positions)
                refine_kappa = self._compute_curvature_batched(tokens, refine_list)
                kappa_dict.update(refine_kappa)

        positions_out = sorted(kappa_dict.keys())
        values_out = [kappa_dict[p] for p in positions_out]

        return positions_out, values_out

def compute_kappa(
    sequence_text: str,
    positions: List[int],
    L: int = 8,
    R: int = 8,
    support_k: int = 50,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    cache_key: Optional[str] = None,
    **kwargs
) -> Dict[int, float]:

    config = CurvatureConfig(
        L=L,
        R=R,
        support_k=support_k,
        device=device,
        **kwargs
    )

    computer = CurvatureComputer(config)
    return computer.compute(sequence_text, positions, cache_key)

def compute_kappa_sparse(
    text: str,
    stride: int = 16,
    include_boundaries: bool = True,
    L: int = 8,
    R: int = 8,
    support_k: int = 50,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    cache_key: Optional[str] = None,
    **kwargs
) -> Tuple[List[int], List[float]]:

    config = CurvatureConfig(
        L=L,
        R=R,
        support_k=support_k,
        device=device,
        **kwargs
    )

    computer = CurvatureComputer(config)
    return computer.compute_sparse(text, stride, include_boundaries, None, cache_key)

def get_optimized_computer(config=None):

    try:
        from texture.curvature_optimized import OptimizedCurvatureComputer, OptimizedCurvatureConfig

        if config is None:
            opt_config = OptimizedCurvatureConfig()
        else:
            opt_config = OptimizedCurvatureConfig(
                model_name=getattr(config, 'model_name', 'distilroberta-base'),
                L=getattr(config, 'L', 8),
                R=getattr(config, 'R', 8),
                support_k=getattr(config, 'support_k', 50),
                device=getattr(config, 'device', 'cuda'),
                cache_dir=getattr(config, 'cache_dir', './cache/curvature'),
                batch_size=getattr(config, 'batch_size', 64),
            )

        return OptimizedCurvatureComputer.get_instance(opt_config)
    except ImportError:
        print("[Curvature] Optimized version not available, using standard")
        return CurvatureComputer(config)
