import json
from dataclasses import dataclass
from typing import List, Tuple, Set
import torch
from sentence_transformers import SentenceTransformer

@dataclass
class RunConfig:
    run_id: int
    step_size: int
    threshold: float

def load_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def normalize_chunk_texts(chunks: List[str]) -> List[str]:
    cleaned = []
    for s in chunks:
        s = s.split("Question_start")[0]
        s = s.split("Query")[0]
        cleaned.append(s)
    return cleaned

def build_flat_corpus(merged_corpus: List[List[str]]) -> List[str]:
    flat = []
    for group in merged_corpus:
        flat.extend(group)
    return flat

def encode_texts(
    encoder: SentenceTransformer,
    texts: List[str],
    device: str,
    batch_size: int = 1,
) -> torch.Tensor:
    return encoder.encode(
        texts,
        batch_size=batch_size,
        convert_to_tensor=True,
        device=device,
        normalize_embeddings=True,
    ).detach().clone()

def greedy_unique_match(
    retrieved_embs: torch.Tensor,
    corpus_embs: torch.Tensor,
    threshold: float,
) -> Tuple[torch.Tensor, List[Tuple[int, int, float]]]:

    sim_matrix = torch.matmul(retrieved_embs, corpus_embs.T)
    n_retrieved, n_corpus = sim_matrix.shape

    used = torch.zeros(n_corpus, dtype=torch.bool, device=sim_matrix.device)
    matched_pairs: List[Tuple[int, int, float]] = []

    for i in range(n_retrieved):
        sims = sim_matrix[i].clone()
        sims[used] = -1.0
        max_sim, max_idx = sims.max(dim=0)

        if max_sim.item() > threshold:
            used[max_idx] = True
            matched_pairs.append((i, int(max_idx.item()), float(max_sim.item())))

    return used, matched_pairs

def slice_chunks_by_steps(
    chunks: List[str],
    steps: List[List[int]],
    step_size: int
) -> Tuple[List[str], List[str], List[List[int]]]:

    if len(steps) == 0:
        return [], [], []

    max_chunks = int(steps[-1][-1])
    chunks_full = chunks[:max_chunks]

    n_steps = int(100 / step_size)
    steps_eff = steps[:n_steps] if n_steps > 0 else []

    if len(steps_eff) == 0:
        chunks_100 = []
    else:
        max_chunks_100 = int(steps_eff[-1][-1])
        chunks_100 = chunks_full[:max_chunks_100]

    return chunks_full, chunks_100, steps_eff

def run_metrics_for_one_exp(
    cfg: RunConfig,
    encoder: SentenceTransformer,
    device: str,
    corpus_texts_flat: List[str],
    corpus_embs: torch.Tensor,
    record_dir_template: str,
    total_budget: int = 2000,
    eff_budget: int = 200,
) -> None:

    record_dir = record_dir_template.format(run_id=cfg.run_id)

    raw_chunks = load_json(f"{record_dir}/chunks_iter.json")
    raw_steps = load_json(f"{record_dir}/bb.json")
    raw_ids = load_json(f"{record_dir}/ids.json")

    cleaned_chunks = normalize_chunk_texts(raw_chunks)
    id_set: Set[int] = set(raw_ids)

    chunks_full, chunks_100, steps_eff = slice_chunks_by_steps(
        cleaned_chunks, raw_steps, cfg.step_size
    )

    unique_chunks_full = list(set(chunks_full))
    unique_chunks_100 = list(set(chunks_100))

    retrieved_embs_full = encode_texts(encoder, unique_chunks_full, device)
    retrieved_embs_100 = encode_texts(encoder, unique_chunks_100, device)

    used_full, _ = greedy_unique_match(retrieved_embs_full, corpus_embs, cfg.threshold)
    used_100, _ = greedy_unique_match(retrieved_embs_100, corpus_embs, cfg.threshold)

    n_corpus = corpus_embs.shape[0]
    used_full_count = int(used_full.sum().item())
    used_100_count = int(used_100.sum().item())

    greedy_coverage = used_full_count / total_budget * 100.0
    id_coverage = len(id_set) / n_corpus * 100.0
    efficiency = used_100_count / eff_budget * 100.0

    print(f"run_id={cfg.run_id}")
    print(f"matched_chunks={used_full_count}")
    print(f"coverage_greedy={greedy_coverage:.2f}%")
    print(f"coverage_ids={id_coverage:.2f}%")
    print(f"efficiency={efficiency:.2f}%")
    print()

def main():

    device = "cuda:0"
    encoder_name = "sentence-transformers/clip-ViT-B-32"

    corpus_path = "/path/to/merged_corpus.json"
    record_dir_template = "/path/to/run_records/run_{run_id}"

    id="Your_exp_id"
    step="Your_exp_step"
    experiments = [
        RunConfig(run_id=id, step_size=int(step), threshold=0.88),
    ]

    encoder = SentenceTransformer(encoder_name, device=device).eval()

    merged_corpus = load_json(corpus_path)
    corpus_texts_flat = build_flat_corpus(merged_corpus)
    corpus_embs = encode_texts(encoder, corpus_texts_flat, device)

    for cfg in experiments:
        run_metrics_for_one_exp(
            cfg=cfg,
            encoder=encoder,
            device=device,
            corpus_texts_flat=corpus_texts_flat,
            corpus_embs=corpus_embs,
            record_dir_template=record_dir_template,
        )
if __name__ == "__main__":
    main()
