from loguru import logger
import json
from collections import defaultdict
from statistics import mean
from time import perf_counter
from typing import List, Any, Dict, Tuple
from uuid import uuid4
import time
import numpy as np
from pydantic import BaseModel, Field
from sklearn.cluster import KMeans
from typing import NamedTuple
import jsonlines
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import math
import pandas as pd

class ScoredPoint(NamedTuple):
    vector: list[float]
    payload: dict


def convert_records_to_scored_points_with_encoder(
    records: List[Dict[str, Any]],
    embedding_model,
) -> List[ScoredPoint]:
    scored_points: List[ScoredPoint] = []
    for chunk_id, rec in enumerate(records, start=1):
        text = rec.get("text", "")
        embedding = embedding_model.encode(text)
        payload = {
            "content": text,
            "title": rec.get("title"),
            "id": rec.get("id"),
            "chunk_id": chunk_id,
        }
        scored_points.append(ScoredPoint(vector=embedding, payload=payload))
    return scored_points


def multi_perspective_sampling(
    k: int,
    retrieved_points: List[ScoredPoint],
    num_subsets: int,
    seed: int = 1399,
) -> List[List[str]]:
    """
    After performing k-means clustering on the semantic vectors, 
    we conduct random sampling with replacement within each cluster for num_subsets times, 
    generating num_subsets subsets. Each subset consists of k documents, 
    with one document drawn from each distinct cluster.
    """
    algo = KMeans(n_clusters=k, random_state=seed)
    vectors = [point.vector for point in retrieved_points]
    clusters = algo.fit_predict(X=vectors)

    cluster_dict: defaultdict[int, List[int]] = defaultdict(list)
    for idx, c in enumerate(clusters):
        cluster_dict[c].append(idx)

    m = num_subsets
    logger.info("Will create {m} subsets.", m=m)

    np.random.seed(seed)
    subsets: List[List[str]] = []
    unique_clusters = set(clusters)
    for _ in range(m):
        subset_idxs = [np.random.choice(cluster_dict[c]) for c in unique_clusters]
        subset_docs = [
            f" {retrieved_points[i].payload['title']}\n{retrieved_points[i].payload['content']}"
            for idx, i in enumerate(subset_idxs)
        ]   

        subsets.append(subset_docs)

    return subsets




def select_final_chunk_by_centrality(
    embedding_model,
    candidate_answers: List[str],
    weights: List[float],
) -> Tuple[int, str]:
    """
    For each candidate answer, compute the pairwise cosine similarity matrix U (m x m).
    For each row i, compute the score as:
    score_i = sum over j of U[i][j] * weights[j]
    Return the index with the highest score and the corresponding candidate_answers[i].
    """
    print("[Centrality] Loading embedding model...")
    model = embedding_model
    print("[Centrality] Encoding candidate answers...")
    emb = model.encode(candidate_answers, convert_to_tensor=True)
    print("[Centrality] Computing similarity matrix...")
    sim_matrix = util.cos_sim(emb, emb)
    print("[Centrality] Applying weights and scoring...")
    weights_tensor = torch.tensor(weights, dtype=sim_matrix.dtype, device=sim_matrix.device)
    scores = (sim_matrix * weights_tensor.unsqueeze(0)).sum(dim=1)
    best_idx = int(torch.argmax(scores).item())
    best_score = float(scores[best_idx].item())
    print(f"[Centrality] Best candidate index: {best_idx}, score: {best_score:.4f}")
    return best_idx, candidate_answers[best_idx]




def build_arc_rag_prompt_from_fields(question, labels, choices, ctxs):
    """
    Parameters:
    - question: str, the question text
    - labels: List[str], e.g. ["A", "B", "C", "D"]
    - choices: List[str], choice content corresponding to each label
    - ctxs: List[Dict], each with "title" and "text" fields

    Returns:
    - prompt: str, fully formatted prompt for RAG-style answering
    """

    # Build the choices section
    choices_str = "Choices:\n"
    for label, text in zip(labels, choices):
        choices_str += f"{label}. {text}\n"

    # Build the documents section
    # Final prompt string
    prompt = f"""<s>[INST]
### Instruction: 
Use the evidence documents to answer the following question.  
Choose the the best answer choice from the given choices (A, B, C, or D).
---

**Question:  
{question}  

{choices_str}  

---
**Evidence Documents: 

{ctxs}
---
**Response:
[/INST]
"""
    return prompt





def rag_drafting_generator_local(
    num_cluster,
    num_max_new_token,
    embedding_model,
    draft_model,
    draft_tokenizer,
    question: str,
    docs: List[Dict[str, Any]],
    labels,
    choices,
    num_subsets: int,
) -> Dict[str, Any]:
    # 1) convert docs to ScoredPoint
    retrieved_points = convert_records_to_scored_points_with_encoder(docs, embedding_model)

    # 2) clusting and sampling
    start_time = perf_counter()
    subsets = multi_perspective_sampling(
        k=num_cluster,
        retrieved_points=retrieved_points,
        num_subsets=num_subsets,
        seed=1399,
    )
    kmeans_time = perf_counter() - start_time

    formatted_subsets = [
        "\n".join(f"{i+1}. {doc}" for i, doc in enumerate(subset))
        for subset in subsets
    ]

    responses = []
    for idx, subset in enumerate(formatted_subsets, 1):
        print(f"--- Generating draft {idx}/{len(formatted_subsets)} ---")
        itr_start = perf_counter()
        prompt = build_arc_rag_prompt_from_fields(question = question,labels = labels, choices = choices, ctxs = subset)
        input_ids = draft_tokenizer(prompt, return_tensors="pt").input_ids.to(draft_model.device)
        with torch.no_grad():
            output_ids = draft_model.generate(
                input_ids=input_ids,
                do_sample=False,
                max_new_tokens=num_max_new_token,
                repetition_penalty=1.1,
                eos_token_id=draft_tokenizer.eos_token_id,
            )
        new_tokens = output_ids[0][len(input_ids[0]):]
        new_text = draft_tokenizer.decode(new_tokens, skip_special_tokens=True)

        candidate_answer = new_text
        itr_time = perf_counter() - itr_start
        total_time = kmeans_time + itr_time
        responses.append({
            "subset_idx": idx,
            "subset": subset,
            "time": total_time,
            "new_text": new_text,
            "new_tokens": new_tokens,
            "candidate_answer": candidate_answer,
        })


    print("--- Selecting final chunk by centrality ---")
    candidate_answers = [r["candidate_answer"] for r in responses]
    weights = [1.0] * len(candidate_answers)
    select_start_time = perf_counter()

    best_idx, best_answer = select_final_chunk_by_centrality(
        embedding_model=embedding_model, candidate_answers=candidate_answers, weights=weights
    )
    select_time = perf_counter() - select_start_time
    print(f"Final selected draft index: {best_idx}")

    return {
        "responses": responses,
        "best_index": best_idx,
        "best_answer": best_answer,
        "select_time": select_time,
    }
