from typing import Any, Iterable, List, Sequence
import os
import re
import json
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers.modeling_outputs import BaseModelOutput
from transformers import AutoTokenizer, BartForConditionalGeneration, LogitsProcessor

# =========================
#   Surrogate Encoder
# =========================
def surrogate_encode(texts, embedding_name, surrogate_model, batch_size, device):
    if embedding_name == "intfloat/e5-base-v2":
        texts = ["query: " + t for t in texts]
    else:
        texts = [t for t in texts]
    return surrogate_model.encode(
        texts,
        batch_size=batch_size,
        convert_to_tensor=True,
        device=device,
        normalize_embeddings=True
    ).detach().clone()

# =========================
#   Inverse Decoding Processor
# =========================
class SemanticGuidedLogitsProcessor(LogitsProcessor):
    """
    Implements Eq.(1):  S(w) = logit_M(w|prefix) + λ cos(ES(prefix ◦ w), v*)
    """

    def __init__(
        self,
        surrogate_model,
        tokenizer,
        target_latent_vec,
        semantic_weight=2.0,
        topk_tokens=64,
        device=None,
        verbose=False,
    ):

        dev = device or target_latent_vec.device
        tgt = target_latent_vec.to(dev)
        if tgt.dim() == 1:
            tgt = tgt.unsqueeze(0)
        self.target_latent = F.normalize(tgt, dim=-1).detach()

        self.tokenizer = tokenizer
        self.semantic_weight = float(semantic_weight)
        self.topk_tokens = int(topk_tokens)
        self.device = device or target_latent_vec.device
        self.verbose = verbose
        self._step = 0
        self.surrogate_model = surrogate_model

    def _decode_candidate(self, ids):
        text = self.tokenizer.decode(
            ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return text

    @torch.no_grad()
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        B, V = scores.shape
        k = min(self.topk_tokens, V)
        topk_vals, topk_idx = torch.topk(scores, k=k, dim=1)

        candidate_texts = []
        for b in range(B):
            prefix_ids = input_ids[b].tolist()
            for j in range(k):
                tok_id = int(topk_idx[b, j])
                cand_ids = prefix_ids + [tok_id]
                text = self._decode_candidate(cand_ids)
                candidate_texts.append(text)

        candidate_embs = surrogate_encode(
            candidate_texts, None, self.surrogate_model, 32, self.device
        )

        if not torch.is_tensor(candidate_embs):
            candidate_embs = torch.tensor(candidate_embs)

        candidate_embs = F.normalize(candidate_embs.to(self.device), dim=-1)
        cos_sim = (candidate_embs @ self.target_latent.T).view(B, k, -1)
        cos_sim = torch.nan_to_num(cos_sim, nan=-1.0, posinf=1.0, neginf=-1.0)
        cos_sim = cos_sim.view(B, k)

        guided_scores = self.semantic_weight * cos_sim
        new_scores = torch.full_like(scores, -1e9)
        new_scores.scatter_(dim=1, index=topk_idx, src=guided_scores)

        eps = getattr(self, "eps", 1e-2)
        scores = new_scores + eps * scores

        not_finite = ~torch.isfinite(scores)
        if not_finite.any():
            scores = scores.masked_fill(not_finite, -1e9)

        row_all_bad = torch.isneginf(scores).all(dim=1)
        if row_all_bad.any():
            fallback_id = (
                getattr(self.tokenizer, "eos_token_id", None)
                or getattr(self.tokenizer, "pad_token_id", 0)
                or 0
            )
            scores[row_all_bad, fallback_id] = 0.0

        return scores

# =========================
#   Inverse Decoding Module
# =========================
def invert_latent_to_query(
    target_latent_vec,
    bart_tokenizer,
    bart_model,
    surrogate_model,
    min_new_tokens=79,
    max_new_tokens=80,
    num_beams=1,
    do_sample=True,
    temperature=1.0,
    top_p=0.9,
):

    logits_processor = SemanticGuidedLogitsProcessor(
        surrogate_model=surrogate_model,
        tokenizer=bart_tokenizer,
        target_latent_vec=target_latent_vec,
        semantic_weight=8.0,
        topk_tokens=256,
        verbose=True,
    )

    latent_prefix = target_latent_vec.unsqueeze(0)

    encoder_outputs = BaseModelOutput(last_hidden_state=latent_prefix)
    attention_mask = torch.ones(latent_prefix.shape[:2], device=latent_prefix.device)

    gen_ids = bart_model.generate(
        encoder_outputs=encoder_outputs,
        attention_mask=attention_mask,
        min_new_tokens=min_new_tokens,
        max_new_tokens=max_new_tokens,
        num_beams=num_beams,
        temperature=temperature,
        do_sample=do_sample,
        top_p=top_p,
        eos_token_id=bart_tokenizer.eos_token_id,
        pad_token_id=bart_tokenizer.pad_token_id,
        logits_processor=[logits_processor],
        early_stopping=False,
    )

    return bart_tokenizer.decode(gen_ids[0], skip_special_tokens=True)

# =========================
#   Local Gaussian Perturbation
# =========================
def gaussian_perturb(
    z: torch.Tensor,
    sigma=0.1,
):
    if z.dim() == 2:
        z = z.squeeze(0)
    noise = torch.randn_like(z) * sigma
    z_noisy = z + noise
    z_noisy = F.normalize(z_noisy, dim=0)
    return z_noisy.unsqueeze(0)

# =========================
#   Retrieved Context Parser
# =========================
def parse_retrieved_context(context):
    if not isinstance(context, str):
        try:
            context = str(context)
        except Exception:
            return []

    context = context.strip()
    if not context:
        return []

    prefix_pat = r'^\s*"?\s*(?:the\s+following\s+)?(?:relevant\s+)?information(?:\s+has\s+been)?\s+retrieved\s*:\s*'
    context = re.sub(prefix_pat, "", context, flags=re.I).strip()
    context = context.strip().strip('"').strip("'").strip()
    if not context:
        return []

    flat = re.sub(r"\s*\n\s*", " ", context).strip()
    pattern_inline = r"(?:(?<=^)|(?<=\s))(\d{1,3})\s*([:.)])\s*"
    it = list(re.finditer(pattern_inline, flat))

    if len(it) >= 1:
        nums = [int(m.group(1)) for m in it]
        ok = (nums[0] == 1) and all(nums[i] == nums[i-1] + 1 for i in range(1, len(nums)))
        if ok:
            parts = []
            for i, m in enumerate(it):
                start = m.end()
                end = it[i + 1].start() if i + 1 < len(it) else len(flat)
                seg = flat[start:end].strip()
                if seg:
                    parts.append(seg)
            if parts:
                return parts

    pattern_line = r"(?m)^\s*\d+\s*(?:[.:)]\s*|\s+)(.*?)(?=\n\s*\d+\s*(?:[.:)]\s*|\s+)|\Z)"
    matches = re.findall(pattern_line, context, flags=re.S)
    if matches:
        return [m.strip() for m in matches if m.strip()]

    if "\n\n" in context:
        parts = re.split(r"\n\s*\n", context)
        return [p.strip() for p in parts if p.strip()]

    return [context] if context else []

# =========================
#   Global Orthogonal Planner
# =========================
def synthesize_orthogonal_directions(
    context_latent,
    surrogate_name,
    num_directions=4,
    eps=1e-9,
    max_tries=10000,
):
    if context_latent.dim() == 2:
        context_latent = context_latent[0]
    elif context_latent.dim() != 1:
        raise ValueError(f"context_latent must be [D] or [B,D], got {tuple(context_latent.shape)}")
    D = context_latent.numel()
    if num_directions > D - 1:
        raise ValueError(f"Need k <= D-1. Got k={num_directions}, D={D}")
    context_latent = context_latent / (context_latent.norm() + eps)
    direction_set = []
    tries = 0

    while len(direction_set) < num_directions and tries < max_tries:
        tries += 1
        r = torch.randn_like(context_latent)
        r = r - torch.dot(r, context_latent) * context_latent

        for d in direction_set:
            r = r - torch.dot(r, d) * d
        n = r.norm()
        if n < 1e-6:
            continue

        r = r / (n + eps)

        if surrogate_name == "BAAI/bge-base-en-v1.5":
            direction = r
        else:
            direction = r * 1.5
        direction_set.append(direction)
    return torch.stack(direction_set, dim=0)


def save_run_artifacts(
    chunk_list,
    coverage_curve,
    query_list,
    config_record,
    retrieved_chunks,
    retrieved_chunk_ids,
    out_dir,
):
    os.makedirs(out_dir, exist_ok=True)

    with open(os.path.join(out_dir, "chunks_iter.json"), "w", encoding="utf-8") as f:
        json.dump(chunk_list, f, ensure_ascii=False, indent=4)

    with open(os.path.join(out_dir, "coverage.json"), "w", encoding="utf-8") as f:
        json.dump(coverage_curve, f, indent=4)

    with open(os.path.join(out_dir, "query_list.json"), "w", encoding="utf-8") as f:
        json.dump(query_list, f, indent=4)

    with open(os.path.join(out_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(config_record, f, indent=2, default=str)

    with open(os.path.join(out_dir, "retrieved_chunks.json"), "w", encoding="utf-8") as f:
        json.dump(retrieved_chunks, f, indent=2, default=str)

    with open(os.path.join(out_dir, "retrieved_ids.json"), "w", encoding="utf-8") as f:
        json.dump(retrieved_chunk_ids, f, indent=2)