import random
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
# import torch.nn.functional as F
import json
from tqdm import tqdm
from collections import Counter
import argparse
import math
import os
import torch.nn as nn
import csv # For CSV output

class RNNClassifier(nn.Module):
    def __init__(self, hidden_dim=256, rnn_type='GRU'):
        super().__init__()
        self.rnn_type = rnn_type.upper()
        if self.rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_dim, batch_first=True)
        else:
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden_dim, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, lengths):
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        if self.rnn_type == 'LSTM':
            packed_out, (hn, cn) = self.rnn(packed_x)
        else:
            packed_out, hn = self.rnn(packed_x)
        last_hidden = hn[-1]
        out = self.classifier(last_hidden)
        return torch.sigmoid(out).squeeze(-1)


class DeepSetClassifier(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, lengths):
        phi_x = self.phi(x)
        mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1)
        mask = mask.unsqueeze(-1)
        phi_x = phi_x * mask
        agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1)
        out = self.rho(agg)
        return torch.sigmoid(out).squeeze(-1)

# Global verifier, initialized once
verifier_model_global: nn.Module = None # Use nn.Module for better type hinting
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Heads import (ensure this path is correct in your environment)
from heads import get_matching_head # Assuming this import works in your setup

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True, device=DEVICE)
        self.embedding_model.eval()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.matching_head = get_matching_head("cos_sim", embedding_dim)
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt", map_location=DEVICE))
        self.matching_head = self.matching_head.to(DEVICE)
        self.matching_head.eval()

        self.tokenid2emb = self._build_token_embedding_cache(model_dir)

    def _build_token_embedding_cache(self, model_dir):
        cache_path = os.path.join(model_dir, "tokenid_embedding_cache.pt")
        if os.path.exists(cache_path):
            print(f"📦 Loading token ID embedding cache from {cache_path} ...")
            tokenid2emb_raw = torch.load(cache_path, map_location=DEVICE)
            return {int(token_id): emb.to(DEVICE) for token_id, emb in tokenid2emb_raw.items()}
        else:
            print("⚙️ Building token ID embedding index from tokenizer vocab...")
            tokenizer = self.embedding_model.tokenizer
            vocab = tokenizer.get_vocab()
            filtered_items = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip()]

            tokens = [x[0] for x in filtered_items]
            ids = [x[1] for x in filtered_items]
            
            token_embs = self.embedding_model.encode(tokens, convert_to_tensor=True, show_progress_bar=True, device=DEVICE)
            tokenid2emb = {int(i): emb for i, emb in zip(ids, token_embs)}

            torch.save(tokenid2emb, cache_path)
            print(f"✅ Cached embeddings for {len(tokenid2emb)} token ids to {cache_path}")
            return {int(k): v.to(DEVICE) for k, v in tokenid2emb.items()}


    def encode(self, text: str) -> torch.Tensor:
        return self.embedding_model.encode(text, convert_to_tensor=True, device=DEVICE)

    def score(self, emb_a: torch.Tensor, emb_b: torch.Tensor) -> float:
        emb_a = emb_a.to(DEVICE)
        emb_b = emb_b.to(DEVICE)
        features = {
            "embedding_a": emb_a.unsqueeze(0),
            "embedding_b": emb_b.unsqueeze(0)
        }
        with torch.no_grad():
            logits = self.matching_head(features)["logits"]
            return torch.sigmoid(logits).item()

    @torch.no_grad()
    def predict_batch(self, answers, reasons, batch_size=32):
        assert len(answers) == len(reasons)
        all_probs = []
        for idx in range(0, len(answers), batch_size):
            batch_answers = answers[idx:idx+batch_size]
            batch_reasons = reasons[idx:idx+batch_size]

            emb_a = self.embedding_model.encode(batch_answers, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)
            emb_b = self.embedding_model.encode(batch_reasons, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = self.matching_head(features)
            logits = outputs["logits"].squeeze(-1)
            probs = torch.sigmoid(logits)
            all_probs.extend(probs.tolist())
        return all_probs

    def score_pair(self, reason: str, answer: str) -> float:
        probs = self.predict_batch([answer], [reason], batch_size=1)
        return probs[0]


def make_model_a_scorer(infer_a: MatchingInference):
    tokenizer = infer_a.embedding_model.tokenizer
    def model_a_score(token_ids: List[int], sentence_text: str) -> float:
        valid_embs = [infer_a.tokenid2emb[tid] for tid in token_ids if tid in infer_a.tokenid2emb]
        if not valid_embs:
            return 0.0
        token_emb = torch.mean(torch.stack(valid_embs), dim=0)
        sent_emb = infer_a.encode(sentence_text)
        
        token_emb = token_emb.to(DEVICE)
        sent_emb = sent_emb.to(DEVICE)

        features = {
            "embedding_a": token_emb.unsqueeze(0),
            "embedding_b": sent_emb.unsqueeze(0)
        }
        with torch.no_grad():
            logits = infer_a.matching_head(features)["logits"]
            score = torch.sigmoid(logits).item()
        return score
    return model_a_score, tokenizer


def _check_sufficiency_with_verifier(
    current_scores_a: List[float],
    current_scores_b: List[float],
    verification_threshold: float
) -> Tuple[bool, float]:
    global verifier_model_global
    if verifier_model_global is None:
        raise ValueError("Verifier model not initialized globally.")

    if not current_scores_a or not current_scores_b:
        return False, 0.0 
    
    if len(current_scores_a) != len(current_scores_b):
        print("Warning: Mismatch in score list lengths for verifier.")
        return False, 0.0

    interleaved_scores = []
    for score_a, score_b in zip(current_scores_a, current_scores_b):
        interleaved_scores.append(score_a)
        interleaved_scores.append(score_b)

    if not interleaved_scores:
        return False, 0.0

    x = torch.tensor(interleaved_scores, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(DEVICE)
    lengths = torch.tensor([len(interleaved_scores)], dtype=torch.long) # lengths should be on CPU for pack_padded_sequence

    verifier_model_global.eval()
    with torch.no_grad():
        model_probability = verifier_model_global(x, lengths).item()

    sufficient = model_probability >= verification_threshold
    return sufficient, model_probability


def preprocess_samples_for_efficient_validation(
    data: List[Dict],
    infer_a: MatchingInference,
    infer_b: MatchingInference,
    token_ratio: float,
) -> List[Dict]:
    print("🚀 Pre-processing samples to compute block scores...")
    preprocessed_data = []
    model_a_scorer, tokenizer_a = make_model_a_scorer(infer_a)

    for item in tqdm(data, desc="Pre-calculating block scores"):
        P, R_sentences, A_text = item["P"], item["R"], item["A"]

        filtered_sentences = [s for s in R_sentences if s and s.strip()]
        if not filtered_sentences:
            preprocessed_data.append({
                "label": item.get("label"),
                "block_scores": [],
                "alpha": 0
            })
            continue

        random.shuffle(filtered_sentences)
        
        current_sample_block_scores = []
        for block_text in filtered_sentences:
            encoding = tokenizer_a(block_text, add_special_tokens=False, return_tensors='pt')
            block_token_ids = encoding["input_ids"][0].tolist()

            if not block_token_ids:
                continue

            sample_size = max(1, int(len(block_token_ids) * token_ratio))
            sample_size = min(sample_size, len(block_token_ids))
            selected_ids = random.sample(block_token_ids, sample_size)
            
            try:
                score_a = model_a_scorer(selected_ids, block_text)
                score_b = infer_b.score_pair(block_text, A_text)
                current_sample_block_scores.append({'a': score_a, 'b': score_b})
            except Exception as e:
                print(f"Error scoring block '{block_text[:30]}...': {e}. Skipping block.")
                continue
        
        preprocessed_data.append({
            "label": item.get("label"),
            "block_scores": current_sample_block_scores,
            "alpha": len(filtered_sentences)
        })
    return preprocessed_data


def run_validation_on_preprocessed(
    preprocessed_sample_data: Dict,
    probing_ratio: float,
    verification_threshold: float
) -> Tuple[bool, int]: # MODIFIED: Return (prediction, num_rounds_for_this_sample)
    """
    Runs the validation logic for a single sample using its pre-computed block scores.
    Returns a tuple: (predicted_sufficiency, number_of_blocks_validated).
    """
    block_scores = preprocessed_sample_data["block_scores"]
    alpha = preprocessed_sample_data["alpha"] 

    num_blocks_validated_for_this_sample = 0 # MODIFIED: Initialize counter

    if not block_scores:
        return False, num_blocks_validated_for_this_sample # Returns (False, 0)

    num_initial_blocks_to_check_ideal = min(alpha, math.ceil(max(1, probing_ratio * alpha)))
    
    accumulated_scores_a = []
    accumulated_scores_b = []
    pred_sufficient = False

    for i, score_pair in enumerate(block_scores):
        accumulated_scores_a.append(score_pair['a'])
        accumulated_scores_b.append(score_pair['b'])
        num_blocks_validated_for_this_sample = i + 1 # MODIFIED: Increment for each block processed
        
        if (i + 1) >= num_initial_blocks_to_check_ideal:
            is_sufficient, _ = _check_sufficiency_with_verifier(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            if is_sufficient:
                pred_sufficient = True
                break # Early stop, num_blocks_validated_for_this_sample will hold the count
    
    if not pred_sufficient:
        if accumulated_scores_a:
            is_sufficient, _ = _check_sufficiency_with_verifier(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            pred_sufficient = is_sufficient
        else:
            pred_sufficient = False
            
    return pred_sufficient, num_blocks_validated_for_this_sample # MODIFIED: Return count


# -------- 主函数入口 --------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_a_dir", type=str, required=True)
    parser.add_argument("--model_b_dir", type=str, required=True)
    parser.add_argument("--verifier_model_path", type=str, required=True, help="Path to the verifier model (e.g., best_model.pt)")
    parser.add_argument("--verifier_model_type", type=str, default="DeepSet", choices=["RNN", "DeepSet"], help="Type of verifier model")
    parser.add_argument("--verifier_hidden_dim", type=int, default=256, help="Hidden dimension for verifier model")
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--accuracy_csv_output_path", type=str, required=True, help="Path to save the accuracy CSV results") # MODIFIED: More specific name
    parser.add_argument("--rounds_csv_output_path", type=str, required=True, help="Path to save the average validation rounds CSV results") # NEW: For average rounds
    parser.add_argument("--token_ratio", type=float, default=0.1, help="Theta: Ratio of tokens to sample within a block for model A.")
    
    parser.add_argument("--verification_thresholds", type=float, nargs='+', required=True, help="List of verification thresholds to test.")
    parser.add_argument("--probing_ratios", type=float, nargs='+', required=True, help="List of probing ratios (gamma) to test.")
    parser.add_argument("--max_samples", type=int, default=None, help="Process a maximum number of samples for quick testing (e.g., 100).")

    args = parser.parse_args()

    print(f"Device: {DEVICE}")

    print("🧠 Initializing Verifier Model...")
    if args.verifier_model_type.upper() == "RNN":
        verifier_model_global = RNNClassifier(hidden_dim=args.verifier_hidden_dim, rnn_type='GRU').to(DEVICE)
    else:
        verifier_model_global = DeepSetClassifier(hidden_dim=args.verifier_hidden_dim).to(DEVICE)
    
    try:
        verifier_model_global.load_state_dict(torch.load(args.verifier_model_path, map_location=DEVICE))
        print(f"Verifier model loaded successfully from '{args.verifier_model_path}'")
    except FileNotFoundError:
        print(f"Warning: Verifier model '{args.verifier_model_path}' not found. Using a randomly initialized verifier.")
    except Exception as e:
        print(f"Error loading verifier model: {e}. Using a randomly initialized verifier.")
    verifier_model_global.eval()

    print("🚀 Initializing Inference Engines...")
    infer_a = MatchingInference(args.model_a_dir)
    infer_b = MatchingInference(args.model_b_dir)
    print("✅ Inference Engines Ready.")

    with open(args.data_path, "r", encoding="utf-8") as fin:
        all_data = json.load(fin)

    if args.max_samples is not None and args.max_samples > 0:
        print(f"🔪 Using a subset of {args.max_samples} samples for processing.")
        data_subset = all_data[:args.max_samples]
    else:
        data_subset = all_data
    
    preprocessed_samples = preprocess_samples_for_efficient_validation(
        data_subset, infer_a, infer_b, args.token_ratio
    )

    accuracy_results_for_csv = []
    avg_rounds_results_for_csv = [] # NEW: List to store average rounds data

    print(f"\n⚙️ Starting multi-parameter validation loop for {len(args.verification_thresholds)} thresholds and {len(args.probing_ratios)} probing ratios...")
    
    num_total_samples_processed = len(preprocessed_samples) # Used for averaging rounds

    for v_thresh in tqdm(args.verification_thresholds, desc="Thresholds"):
        for p_ratio in tqdm(args.probing_ratios, desc="Probing Ratios", leave=False):
            correct_predictions = 0
            total_labeled_samples = 0
            total_validation_rounds_for_combo = 0 # NEW: Accumulator for rounds for this param combo
            
            for sample_data in preprocessed_samples: # Iterate through all preprocessed samples
                label = sample_data["label"]
                
                # Get the prediction AND number of rounds for this sample
                pred_is_sufficient, num_rounds_this_sample = run_validation_on_preprocessed( # MODIFIED
                    sample_data, p_ratio, v_thresh
                )
                
                total_validation_rounds_for_combo += num_rounds_this_sample # NEW: Accumulate rounds

                # Accuracy calculation (only for labeled samples)
                if label is not None:
                    total_labeled_samples += 1
                    if pred_is_sufficient == label:
                        correct_predictions += 1
            
            accuracy = (correct_predictions / total_labeled_samples) if total_labeled_samples > 0 else 0.0
            
            accuracy_results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "accuracy": f"{accuracy:.4f}",
                "correct_predictions": correct_predictions,
                "total_labeled_samples": total_labeled_samples
            })

            # NEW: Calculate and store average rounds for this combo
            avg_rounds_for_combo = (total_validation_rounds_for_combo / num_total_samples_processed) \
                                   if num_total_samples_processed > 0 else 0.0
            avg_rounds_results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "avg_validation_rounds": f"{avg_rounds_for_combo:.2f}" # Store as formatted string or float
            })


    # --- Stage 3a: Output accuracy results to CSV ---
    if accuracy_results_for_csv:
        fieldnames = accuracy_results_for_csv[0].keys()
        with open(args.accuracy_csv_output_path, "w", newline='', encoding="utf-8") as fout_csv: # MODIFIED: output path
            writer = csv.DictWriter(fout_csv, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(accuracy_results_for_csv)
        print(f"\n✅ Accuracy results for different combinations saved to {args.accuracy_csv_output_path}")
    else:
        print("\n⚠️ No accuracy results to save to CSV. Check data or parameters.")

    # --- Stage 3b: Output average rounds results to a NEW CSV ---
    if avg_rounds_results_for_csv:
        fieldnames_rounds = avg_rounds_results_for_csv[0].keys()
        with open(args.rounds_csv_output_path, "w", newline='', encoding="utf-8") as fout_csv_rounds: # NEW: output path
            writer = csv.DictWriter(fout_csv_rounds, fieldnames=fieldnames_rounds)
            writer.writeheader()
            writer.writerows(avg_rounds_results_for_csv)
        print(f"\n✅ Average validation rounds for different combinations saved to {args.rounds_csv_output_path}")
    else:
        print("\n⚠️ No average rounds results to save to CSV. Check data or parameters.")


    print("\n----- Multi-Parameter Validation Summary (Accuracy) -----")
    for res in accuracy_results_for_csv:
        print(f"Threshold: {res['verification_threshold']}, Probing Ratio: {res['probing_ratio']}, Accuracy: {res['accuracy']} ({res['correct_predictions']}/{res['total_labeled_samples']})")
    
    print("\n----- Multi-Parameter Validation Summary (Average Rounds) -----") # NEW
    for res in avg_rounds_results_for_csv:
        print(f"Threshold: {res['verification_threshold']}, Probing Ratio: {res['probing_ratio']}, Avg. Rounds: {res['avg_validation_rounds']}")