"""
Core Experiments Module

Implements the key experiments for the paper:
"Beyond Temperature: Understanding How Hyperfitting Changes Token Rankings in LLMs"
TODO: Find a better(?) title for our paper.

Experiments:
1. Temperature Matching - Prove hyperfitting ≠ temperature
2. Rank Analysis - Show which tokens change ranks
3. Synthetic Hyperfitting - Try to replicate effect without training
4. Representation Analysis - Find where changes happen (Which layers change in the model)
"""

import os
import json
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List
from tqdm import tqdm
import logging
from dataclasses import dataclass
from metrics import RankAnalyzer
from metrics import (
    DistributionAnalyzer,
    RankAnalyzer,
    GenerationMetrics,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class ExperimentResult:
    """
    Container for experiment results
    """
    experiment_name: str
    model_name: str
    results: Dict
    metadata: Dict


class Experiment1_TemperatureMatching:
    """
    Experiment 1: Temperature Matching
    
    Goal: Prove that hyperfitting is NOT equivalent to temperature scaling
    
    Method:
    1. Measure entropy of hyperfitted model's predictions
    2. Find temperature T such that original model with T has same entropy
    3. Compare generation quality: if hyperfitting = temperature, quality should match
    
    Expected Result: Even with matched entropy, hyperfitted model produces better generations
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.orig_analyzer = DistributionAnalyzer(original_model, tokenizer, device)
        self.hyper_analyzer = DistributionAnalyzer(hyperfitted_model, tokenizer, device)
    
    def measure_hyperfitted_entropy(self, eval_texts: List[str]) -> float:
        """
        Measure average entropy of hyperfitted model.
        """
        total_entropy = 0.0
        count = 0
        
        for text in tqdm(eval_texts, desc="Measuring hyperfitted entropy"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            result = self.hyper_analyzer.analyze_distribution(input_ids)
            total_entropy += result["entropy"]
            count += 1
        
        return total_entropy / count
    
    def find_matching_temperature(
        self,
        eval_texts: List[str],
        target_entropy: float,
    ) -> float:
        """
        Find temperature that produces similar entropy on original model.
        """
        # Use first few texts for calibration
        calibration_texts = eval_texts[:10]
        
        temperatures = []
        for text in tqdm(calibration_texts, desc="Finding matching temperature"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            t = self.orig_analyzer.find_matching_temperature(input_ids, target_entropy)
            temperatures.append(t)
        
        return np.mean(temperatures)
    
    def generate_and_compare(
        self,
        prompts: List[str],
        matched_temperature: float,
        max_new_tokens: int = 224,
    ) -> Dict:
        """
        Generate text with both models and compare quality
        
        Compares:
        - Original model (greedy, temp=1.0)
        - Original model (with matched temperature)
        - Hyperfitted model (greedy)
        """
        results = {
            "original_greedy": [],
            "original_matched_temp": [],
            "hyperfitted_greedy": [],
        }
        
        for prompt in tqdm(prompts, desc="Generating and comparing"):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            # Original model with greedy decoding
            self.original_model.eval()
            with torch.no_grad():
                out1 = self.original_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            tokens1 = out1[0].tolist()
            metrics1 = GenerationMetrics.compute_all_metrics(tokens1[input_ids.shape[1]:])
            results["original_greedy"].append(metrics1)
            
            # Original model - matched temperature (sample then take argmax)
            # This simulates what temperature does to the distribution
            generated_tokens = input_ids.tolist()[0]
            current_input = input_ids.clone()
            
            for _ in range(max_new_tokens):
                with torch.no_grad():
                    outputs = self.original_model(current_input)
                    logits = outputs.logits[:, -1, :] / matched_temperature
                    # Take argmax (greedy on scaled distribution)
                    next_token = logits.argmax(dim=-1, keepdim=True)
                    generated_tokens.append(next_token.item())
                    current_input = torch.cat([current_input, next_token], dim=-1)
                    
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
            
            metrics2 = GenerationMetrics.compute_all_metrics(generated_tokens[input_ids.shape[1]:])
            results["original_matched_temp"].append(metrics2)
            
            # Hyperfitted model with greedy decoding
            self.hyperfitted_model.eval()
            with torch.no_grad():
                out3 = self.hyperfitted_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            tokens3 = out3[0].tolist()
            metrics3 = GenerationMetrics.compute_all_metrics(tokens3[input_ids.shape[1]:])
            results["hyperfitted_greedy"].append(metrics3)
        
        aggregated = {}
        for model_type, metrics_list in results.items():
            aggregated[model_type] = {
                "mean_ttr": np.mean([m["ttr"] for m in metrics_list]),
                "std_ttr": np.std([m["ttr"] for m in metrics_list]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in metrics_list]),
                "mean_trigram_rep": np.mean([m["trigram_repetition"] for m in metrics_list]),
                "mean_length": np.mean([m["length"] for m in metrics_list]),
            }
        
        return {
            "individual_results": results,
            "aggregated": aggregated,
            "matched_temperature": matched_temperature,
        }
    
    def run(
        self,
        eval_texts: List[str],
        prompts: List[str],
        max_new_tokens: int = 224,
    ) -> ExperimentResult:
        """
        Run the full temperature matching experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 1: Temperature Matching")
        logger.info("=" * 60)
        
        # Compute hyperfitted entropy
        logger.info("Measuring hyperfitted model entropy")
        hyper_entropy = self.measure_hyperfitted_entropy(eval_texts)
        logger.info(f"Hyperfitted model average entropy: {hyper_entropy:.4f}")
        
        # Find matching temperature
        logger.info("Finding matching temperature for original model")
        matched_temp = self.find_matching_temperature(eval_texts, hyper_entropy)
        logger.info(f"Matched temperature: {matched_temp:.4f}")
        
        # Verify entropy match
        logger.info("Verifying entropy match")
        orig_entropy_default = 0.0
        orig_entropy_matched = 0.0
        
        for text in eval_texts[:20]:
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            default_result = self.orig_analyzer.analyze_distribution(input_ids, temperature=1.0)
            matched_result = self.orig_analyzer.analyze_distribution(input_ids, temperature=matched_temp)
            
            orig_entropy_default += default_result["entropy"]
            orig_entropy_matched += matched_result["entropy"]
        
        orig_entropy_default /= 20
        orig_entropy_matched /= 20
        
        logger.info(f"Original model entropy (temp=1.0): {orig_entropy_default:.4f}")
        logger.info(f"Original model entropy (temp={matched_temp:.3f}): {orig_entropy_matched:.4f}")
        logger.info(f"Hyperfitted model entropy: {hyper_entropy:.4f}")
        
        # Step 4: Generate and compare
        logger.info("Step 4: Generating and comparing outputs...")
        comparison = self.generate_and_compare(prompts, matched_temp, max_new_tokens)
        
        # Summary
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        for model_type, metrics in comparison["aggregated"].items():
            logger.info(f"\n{model_type}:")
            logger.info(f"  Mean TTR: {metrics['mean_ttr']:.4f} (±{metrics['std_ttr']:.4f})")
            logger.info(f"  Mean Bigram Repetition: {metrics['mean_bigram_rep']:.4f}")
            logger.info(f"  Mean Trigram Repetition: {metrics['mean_trigram_rep']:.4f}")
        
        return ExperimentResult(
            experiment_name="temperature_matching",
            model_name=str(type(self.original_model).__name__),
            results={
                "hyperfitted_entropy": hyper_entropy,
                "matched_temperature": matched_temp,
                "original_entropy_default": orig_entropy_default,
                "original_entropy_matched": orig_entropy_matched,
                "generation_comparison": comparison,
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "num_prompts": len(prompts),
                "max_new_tokens": max_new_tokens,
            },
        )


class Experiment2_RankAnalysis:
    """
    Experiment 2: Rank Analysis
    
    Goal: Show that hyperfitting changes WHICH tokens are top-ranked, not just distribution shape
    
    Method:
    1. For each position, compare top-k rankings between original and hyperfitted
    2. Compute rank correlation
    3. Identify tokens that get significantly promoted/demoted
    
    Expected Result: Significant rank changes, not just probability reweighting
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.rank_analyzer = RankAnalyzer(
            original_model,
            hyperfitted_model,
            tokenizer,
            device,
        )
        self.tokenizer = tokenizer
        self.device = device

    def compare_temperatures(self, eval_texts, temperature1=1.0, temperature2=0.59, max_seq_length=256):
        """比较两个温度设置的排序"""
        import torch.nn.functional as F

        logger.info(f"Comparing T={temperature1} vs T={temperature2}")
        all_top1_agreements = []
        all_rank_correlations = []

        original_model = self.rank_analyzer.original_model

        for text in tqdm(eval_texts, desc="Temperature comparison"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]

            top1_result = self._compare_top1_at_temperatures(input_ids, temperature1, temperature2, original_model)
            all_top1_agreements.append(top1_result)

            rank_corr = self._compute_rank_correlation_at_temperatures(input_ids, temperature1, temperature2,
                                                                       original_model)
            all_rank_correlations.append(rank_corr)

        return {
            "top1_comparison": {
                "top1_agreement": np.mean([r["top1_agreement"] for r in all_top1_agreements]),
                "t2_top1_in_t1_top5": np.mean([r["t2_top1_in_t1_top5"] for r in all_top1_agreements]),
                "t2_top1_in_t1_top10": np.mean([r["t2_top1_in_t1_top10"] for r in all_top1_agreements]),
            },
            "rank_correlation": {
                "mean_rank_correlation": np.mean([r["correlation"] for r in all_rank_correlations]),
                "std_rank_correlation": np.std([r["correlation"] for r in all_rank_correlations]),
            },
            "temperature1": temperature1,
            "temperature2": temperature2,
        }

    def _compare_top1_at_temperatures(self, input_ids, temp1, temp2, model):
        """比较两个温度下的top-1预测"""
        import torch.nn.functional as F

        input_ids = input_ids.to(self.device)
        model.eval()

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits

        batch_size, seq_len, vocab_size = logits.shape
        results = {"top1_agreement": 0, "t2_top1_in_t1_top5": 0, "t2_top1_in_t1_top10": 0, "total_positions": 0}

        for b in range(batch_size):
            for pos in range(seq_len - 1):
                position_logits = logits[b, pos, :]

                probs_t1 = F.softmax(position_logits / temp1, dim=-1)
                top1_t1 = torch.argmax(probs_t1).item()
                top5_t1 = torch.topk(probs_t1, k=5)[1].tolist()
                top10_t1 = torch.topk(probs_t1, k=10)[1].tolist()

                probs_t2 = F.softmax(position_logits / temp2, dim=-1)
                top1_t2 = torch.argmax(probs_t2).item()

                if top1_t1 == top1_t2:
                    results["top1_agreement"] += 1
                if top1_t2 in top5_t1:
                    results["t2_top1_in_t1_top5"] += 1
                if top1_t2 in top10_t1:
                    results["t2_top1_in_t1_top10"] += 1

                results["total_positions"] += 1

        total = results["total_positions"]
        if total > 0:
            for key in ["top1_agreement", "t2_top1_in_t1_top5", "t2_top1_in_t1_top10"]:
                results[key] /= total

        return results

    def _compute_rank_correlation_at_temperatures(self, input_ids, temp1, temp2, model, top_k=100):
        """计算两个温度下的Spearman相关性"""
        import torch.nn.functional as F
        from scipy.stats import spearmanr

        input_ids = input_ids.to(self.device)
        model.eval()

        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits

        batch_size, seq_len, vocab_size = logits.shape
        correlations = []

        for b in range(batch_size):
            for pos in range(seq_len - 1):
                position_logits = logits[b, pos, :]

                probs_t1 = F.softmax(position_logits / temp1, dim=-1)
                ranking_t1 = torch.argsort(probs_t1, descending=True)[:top_k].cpu().numpy()

                probs_t2 = F.softmax(position_logits / temp2, dim=-1)
                ranking_t2 = torch.argsort(probs_t2, descending=True)

                ranks_in_t2 = []
                for token_id in ranking_t1:
                    idx = (ranking_t2 == token_id).nonzero(as_tuple=True)[0]
                    if len(idx) > 0:
                        ranks_in_t2.append(idx[0].item())
                    else:
                        ranks_in_t2.append(vocab_size)

                if len(set(ranks_in_t2)) > 1:
                    corr, _ = spearmanr(list(range(top_k)), ranks_in_t2)
                    if not np.isnan(corr):
                        correlations.append(corr)

        return {
            "correlation": np.mean(correlations) if correlations else 0.0,
            "num_positions": len(correlations),
        }

    def run(
        self,
        eval_texts: List[str],
        max_seq_length: int = 256,
    ) -> ExperimentResult:
        """
        Run the rank analysis experiment.
        """

        logger.info("=" * 60)
        logger.info("Experiment 2: Rank Analysis")
        logger.info("=" * 60)
        
        all_top1_comparisons = []
        all_rank_correlations = []
        all_promoted_tokens = []
        
        for text in tqdm(eval_texts, desc="Analyzing ranks"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            # Compare top-1 predictions
            top1_comparison = self.rank_analyzer.compare_top1_predictions(input_ids)
            all_top1_comparisons.append(top1_comparison)
            
            # Compute rank correlations
            rank_corr = self.rank_analyzer.compute_rank_correlation(input_ids, top_k=100)
            all_rank_correlations.append(rank_corr)
            
            # Find promoted tokens
            promoted = self.rank_analyzer.find_promoted_tokens(input_ids, threshold=50)
            all_promoted_tokens.extend(promoted[:20])  # Keep top 20 per text
        
        aggregated_top1 = {
            "top1_agreement": np.mean([c["top1_agreement"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top5": np.mean([c["hyper_top1_in_orig_top5"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top10": np.mean([c["hyper_top1_in_orig_top10"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top50": np.mean([c["hyper_top1_in_orig_top50"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top100": np.mean([c["hyper_top1_in_orig_top100"] for c in all_top1_comparisons]),
        }
        
        aggregated_corr = {
            "mean_rank_correlation": np.mean([c["mean_rank_correlation"] for c in all_rank_correlations]),
            "std_rank_correlation": np.std([c["mean_rank_correlation"] for c in all_rank_correlations]),
        }
        
        # Sort promoted tokens by rank improvement
        all_promoted_tokens.sort(key=lambda x: x["rank_improvement"], reverse=True)
        top_promoted = all_promoted_tokens[:100]  # Keep top 100
        
        # Summary
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        logger.info(f"\nTop-1 Agreement Rate: {aggregated_top1['top1_agreement']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-5: {aggregated_top1['hyper_top1_in_orig_top5']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-10: {aggregated_top1['hyper_top1_in_orig_top10']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-50: {aggregated_top1['hyper_top1_in_orig_top50']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-100: {aggregated_top1['hyper_top1_in_orig_top100']:.4f}")
        logger.info(f"\nMean Rank Correlation (top-100): {aggregated_corr['mean_rank_correlation']:.4f}")
        
        logger.info("\nTop 10 Most Promoted Tokens:")
        for i, token_info in enumerate(top_promoted[:10]):
            logger.info(f"  {i+1}. '{token_info['token_str']}': rank {token_info['original_rank']} -> {token_info['hyperfitted_rank']} (+{token_info['rank_improvement']})")
        logger.info("收集排名分布数据...")

        # rank_analyzer = RankAnalyzer(
        #     model=self.model,
        #     tokenizer=self.tokenizer,
        #     device=self.device  # 或 'cuda' / 'cpu'
        # )
        # eval_input_ids_list = []
        # for batch in self.eval_dataloader:
        #     input_ids = batch['input_ids']  # 或 batch[0]
        #     eval_input_ids_list.append(input_ids)
        # eval_input_ids = torch.cat(eval_input_ids_list, dim=0)
        rank_distribution = self.rank_analyzer.collect_rank_distribution(input_ids)

        # ========== 新增：温度比较 ==========
        logger.info("\n[NEW] Computing temperature comparison...")
        temperature_comparison = self.compare_temperatures(
            eval_texts=eval_texts,
            temperature1=1.0,
            temperature2=0.59,  # 从 Experiment1 获取，或硬编码
            max_seq_length=max_seq_length,
        )

        # 打印温度比较结果
        logger.info(f"\n[Temperature Comparison Results]")
        logger.info(f"Top-1 Agreement: {temperature_comparison['top1_comparison']['top1_agreement']:.4f}")
        logger.info(f"Rank Correlation: {temperature_comparison['rank_correlation']['mean_rank_correlation']:.4f}")

        return ExperimentResult(
            experiment_name="rank_analysis",
            model_name="comparison",
            results={
                "top1_comparison": aggregated_top1,
                "rank_correlation": aggregated_corr,
                "promoted_tokens": top_promoted,
                'rank_distribution': rank_distribution,  # 新增！
                "temperature_comparison": temperature_comparison,  # 新增这一行！
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "max_seq_length": max_seq_length,
            },
        )


class Experiment3_SyntheticHyperfitting:
    """
    Experiment 3: Synthetic Hyperfitting

    Goal: Can we achieve hyperfitting effects WITHOUT training?

    Revised Logic:
    1. Identify tokens that are systematically promoted in Rank.
    2. For these tokens, learn the average LOGIT increase (not rank difference).
    3. Apply this Logit Bias to the original model.
    """

    def __init__(
            self,
            original_model,
            hyperfitted_model,
            tokenizer,
            device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device

        # 改名以明确物理含义：这是一个 Logit 偏置，不是排名整数
        self.logit_bias = None

    def learn_rank_corrections(
            self,
            calibration_texts: List[str],
            max_seq_length: int = 256,
            top_k: int = 500  # 【新增】只保留前 500 个最重要的 Logit 变化，去除噪音
    ) -> Dict[int, float]:
        """
        Learn token-level logit bias (Filtered & Optimized).
        """
        token_logit_diffs = {}

        # 【关键修正 1】获取所有特殊 Token ID (EOS, BOS, PAD等)
        # 必须排除它们，防止模型学到“过早结束”的偏置
        special_tokens = set(self.tokenizer.all_special_ids)
        if self.tokenizer.eos_token_id is not None:
            special_tokens.add(self.tokenizer.eos_token_id)

        for text in tqdm(calibration_texts, desc="Learning corrections"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]

            self.original_model.eval()
            self.hyperfitted_model.eval()

            with torch.no_grad():
                orig_logits = self.original_model(input_ids).logits
                hyper_logits = self.hyperfitted_model(input_ids).logits

            orig_rankings = torch.argsort(orig_logits, dim=-1, descending=True)
            hyper_rankings = torch.argsort(hyper_logits, dim=-1, descending=True)

            seq_len = input_ids.shape[1]

            for pos in range(seq_len):
                # 只看 Hyperfitted 的 Top 20
                for rank_idx in range(20):
                    token_id = hyper_rankings[0, pos, rank_idx].item()

                    # 【关键修正 2】如果是特殊字符（如EOS），直接跳过
                    if token_id in special_tokens:
                        continue

                    orig_rank_pos = (orig_rankings[0, pos, :] == token_id).nonzero(as_tuple=True)[0]

                    if len(orig_rank_pos) > 0:
                        orig_rank = orig_rank_pos[0].item()
                        current_rank = rank_idx

                        # 只有排名显著提升才记录
                        if orig_rank > current_rank + 5:
                            hyper_val = hyper_logits[0, pos, token_id].item()
                            orig_val = orig_logits[0, pos, token_id].item()
                            diff = hyper_val - orig_val

                            if diff > 0:
                                if token_id not in token_logit_diffs:
                                    token_logit_diffs[token_id] = []
                                token_logit_diffs[token_id].append(diff)

        # 计算平均值
        avg_diffs = {
            k: np.mean(v) for k, v in token_logit_diffs.items() if len(v) >= 3
        }

        # 【关键修正 3】排序并截断：只保留提升幅度最大的 Top-K 个 Token
        # 这样可以将修正数量从 8000+ 降到 500，大幅减少噪音
        sorted_items = sorted(avg_diffs.items(), key=lambda x: x[1], reverse=True)
        self.logit_bias = dict(sorted_items[:top_k])

        logger.info(f"Learned logit biases for Top-{len(self.logit_bias)} tokens (Filtered specials & Top-K)")

        if len(self.logit_bias) > 0:
            avg_val = np.mean(list(self.logit_bias.values()))
            logger.info(f"Average correction magnitude: {avg_val:.4f}")

        return self.logit_bias

    def generate_with_synthetic_correction(
            self,
            prompt: str,
            max_new_tokens: int = 224,
            scale: float = 1.0,
    ) -> List[int]:
        """
        Generate text using synthetic correction (Optimized & Type-Safe).
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated = input_ids.clone()

        self.original_model.eval()

        # 1. 预先构建 Correction Vector (默认为 float32)
        vocab_size = self.original_model.config.vocab_size
        correction_vector = torch.zeros(vocab_size, device=self.device)

        if self.logit_bias:
            indices = torch.tensor(list(self.logit_bias.keys()), device=self.device)

            # 【核心修复点】显式指定 dtype=torch.float32
            # 这样就把 Numpy 的 Double (float64) 转为了 Float (float32)，解决了类型不匹配报错
            values = torch.tensor(
                list(self.logit_bias.values()),
                device=self.device,
                dtype=torch.float32
            )

            # 现在两边都是 float32，可以安全赋值
            correction_vector[indices] = values * scale

        for _ in range(max_new_tokens):
            with torch.no_grad():
                outputs = self.original_model(generated)
                next_token_logits = outputs.logits[:, -1, :]

                # 直接应用预计算的 Vector
                # 注意：如果模型是 bfloat16/float16，float32 加上去会自动广播，通常是安全的
                corrected_logits = next_token_logits + correction_vector

                # Greedy selection
                next_token = corrected_logits.argmax(dim=-1, keepdim=True)
                generated = torch.cat([generated, next_token], dim=-1)

                if next_token.item() == self.tokenizer.eos_token_id:
                    break

        return generated[0].tolist()

    def run(
            self,
            calibration_texts: List[str],
            test_prompts: List[str],
            max_new_tokens: int = 224,
            # 【参数调整】现在的 Scale 是针对 Logit 的
            # Scale=1.0 意味着"完全复现 Hyperfitted 模型的 Logit 偏置"
            scales: List[float] = [0.2, 0.5, 0.8, 1.0, 1.2],
    ) -> ExperimentResult:
        """
        Run the synthetic hyperfitting experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 3: Synthetic Hyperfitting (Logit Bias Method)")
        logger.info("=" * 60)

        # Learn corrections
        self.learn_rank_corrections(calibration_texts)

        # Test different scales
        results_by_scale = {}

        for scale in scales:
            logger.info(f"Testing scale={scale}")

            scale_metrics = []

            for prompt in tqdm(test_prompts, desc=f"Scale {scale}"):
                tokens = self.generate_with_synthetic_correction(
                    prompt, max_new_tokens, scale
                )

                input_len = len(self.tokenizer.encode(prompt))
                generated_tokens = tokens[input_len:]

                # 处理空生成的情况
                if not generated_tokens:
                    generated_tokens = [0]

                metrics = GenerationMetrics.compute_all_metrics(generated_tokens)
                scale_metrics.append(metrics)

            results_by_scale[scale] = {
                "mean_ttr": np.mean([m["ttr"] for m in scale_metrics]),
                "std_ttr": np.std([m["ttr"] for m in scale_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in scale_metrics]),
                "mean_length": np.mean([m["length"] for m in scale_metrics]),
            }

        # Compare with baselines
        logger.info("Comparing with baselines")

        # Original model baseline
        original_metrics = []
        for prompt in tqdm(test_prompts, desc="Original model"):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                out = self.original_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            generated_tokens = out[0].tolist()[input_ids.shape[1]:]
            original_metrics.append(GenerationMetrics.compute_all_metrics(generated_tokens))

        # Hyperfitted model baseline
        hyperfitted_metrics = []
        for prompt in tqdm(test_prompts, desc="Hyperfitted model"):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                out = self.hyperfitted_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            generated_tokens = out[0].tolist()[input_ids.shape[1]:]
            hyperfitted_metrics.append(GenerationMetrics.compute_all_metrics(generated_tokens))

        baselines = {
            "original": {
                "mean_ttr": np.mean([m["ttr"] for m in original_metrics]),
                "std_ttr": np.std([m["ttr"] for m in original_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in original_metrics]),
            },
            "hyperfitted": {
                "mean_ttr": np.mean([m["ttr"] for m in hyperfitted_metrics]),
                "std_ttr": np.std([m["ttr"] for m in hyperfitted_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in hyperfitted_metrics]),
            },
        }

        return ExperimentResult(
            experiment_name="synthetic_hyperfitting",
            model_name="comparison",
            results={
                "baselines": baselines,
                "synthetic_by_scale": results_by_scale,
                "num_corrections": len(self.logit_bias) if self.logit_bias else 0,
            },
            metadata={
                "num_calibration_texts": len(calibration_texts),
                "num_test_prompts": len(test_prompts),
                "scales_tested": scales,
            },
        )

class Experiment4_RepresentationAnalysis:
    """
    Experiment 4: Representation Analysis
    
    Goal: Find where in the model hyperfitting causes changes
    
    Method:
    1. Compare hidden states layer by layer
    2. Compute cosine similarity and effective dimensionality
    3. Identify which layers change most
    
    Expected Result: Identify critical layers for hyperfitting effect
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
    
    def compute_effective_dimensionality(self, hidden_states: torch.Tensor) -> float:
        """
        Compute participation ratio (effective dimensionality)
        
        PR = (sum of eigenvalues)^2 / (sum of eigenvalues^2)
        """
        h = hidden_states.view(-1, hidden_states.shape[-1]).float()
        h_centered = h - h.mean(dim=0)
        
        # Covariance matrix
        cov = torch.mm(h_centered.T, h_centered) / h.shape[0]
        
        # Eigenvalues
        try:
            eigvals = torch.linalg.eigvalsh(cov).clamp(min=1e-10)
            pr = (eigvals.sum() ** 2) / (eigvals ** 2).sum()
            return pr.item()
        except:
            return 0.0
    
    def compare_layers(
        self,
        input_ids: torch.Tensor,
    ) -> List[Dict]:
        """
        Compare hidden states at each layer. Probably the most important part of the experiment.
        """
        input_ids = input_ids.to(self.device)
        
        self.original_model.eval()
        self.hyperfitted_model.eval()
        
        with torch.no_grad():
            orig_out = self.original_model(input_ids, output_hidden_states=True)
            hyper_out = self.hyperfitted_model(input_ids, output_hidden_states=True)
        
        results = []
        
        for layer_idx, (orig_h, hyper_h) in enumerate(
            zip(orig_out.hidden_states, hyper_out.hidden_states)
        ):
            # Cosine similarity
            orig_flat = orig_h.view(-1, orig_h.shape[-1])
            hyper_flat = hyper_h.view(-1, hyper_h.shape[-1])
            
            cos_sim = F.cosine_similarity(orig_flat, hyper_flat, dim=-1).mean().item()
            
            # L2 distance
            l2_dist = torch.norm(orig_h - hyper_h, dim=-1).mean().item()
            
            # Effective dimensionality
            orig_dim = self.compute_effective_dimensionality(orig_h)
            hyper_dim = self.compute_effective_dimensionality(hyper_h)
            
            # Norm
            orig_norm = orig_h.norm(dim=-1).mean().item()
            hyper_norm = hyper_h.norm(dim=-1).mean().item()
            
            results.append({
                "layer": layer_idx,
                "cosine_similarity": cos_sim,
                "l2_distance": l2_dist,
                "original_effective_dim": orig_dim,
                "hyperfitted_effective_dim": hyper_dim,
                "dim_change": hyper_dim - orig_dim,
                "original_norm": orig_norm,
                "hyperfitted_norm": hyper_norm,
                "norm_change": hyper_norm - orig_norm,
            })
        
        return results
    
    def run(
        self,
        eval_texts: List[str],
        max_seq_length: int = 256,
    ) -> ExperimentResult:
        """
        Run the representation analysis experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 4: Representation Analysis")
        logger.info("=" * 60)
        
        all_layer_results = []
        
        for text in tqdm(eval_texts, desc="Analyzing representations"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            layer_results = self.compare_layers(input_ids)
            all_layer_results.append(layer_results)
        
        # Aggregate by layer
        num_layers = len(all_layer_results[0])
        aggregated = []
        
        for layer_idx in range(num_layers):
            layer_data = [r[layer_idx] for r in all_layer_results]
            
            aggregated.append({
                "layer": layer_idx,
                "mean_cosine_sim": np.mean([d["cosine_similarity"] for d in layer_data]),
                "std_cosine_sim": np.std([d["cosine_similarity"] for d in layer_data]),
                "mean_l2_dist": np.mean([d["l2_distance"] for d in layer_data]),
                "mean_dim_change": np.mean([d["dim_change"] for d in layer_data]),
                "mean_norm_change": np.mean([d["norm_change"] for d in layer_data]),
            })
        
        # Find most changed layers
        most_changed = sorted(aggregated, key=lambda x: x["mean_l2_dist"], reverse=True)[:5]
        
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        logger.info("\nLayer-wise Analysis (first 5 and last 5 layers):")
        
        for layer_info in aggregated[:5]:
            logger.info(f"  Layer {layer_info['layer']}: cos_sim={layer_info['mean_cosine_sim']:.4f}, L2={layer_info['mean_l2_dist']:.4f}")
        logger.info("  ...")
        for layer_info in aggregated[-5:]:
            logger.info(f"  Layer {layer_info['layer']}: cos_sim={layer_info['mean_cosine_sim']:.4f}, L2={layer_info['mean_l2_dist']:.4f}")
        
        logger.info("\nMost Changed Layers (by L2 distance):")
        for layer_info in most_changed:
            logger.info(f"  Layer {layer_info['layer']}: L2={layer_info['mean_l2_dist']:.4f}, dim_change={layer_info['mean_dim_change']:.2f}")
        
        return ExperimentResult(
            experiment_name="representation_analysis",
            model_name="comparison",
            results={
                "layer_wise_analysis": aggregated,
                "most_changed_layers": most_changed,
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "num_layers": num_layers,
                "max_seq_length": max_seq_length,
            },
        )


def save_results(result: ExperimentResult, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, f"{result.experiment_name}_results.json")
    
    # Convert to serializable format
    output = {
        "experiment_name": result.experiment_name,
        "model_name": result.model_name,
        "metadata": result.metadata,
        "results": {},
    }
    
    # Convert numpy arrays and other non-serializable types
    def convert_to_serializable(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, dict):
            return {k: convert_to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(item) for item in obj]
        else:
            return obj
    
    output["results"] = convert_to_serializable(result.results)
    
    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)
    
    logger.info(f"Results saved to {output_path}")
