import random
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
# import torch.nn.functional as F # Not explicitly used now
import json
from tqdm import tqdm
from collections import Counter
import argparse
import math
import os
import csv # For CSV output
from sentence_transformers import SentenceTransformer
# Heads import (ensure this path is correct in your environment)
from heads import get_matching_head

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True, device=DEVICE)
        self.embedding_model.eval()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.matching_head = get_matching_head("cos_sim", embedding_dim)
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt", map_location=DEVICE))
        self.matching_head = self.matching_head.to(DEVICE)
        self.matching_head.eval()

        self.tokenid2emb = self._build_token_embedding_cache(model_dir)

    def _build_token_embedding_cache(self, model_dir):
        cache_path = os.path.join(model_dir, "tokenid_embedding_cache.pt")
        if os.path.exists(cache_path):
            print(f"📦 Loading token ID embedding cache from {cache_path} ...")
            tokenid2emb_raw = torch.load(cache_path, map_location=DEVICE)
            return {int(token_id): emb.to(DEVICE) for token_id, emb in tokenid2emb_raw.items()}
        else:
            print("⚙️ Building token ID embedding index from tokenizer vocab...")
            tokenizer = self.embedding_model.tokenizer
            vocab = tokenizer.get_vocab()
            filtered_items = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip()]

            tokens = [x[0] for x in filtered_items]
            ids = [x[1] for x in filtered_items]
            
            token_embs = self.embedding_model.encode(tokens, convert_to_tensor=True, show_progress_bar=True, device=DEVICE)
            tokenid2emb = {int(i): emb for i, emb in zip(ids, token_embs)}

            torch.save(tokenid2emb, cache_path)
            print(f"✅ Cached embeddings for {len(tokenid2emb)} token ids to {cache_path}")
            return {int(k): v.to(DEVICE) for k, v in tokenid2emb.items()}

    def encode(self, text: str) -> torch.Tensor:
        return self.embedding_model.encode(text, convert_to_tensor=True, device=DEVICE)

    def score(self, emb_a: torch.Tensor, emb_b: torch.Tensor) -> float:
        emb_a = emb_a.to(DEVICE)
        emb_b = emb_b.to(DEVICE)
        features = {
            "embedding_a": emb_a.unsqueeze(0),
            "embedding_b": emb_b.unsqueeze(0)
        }
        with torch.no_grad():
            logits = self.matching_head(features)["logits"]
            return torch.sigmoid(logits).item()

    @torch.no_grad()
    def predict_batch(self, answers, reasons, batch_size=32): # Increased default batch size
        assert len(answers) == len(reasons)
        all_probs = []
        for idx in range(0, len(answers), batch_size):
            batch_answers = answers[idx:idx+batch_size]
            batch_reasons = reasons[idx:idx+batch_size]

            emb_a = self.embedding_model.encode(batch_answers, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)
            emb_b = self.embedding_model.encode(batch_reasons, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = self.matching_head(features)
            logits = outputs["logits"].squeeze(-1)
            probs = torch.sigmoid(logits)
            all_probs.extend(probs.tolist())
        return all_probs

    def score_pair(self, reason: str, answer: str) -> float:
        probs = self.predict_batch([answer], [reason], batch_size=1)
        return probs[0]


# -------- 模型适配接口 (for score_a calculation) --------
def make_model_a_scorer(infer_a: MatchingInference):
    """ Returns a function that scores a block based on token IDs and sentence text, and the tokenizer. """
    tokenizer = infer_a.embedding_model.tokenizer # Get tokenizer once
    def model_a_score(token_ids: List[int], sentence_text: str) -> float:
        valid_embs = [infer_a.tokenid2emb[tid] for tid in token_ids if tid in infer_a.tokenid2emb]
        if not valid_embs:
            return 0.0
        token_emb = torch.mean(torch.stack(valid_embs), dim=0)
        sent_emb = infer_a.encode(sentence_text)
        
        token_emb = token_emb.to(DEVICE)
        sent_emb = sent_emb.to(DEVICE)

        features = {
            "embedding_a": token_emb.unsqueeze(0),
            "embedding_b": sent_emb.unsqueeze(0)
        }
        with torch.no_grad():
            logits = infer_a.matching_head(features)["logits"]
            score = torch.sigmoid(logits).item()
        return score
    return model_a_score, tokenizer

# -------- Rule-based Sufficiency Check --------
def _check_sufficiency_rule_based(
    current_scores_a: List[float],
    current_scores_b: List[float],
    verification_threshold: float
) -> Tuple[bool, float, float]: # Returns: is_sufficient, avg_score_a, avg_score_b
    """
    Calculates the average scores from two lists of scores
    and checks if BOTH meet the verification threshold individually.
    """
    if not current_scores_a or not current_scores_b: # Handles cases with no valid scores after processing blocks
        return False, 0.0, 0.0
    
    if len(current_scores_a) != len(current_scores_b):
        # This case should ideally not happen if pre-processing logic is correct (pairs are always added)
        print(f"Warning: Mismatch in score list lengths for rule-based verifier. A: {len(current_scores_a)}, B: {len(current_scores_b)}")
        # Decide how to handle: could return False or try to average based on min length. Returning False is safer.
        return False, 0.0, 0.0

    avg_score_a = sum(current_scores_a) / len(current_scores_a)
    avg_score_b = sum(current_scores_b) / len(current_scores_b)
    
    sufficient = (avg_score_a >= verification_threshold and avg_score_b >= verification_threshold)
    
    return sufficient, avg_score_a, avg_score_b

# -------- Pre-computation and Validation Logic --------
def preprocess_samples_for_efficient_validation(
    data: List[Dict],
    infer_a: MatchingInference,
    infer_b: MatchingInference,
    token_ratio: float
) -> List[Dict]:
    """
    Pre-processes all samples to compute block scores (score_a, score_b) once.
    Each item in the returned list will have:
    'label': original label
    'block_scores': list of {'a': score_a, 'b': score_b} for each valid scorable block
    'alpha': number of valid (non-empty after filtering) blocks before scoring attempts
    """
    print("🚀 Pre-processing samples to compute block scores (rule-based)...")
    preprocessed_data = []
    model_a_scorer, tokenizer_a = make_model_a_scorer(infer_a)

    for item in tqdm(data, desc="Pre-calculating block scores"):
        P, R_sentences, A_text = item["P"], item["R"], item["A"]

        filtered_sentences = [s for s in R_sentences if s and s.strip()]
        
        # Store alpha (number of potentially scorable blocks after initial filtering)
        # Shuffling happens here, so the order of blocks in 'block_scores' will be random
        # but consistent for all (threshold, ratio) pairs for this sample.
        random.shuffle(filtered_sentences) 
        alpha_count = len(filtered_sentences)

        if not filtered_sentences:
            preprocessed_data.append({
                "label": item.get("label"),
                "block_scores": [],
                "alpha": 0 
            })
            continue
        
        current_sample_block_scores = []
        for block_text in filtered_sentences:
            encoding = tokenizer_a(block_text, add_special_tokens=False, return_tensors='pt')
            block_token_ids = encoding["input_ids"][0].tolist()

            if not block_token_ids:
                continue # Skip this block if it has no tokens after tokenization

            sample_size = max(1, int(len(block_token_ids) * token_ratio))
            sample_size = min(sample_size, len(block_token_ids))
            selected_ids = random.sample(block_token_ids, sample_size)
            
            try:
                score_a = model_a_scorer(selected_ids, block_text)
                score_b = infer_b.score_pair(block_text, A_text) # A_text is the answer for the item
                current_sample_block_scores.append({'a': score_a, 'b': score_b})
            except Exception as e:
                # print(f"Warning: Error scoring block '{block_text[:30]}...': {e}. Skipping block.")
                # This block won't be added to current_sample_block_scores
                continue
        
        preprocessed_data.append({
            "label": item.get("label"),
            "block_scores": current_sample_block_scores, # Only successfully scored blocks
            "alpha": alpha_count # Total potential blocks for this sample after initial text filtering
        })
    return preprocessed_data


def run_validation_on_preprocessed_rule_based(
    preprocessed_sample_data: Dict,
    probing_ratio: float,
    verification_threshold: float
) -> bool:
    """
    Runs the rule-based validation logic for a single sample using its pre-computed block scores.
    """
    block_scores = preprocessed_sample_data["block_scores"] # List of {'a': score, 'b': score}
    alpha = preprocessed_sample_data["alpha"] # Number of initially filtered (and shuffled) blocks

    if not block_scores: # No blocks were successfully scored during preprocessing
        return False

    # num_initial_blocks_to_check is based on alpha (total potential blocks),
    # and it's the target number of *successfully scored blocks* for an early stop check.
    num_initial_blocks_to_check_target = min(alpha, math.ceil(max(1.0, probing_ratio * alpha)))
    if alpha == 0: # if no sentences after filtering, probing_ratio * 0 = 0. Max(1,0) = 1. min(0,1)=0.
        num_initial_blocks_to_check_target = 0
    
    accumulated_scores_a = []
    accumulated_scores_b = []
    pred_sufficient = False

    # Iterate through the successfully pre-scored blocks
    for i, score_pair in enumerate(block_scores):
        accumulated_scores_a.append(score_pair['a'])
        accumulated_scores_b.append(score_pair['b'])
        
        # `i+1` is the count of *successfully scored blocks* processed so far from the pre-scored list.
        # Check for early stopping if we've processed enough successfully scored blocks
        # to meet the `num_initial_blocks_to_check_target`.
        if (i + 1) >= num_initial_blocks_to_check_target and num_initial_blocks_to_check_target > 0 :
            is_sufficient, _, _ = _check_sufficiency_rule_based(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            if is_sufficient:
                pred_sufficient = True
                break 
    
    # If not early stopped and sufficient, or if loop finished, do a final check
    if not pred_sufficient:
        if accumulated_scores_a: # Ensure there's something to check
            is_sufficient, _, _ = _check_sufficiency_rule_based(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            pred_sufficient = is_sufficient
        else: # No scores were accumulated (e.g., block_scores was empty or all errored - caught earlier by `if not block_scores`)
              # This path should ideally not be hit if block_scores is non-empty.
            pred_sufficient = False
            
    return pred_sufficient

# -------- 主函数入口 --------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rule-based validation with efficient parameter sweeping.")
    parser.add_argument("--model_a_dir", type=str, required=True, help="Directory for model A components.")
    parser.add_argument("--model_b_dir", type=str, required=True, help="Directory for model B components (used for R-A scoring).")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the input JSON data file.")
    parser.add_argument("--csv_output_path", type=str, required=True, help="Path to save the CSV results.")
    parser.add_argument("--token_ratio", type=float, default=0.1, help="Theta: Ratio of tokens to sample within a block for model A's score_a.")
    
    parser.add_argument("--verification_thresholds", type=float, nargs='+', required=True, help="List of verification thresholds (for avg_score_a AND avg_score_b).")
    parser.add_argument("--probing_ratios", type=float, nargs='+', required=True, help="List of probing ratios (gamma) to test for early stopping.")
    parser.add_argument("--max_samples", type=int, default=None, help="Process a maximum number of samples for quick testing (e.g., 100). Default: all.")

    args = parser.parse_args()

    print(f"Using device: {DEVICE}")

    print("🚀 Initializing Inference Engines (MatchingInference)...")
    infer_a = MatchingInference(args.model_a_dir)
    infer_b = MatchingInference(args.model_b_dir) # infer_b is used for R-A scoring (score_b)
    print("✅ Inference Engines Ready.")

    with open(args.data_path, "r", encoding="utf-8") as fin:
        all_data = json.load(fin)

    if args.max_samples is not None and args.max_samples > 0:
        print(f"🔪 Using a subset of {args.max_samples} samples for processing.")
        data_subset = all_data[:args.max_samples]
    else:
        data_subset = all_data
    
    # --- Stage 1: Pre-computation of scores for all samples ---
    preprocessed_samples = preprocess_samples_for_efficient_validation(
        data_subset, infer_a, infer_b, args.token_ratio
    )

    results_for_csv = []

    print(f"\n⚙️ Starting rule-based multi-parameter validation for {len(args.verification_thresholds)} thresholds and {len(args.probing_ratios)} probing ratios...")
    
    # --- Stage 2: Iterate through parameter combinations and evaluate using pre-computed scores ---
    for v_thresh in tqdm(args.verification_thresholds, desc="Thresholds"):
        for p_ratio in tqdm(args.probing_ratios, desc="Probing Ratios", leave=False):
            correct_predictions = 0
            total_labeled_samples_for_acc = 0 # Count only samples with labels for acc calc
            
            for sample_data in preprocessed_samples:
                label = sample_data["label"]
                
                if label is None: # Skip samples without labels for accuracy calculation
                    continue 
                total_labeled_samples_for_acc += 1
                
                pred_is_sufficient = run_validation_on_preprocessed_rule_based(
                    sample_data, p_ratio, v_thresh
                )
                
                if pred_is_sufficient == label: # Assuming label is boolean (True/False)
                    correct_predictions += 1
            
            accuracy = (correct_predictions / total_labeled_samples_for_acc) if total_labeled_samples_for_acc > 0 else 0.0
            
            results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "accuracy": f"{accuracy:.4f}",
                "correct_predictions": correct_predictions,
                "total_labeled_samples": total_labeled_samples_for_acc
            })

    # --- Stage 3: Output results to CSV ---
    if results_for_csv:
        fieldnames = results_for_csv[0].keys()
        with open(args.csv_output_path, "w", newline='', encoding="utf-8") as fout_csv:
            writer = csv.DictWriter(fout_csv, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(results_for_csv)
        print(f"\n✅ Accuracy results for different combinations saved to {args.csv_output_path}")
    else:
        print("\n⚠️ No results to save to CSV. Check data or parameters (ensure labeled data exists).")

    print("\n----- Rule-Based Multi-Parameter Validation Summary -----")
    for res in results_for_csv:
        print(f"Threshold: {res['verification_threshold']:.3f}, Probing Ratio: {res['probing_ratio']:.2f}, Accuracy: {res['accuracy']} ({res['correct_predictions']}/{res['total_labeled_samples']})")