from loguru import logger
import json
from collections import defaultdict
from statistics import mean
from time import perf_counter
from typing import List, Any, Dict, Tuple
from uuid import uuid4
import time
import numpy as np
from pydantic import BaseModel, Field
from sklearn.cluster import KMeans
from typing import NamedTuple
import jsonlines
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import math
import pandas as pd

class ScoredPoint(NamedTuple):
    vector: list[float]
    payload: dict


def convert_records_to_scored_points_with_encoder(
    records: List[Dict[str, Any]],
    embedding_model,
) -> List[ScoredPoint]:
    scored_points: List[ScoredPoint] = []
    for chunk_id, rec in enumerate(records, start=1):
        text = rec.get("text", "")
        embedding = embedding_model.encode(text)
        payload = {
            "content": text,
            "title": rec.get("title"),
            "id": rec.get("id"),
            "chunk_id": chunk_id,
        }
        scored_points.append(ScoredPoint(vector=embedding, payload=payload))
    return scored_points


def multi_perspective_sampling(
    k: int,
    retrieved_points: List[ScoredPoint],
    num_subsets: int,
    seed: int = 1399,
) -> List[List[str]]:
    """
    After performing k-means clustering on the semantic vectors, 
    we conduct random sampling with replacement within each cluster for num_subsets times, 
    generating num_subsets subsets. Each subset consists of k documents, 
    with one document drawn from each distinct cluster.
    """
    algo = KMeans(n_clusters=k, random_state=seed)
    vectors = [point.vector for point in retrieved_points]
    clusters = algo.fit_predict(X=vectors)

    cluster_dict: defaultdict[int, List[int]] = defaultdict(list)
    for idx, c in enumerate(clusters):
        cluster_dict[c].append(idx)

    m = num_subsets
    logger.info("Will create {m} subsets.", m=m)

    np.random.seed(seed)
    subsets: List[List[str]] = []
    unique_clusters = set(clusters)
    for _ in range(m):
        subset_idxs = [np.random.choice(cluster_dict[c]) for c in unique_clusters]
        subset_docs = [
            f" {retrieved_points[i].payload['title']}\n{retrieved_points[i].payload['content']}"
            for idx, i in enumerate(subset_idxs)
        ]   

        subsets.append(subset_docs)

    return subsets


rag_drafting_prompt = """
<s>[INST]
## Instruction: 
Use the evidence documents to answer the following question. If the documents do not provide enough information, try to answer with your own knowledge and clearly indicate that this is not directly supported by the documents.

---

**Documents:
{docs}

### Question:
{question}

## Response:
[/INST]
"""


def calculate_perplexity(model, tokenizer, text: str, verify_device: str) -> float:
    inputs = tokenizer(text, return_tensors="pt").to(verify_device)
    ids = inputs["input_ids"]
    with torch.no_grad():
        logits = model(ids).logits
    shifted_logits = logits[:, :-1, :]
    shifted_labels = ids[:, 1:]
    log_probs = torch.nn.functional.log_softmax(shifted_logits, dim=-1)
    tok_log = torch.gather(log_probs, dim=-1, index=shifted_labels.unsqueeze(-1)).squeeze(-1)
    avg_log = tok_log.mean(dim=-1).item()
    return math.exp(-avg_log)


def select_final_chunk_by_centrality(
    embedding_model,
    candidate_answers: List[str],
    weights: List[float],
) -> Tuple[int, str]:
    """
    For each candidate answer, compute the pairwise cosine similarity matrix U (m x m).
    For each row i, compute the score as:
    score_i = sum over j of U[i][j] * weights[j]
    Return the index with the highest score and the corresponding candidate_answers[i].
    """
    print("[Centrality] Loading embedding model...")
    model = embedding_model
    print("[Centrality] Encoding candidate answers...")
    emb = model.encode(candidate_answers, convert_to_tensor=True)
    print("[Centrality] Computing similarity matrix...")
    sim_matrix = util.cos_sim(emb, emb)
    print("[Centrality] Applying weights and scoring...")
    weights_tensor = torch.tensor(weights, dtype=sim_matrix.dtype, device=sim_matrix.device)
    scores = (sim_matrix * weights_tensor.unsqueeze(0)).sum(dim=1)
    best_idx = int(torch.argmax(scores).item())
    best_score = float(scores[best_idx].item())
    print(f"[Centrality] Best candidate index: {best_idx}, score: {best_score:.4f}")
    return best_idx, candidate_answers[best_idx]




def rag_drafting_generator_local(
    num_cluster,
    num_max_new_token,
    embedding_model,
    draft_model,
    draft_tokenizer,
    question: str,
    docs: List[Dict[str, Any]],
    # generated_text: str,
    num_subsets: int,
) -> Dict[str, Any]:
    # 1) convert docs to ScoredPoint
    retrieved_points = convert_records_to_scored_points_with_encoder(docs, embedding_model)

    # 2) cluster and sample
    start_time = perf_counter()
    subsets = multi_perspective_sampling(
        k=num_cluster,
        retrieved_points=retrieved_points,
        num_subsets=num_subsets,
        seed=1399,
    )
    kmeans_time = perf_counter() - start_time

    # 3) format subsets
    formatted_subsets = [
        "\n".join(f"{i+1}. {doc}" for i, doc in enumerate(subset))
        for subset in subsets
    ]

    responses = []
    for idx, subset in enumerate(formatted_subsets, 1):
        print(f"--- Generating draft {idx}/{len(formatted_subsets)} ---")
        itr_start = perf_counter()
        prompt = rag_drafting_prompt.format(
            question=question,
            docs=subset,
        )
        input_ids = draft_tokenizer(prompt, return_tensors="pt").input_ids.to(draft_model.device)
        with torch.no_grad():
            output_ids = draft_model.generate(
                input_ids=input_ids,
                do_sample=False,
                max_new_tokens=num_max_new_token,
                repetition_penalty=1.1,
                eos_token_id=draft_tokenizer.eos_token_id,
            )
        new_tokens = output_ids[0][len(input_ids[0]):]
        new_text = draft_tokenizer.decode(new_tokens, skip_special_tokens=True)
        candidate_answer = new_text
        itr_time = perf_counter() - itr_start
        total_time = kmeans_time + itr_time
        responses.append({
            "subset_idx": idx,
            "subset": subset,
            "time": total_time,
            "new_text": new_text,
            "new_tokens": new_tokens,
            "candidate_answer": candidate_answer,
        })

    print("--- Selecting final chunk by rerank ---")
    candidate_answers = [r["candidate_answer"] for r in responses]
    weights = [1.0] * len(candidate_answers)
    select_start_time = perf_counter()

    best_idx, best_answer = select_final_chunk_by_centrality(
        embedding_model=embedding_model, candidate_answers=candidate_answers, weights=weights
    )
    select_time = perf_counter() - select_start_time
    print(f"Final selected draft index: {best_idx}")

    return {
        "responses": responses,
        "best_index": best_idx,
        "best_answer": best_answer,
        "select_time": select_time,
    }
