import os
import json
import copy
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List
from tqdm import tqdm
import logging
from dataclasses import dataclass

from metrics import (
    DistributionAnalyzer,
    RankAnalyzer,
    GenerationMetrics,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _build_generation_config(model, max_length: int):
    generation_config = copy.deepcopy(model.generation_config)
    generation_config.max_new_tokens = None
    generation_config.max_length = max_length
    return generation_config


@dataclass
class ExperimentResult:
    """
    Container for experiment results
    """
    experiment_name: str
    model_name: str
    results: Dict
    metadata: Dict


class Experiment1_TemperatureMatching:
    """
    Experiment 1: Temperature Matching
    
    Goal: Prove that hyperfitting is NOT equivalent to temperature scaling
    
    Method:
    1. Measure entropy of hyperfitted model's predictions
    2. Find temperature T such that original model with T has same entropy
    3. Compare generation quality: if hyperfitting = temperature, quality should match
    
    Expected Result: Even with matched entropy, hyperfitted model produces better generations
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.orig_analyzer = DistributionAnalyzer(original_model, tokenizer, device)
        self.hyper_analyzer = DistributionAnalyzer(hyperfitted_model, tokenizer, device)
    
    def measure_hyperfitted_entropy(self, eval_texts: List[str]) -> float:
        """
        Measure average entropy of hyperfitted model.
        """
        total_entropy = 0.0
        count = 0
        
        for text in tqdm(eval_texts, desc="Measuring hyperfitted entropy"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            result = self.hyper_analyzer.analyze_distribution(input_ids)
            total_entropy += result["entropy"]
            count += 1
        
        return total_entropy / count
    
    def find_matching_temperature(
        self,
        eval_texts: List[str],
        target_entropy: float,
    ) -> float:
        """
        Find temperature that produces similar entropy on original model.
        """
        # Use first few texts for calibration
        calibration_texts = eval_texts[:10]
        
        temperatures = []
        for text in tqdm(calibration_texts, desc="Finding matching temperature"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            t = self.orig_analyzer.find_matching_temperature(input_ids, target_entropy)
            temperatures.append(t)
        
        return np.mean(temperatures)
    
    def generate_and_compare(
        self,
        prompts: List[str],
        matched_temperature: float,
        max_new_tokens: int = 224,
    ) -> Dict:
        """
        Generate text with both models and compare quality
        
        Compares:
        - Original model (greedy, temp=1.0)
        - Original model (with matched temperature)
        - Hyperfitted model (greedy)
        """
        results = {
            "original_greedy": [],
            "original_matched_temp": [],
            "hyperfitted_greedy": [],
        }
        
        for prompt in tqdm(prompts, desc="Generating and comparing"):
            encoding = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
            input_ids = encoding["input_ids"].to(self.device)
            attention_mask = encoding["attention_mask"].to(self.device)
            max_length = input_ids.shape[1] + max_new_tokens
            
            # Original model with greedy decoding
            self.original_model.eval()
            with torch.no_grad():
                out1 = self.original_model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    generation_config=_build_generation_config(self.original_model, max_length),
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            tokens1 = out1[0].tolist()
            metrics1 = GenerationMetrics.compute_all_metrics(tokens1[input_ids.shape[1]:])
            results["original_greedy"].append(metrics1)
            
            # Original model - matched temperature (sample then take argmax)
            # This simulates what temperature does to the distribution
            generated_tokens = input_ids.tolist()[0]
            current_input = input_ids.clone()
            
            for _ in range(max_new_tokens):
                with torch.no_grad():
                    outputs = self.original_model(current_input)
                    logits = outputs.logits[:, -1, :] / matched_temperature
                    # Take argmax (greedy on scaled distribution)
                    next_token = logits.argmax(dim=-1, keepdim=True)
                    generated_tokens.append(next_token.item())
                    current_input = torch.cat([current_input, next_token], dim=-1)
                    
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
            
            metrics2 = GenerationMetrics.compute_all_metrics(generated_tokens[input_ids.shape[1]:])
            results["original_matched_temp"].append(metrics2)
            
            # Hyperfitted model with greedy decoding
            self.hyperfitted_model.eval()
            with torch.no_grad():
                out3 = self.hyperfitted_model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    generation_config=_build_generation_config(self.hyperfitted_model, max_length),
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            tokens3 = out3[0].tolist()
            metrics3 = GenerationMetrics.compute_all_metrics(tokens3[input_ids.shape[1]:])
            results["hyperfitted_greedy"].append(metrics3)
        
        aggregated = {}
        for model_type, metrics_list in results.items():
            aggregated[model_type] = {
                "mean_ttr": np.mean([m["ttr"] for m in metrics_list]),
                "std_ttr": np.std([m["ttr"] for m in metrics_list]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in metrics_list]),
                "mean_trigram_rep": np.mean([m["trigram_repetition"] for m in metrics_list]),
                "mean_length": np.mean([m["length"] for m in metrics_list]),
            }
        
        return {
            "individual_results": results,
            "aggregated": aggregated,
            "matched_temperature": matched_temperature,
        }
    
    def run(
        self,
        eval_texts: List[str],
        prompts: List[str],
        max_new_tokens: int = 224,
    ) -> ExperimentResult:
        """
        Run the full temperature matching experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 1: Temperature Matching")
        logger.info("=" * 60)
        
        # Compute hyperfitted entropy
        logger.info("Measuring hyperfitted model entropy")
        hyper_entropy = self.measure_hyperfitted_entropy(eval_texts)
        logger.info(f"Hyperfitted model average entropy: {hyper_entropy:.4f}")
        
        # Find matching temperature
        logger.info("Finding matching temperature for original model")
        matched_temp = self.find_matching_temperature(eval_texts, hyper_entropy)
        logger.info(f"Matched temperature: {matched_temp:.4f}")
        
        # Verify entropy match
        logger.info("Verifying entropy match")
        orig_entropy_default = 0.0
        orig_entropy_matched = 0.0
        
        for text in eval_texts[:20]:
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > 256:
                input_ids = input_ids[:, :256]
            
            default_result = self.orig_analyzer.analyze_distribution(input_ids, temperature=1.0)
            matched_result = self.orig_analyzer.analyze_distribution(input_ids, temperature=matched_temp)
            
            orig_entropy_default += default_result["entropy"]
            orig_entropy_matched += matched_result["entropy"]
        
        orig_entropy_default /= 20
        orig_entropy_matched /= 20
        
        logger.info(f"Original model entropy (temp=1.0): {orig_entropy_default:.4f}")
        logger.info(f"Original model entropy (temp={matched_temp:.3f}): {orig_entropy_matched:.4f}")
        logger.info(f"Hyperfitted model entropy: {hyper_entropy:.4f}")
        
        # Step 4: Generate and compare
        logger.info("Step 4: Generating and comparing outputs...")
        comparison = self.generate_and_compare(prompts, matched_temp, max_new_tokens)
        
        # Summary
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        for model_type, metrics in comparison["aggregated"].items():
            logger.info(f"\n{model_type}:")
            logger.info(f"  Mean TTR: {metrics['mean_ttr']:.4f} (±{metrics['std_ttr']:.4f})")
            logger.info(f"  Mean Bigram Repetition: {metrics['mean_bigram_rep']:.4f}")
            logger.info(f"  Mean Trigram Repetition: {metrics['mean_trigram_rep']:.4f}")
        
        return ExperimentResult(
            experiment_name="temperature_matching",
            model_name=str(type(self.original_model).__name__),
            results={
                "hyperfitted_entropy": hyper_entropy,
                "matched_temperature": matched_temp,
                "original_entropy_default": orig_entropy_default,
                "original_entropy_matched": orig_entropy_matched,
                "generation_comparison": comparison,
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "num_prompts": len(prompts),
                "max_new_tokens": max_new_tokens,
            },
        )


class Experiment2_RankAnalysis:
    """
    Experiment 2: Rank Analysis
    
    Goal: Show that hyperfitting changes WHICH tokens are top-ranked, not just distribution shape
    
    Method:
    1. For each position, compare top-k rankings between original and hyperfitted
    2. Compute rank correlation
    3. Identify tokens that get significantly promoted/demoted
    
    Expected Result: Significant rank changes, not just probability reweighting
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.rank_analyzer = RankAnalyzer(
            original_model,
            hyperfitted_model,
            tokenizer,
            device,
        )
        self.tokenizer = tokenizer
        self.device = device
    
    def run(
        self,
        eval_texts: List[str],
        max_seq_length: int = 256,
    ) -> ExperimentResult:
        """
        Run the rank analysis experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 2: Rank Analysis")
        logger.info("=" * 60)
        
        all_top1_comparisons = []
        all_rank_correlations = []
        all_promoted_tokens = []
        
        for text in tqdm(eval_texts, desc="Analyzing ranks"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            # Compare top-1 predictions
            top1_comparison = self.rank_analyzer.compare_top1_predictions(input_ids)
            all_top1_comparisons.append(top1_comparison)
            
            # Compute rank correlations
            rank_corr = self.rank_analyzer.compute_rank_correlation(input_ids, top_k=100)
            all_rank_correlations.append(rank_corr)
            
            # Find promoted tokens
            promoted = self.rank_analyzer.find_promoted_tokens(input_ids, threshold=50)
            all_promoted_tokens.extend(promoted[:20])  # Keep top 20 per text
        
        aggregated_top1 = {
            "top1_agreement": np.mean([c["top1_agreement"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top5": np.mean([c["hyper_top1_in_orig_top5"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top10": np.mean([c["hyper_top1_in_orig_top10"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top50": np.mean([c["hyper_top1_in_orig_top50"] for c in all_top1_comparisons]),
            "hyper_top1_in_orig_top100": np.mean([c["hyper_top1_in_orig_top100"] for c in all_top1_comparisons]),
        }
        
        aggregated_corr = {
            "mean_rank_correlation": np.mean([c["mean_rank_correlation"] for c in all_rank_correlations]),
            "std_rank_correlation": np.std([c["mean_rank_correlation"] for c in all_rank_correlations]),
        }
        
        # Sort promoted tokens by rank improvement
        all_promoted_tokens.sort(key=lambda x: x["rank_improvement"], reverse=True)
        top_promoted = all_promoted_tokens[:100]  # Keep top 100
        
        # Summary
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        logger.info(f"\nTop-1 Agreement Rate: {aggregated_top1['top1_agreement']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-5: {aggregated_top1['hyper_top1_in_orig_top5']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-10: {aggregated_top1['hyper_top1_in_orig_top10']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-50: {aggregated_top1['hyper_top1_in_orig_top50']:.4f}")
        logger.info(f"Hyperfitted top-1 in Original top-100: {aggregated_top1['hyper_top1_in_orig_top100']:.4f}")
        logger.info(f"\nMean Rank Correlation (top-100): {aggregated_corr['mean_rank_correlation']:.4f}")
        
        logger.info("\nTop 10 Most Promoted Tokens:")
        for i, token_info in enumerate(top_promoted[:10]):
            logger.info(f"  {i+1}. '{token_info['token_str']}': rank {token_info['original_rank']} -> {token_info['hyperfitted_rank']} (+{token_info['rank_improvement']})")
        
        return ExperimentResult(
            experiment_name="rank_analysis",
            model_name="comparison",
            results={
                "top1_comparison": aggregated_top1,
                "rank_correlation": aggregated_corr,
                "promoted_tokens": top_promoted,
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "max_seq_length": max_seq_length,
            },
        )


class Experiment3_SyntheticHyperfitting:
    """
    Experiment 3: Synthetic Hyperfitting
    
    Goal: Can we achieve hyperfitting effects WITHOUT training?
    
    Method:
    1. Learn which tokens are systematically promoted by hyperfitting
    2. Create a simple logit adjustment function
    3. Apply to original model and test if generation improves
    
    Expected Result: If it works -> hyperfitting is about rank modification
                     If it doesn't -> something deeper changes (representations)
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.rank_correction = None
    
    def learn_rank_corrections(
        self,
        calibration_texts: List[str],
        max_seq_length: int = 256,
    ) -> Dict[int, float]:
        """
        Learn token-level rank corrections from comparing models
        
        Returns:
            Dict mapping token_id -> average rank improvement
        """    
        rank_improvements = {}  # token_id -> list of improvements
        
        for text in tqdm(calibration_texts, desc="Learning corrections"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            # Get rankings from both models
            self.original_model.eval()
            self.hyperfitted_model.eval()
            
            with torch.no_grad():
                orig_logits = self.original_model(input_ids).logits
                hyper_logits = self.hyperfitted_model(input_ids).logits
            
            orig_rankings = torch.argsort(orig_logits, dim=-1, descending=True)
            hyper_rankings = torch.argsort(hyper_logits, dim=-1, descending=True)
            
            seq_len = input_ids.shape[1]
            vocab_size = orig_logits.shape[-1]
            
            for pos in range(seq_len):
                # Look at top-100 tokens in hyperfitted model
                for new_rank in range(min(100, vocab_size)):
                    token_id = hyper_rankings[0, pos, new_rank].item()
                    
                    # Find original rank
                    orig_rank_tensor = (orig_rankings[0, pos, :] == token_id).nonzero(as_tuple=True)[0]
                    if len(orig_rank_tensor) > 0:
                        orig_rank = orig_rank_tensor[0].item()
                        improvement = orig_rank - new_rank
                        
                        if token_id not in rank_improvements:
                            rank_improvements[token_id] = []
                        rank_improvements[token_id].append(improvement)
        
        # Average the improvements
        self.rank_correction = {
            k: np.mean(v) for k, v in rank_improvements.items() if len(v) >= 5
        }
        
        logger.info(f"Learned corrections for {len(self.rank_correction)} tokens")
        
        return self.rank_correction
    
    def apply_synthetic_correction(
        self,
        logits: torch.Tensor,
        scale: float = 0.1,
    ) -> torch.Tensor:
        """
        Apply learned rank corrections to logits.
        """
        if self.rank_correction is None:
            raise ValueError("Must call learn_rank_corrections first")
        
        correction = torch.zeros_like(logits)
        
        for token_id, improvement in self.rank_correction.items():
            if token_id < logits.shape[-1]:
                correction[:, :, token_id] = improvement * scale
        
        return logits + correction
    
    def generate_with_synthetic_correction(
        self,
        prompt: str,
        max_new_tokens: int = 224,
        scale: float = 0.1,
    ) -> List[int]:
        """
        Generate text using synthetic correction.
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated = input_ids.clone()
        
        self.original_model.eval()
        
        for _ in range(max_new_tokens):
            with torch.no_grad():
                outputs = self.original_model(generated)
                logits = outputs.logits[:, -1:, :]
                
                # Apply correction
                corrected_logits = self.apply_synthetic_correction(logits, scale)
                
                # Greedy selection
                next_token = corrected_logits[:, -1, :].argmax(dim=-1, keepdim=True)
                generated = torch.cat([generated, next_token], dim=-1)
                
                if next_token.item() == self.tokenizer.eos_token_id:
                    break
        
        return generated[0].tolist()
    
    def run(
        self,
        calibration_texts: List[str],
        test_prompts: List[str],
        max_new_tokens: int = 224,
        scales: List[float] = [0.01, 0.05, 0.1, 0.2, 0.5],
    ) -> ExperimentResult:
        """
        Run the synthetic hyperfitting experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 3: Synthetic Hyperfitting")
        logger.info("=" * 60)
        
        # Learn corrections
        self.learn_rank_corrections(calibration_texts)
        
        # Test different scales
        results_by_scale = {}
        
        for scale in scales:
            logger.info(f"Testing scale={scale}")
            
            scale_metrics = []
            
            for prompt in tqdm(test_prompts, desc=f"Scale {scale}"):
                # Generate with synthetic correction
                tokens = self.generate_with_synthetic_correction(
                    prompt, max_new_tokens, scale
                )
                
                input_len = len(self.tokenizer.encode(prompt))
                generated_tokens = tokens[input_len:]
                
                metrics = GenerationMetrics.compute_all_metrics(generated_tokens)
                scale_metrics.append(metrics)
            
            results_by_scale[scale] = {
                "mean_ttr": np.mean([m["ttr"] for m in scale_metrics]),
                "std_ttr": np.std([m["ttr"] for m in scale_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in scale_metrics]),
                "mean_length": np.mean([m["length"] for m in scale_metrics]),
            }
        
        # Compare with baselines
        logger.info("Comparing with baselines")
        
        # Original model baseline
        original_metrics = []
        for prompt in tqdm(test_prompts, desc="Original model"):
            encoding = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
            input_ids = encoding["input_ids"].to(self.device)
            attention_mask = encoding["attention_mask"].to(self.device)
            max_length = input_ids.shape[1] + max_new_tokens
            with torch.no_grad():
                out = self.original_model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    generation_config=_build_generation_config(self.original_model, max_length),
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            generated_tokens = out[0].tolist()[input_ids.shape[1]:]
            original_metrics.append(GenerationMetrics.compute_all_metrics(generated_tokens))
        
        # Hyperfitted model baseline
        hyperfitted_metrics = []
        for prompt in tqdm(test_prompts, desc="Hyperfitted model"):
            encoding = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
            input_ids = encoding["input_ids"].to(self.device)
            attention_mask = encoding["attention_mask"].to(self.device)
            max_length = input_ids.shape[1] + max_new_tokens
            with torch.no_grad():
                out = self.hyperfitted_model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    generation_config=_build_generation_config(self.hyperfitted_model, max_length),
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            generated_tokens = out[0].tolist()[input_ids.shape[1]:]
            hyperfitted_metrics.append(GenerationMetrics.compute_all_metrics(generated_tokens))
        
        baselines = {
            "original": {
                "mean_ttr": np.mean([m["ttr"] for m in original_metrics]),
                "std_ttr": np.std([m["ttr"] for m in original_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in original_metrics]),
            },
            "hyperfitted": {
                "mean_ttr": np.mean([m["ttr"] for m in hyperfitted_metrics]),
                "std_ttr": np.std([m["ttr"] for m in hyperfitted_metrics]),
                "mean_bigram_rep": np.mean([m["bigram_repetition"] for m in hyperfitted_metrics]),
            },
        }
        
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        logger.info(f"\nOriginal Model: TTR={baselines['original']['mean_ttr']:.4f}")
        logger.info(f"Hyperfitted Model: TTR={baselines['hyperfitted']['mean_ttr']:.4f}")
        logger.info("\nSynthetic Corrections:")
        for scale, metrics in results_by_scale.items():
            logger.info(f"  Scale {scale}: TTR={metrics['mean_ttr']:.4f}")
        
        return ExperimentResult(
            experiment_name="synthetic_hyperfitting",
            model_name="comparison",
            results={
                "baselines": baselines,
                "synthetic_by_scale": results_by_scale,
                "num_corrections": len(self.rank_correction),
            },
            metadata={
                "num_calibration_texts": len(calibration_texts),
                "num_test_prompts": len(test_prompts),
                "scales_tested": scales,
            },
        )


class Experiment4_RepresentationAnalysis:
    """
    Experiment 4: Representation Analysis
    
    Goal: Find where in the model hyperfitting causes changes
    
    Method:
    1. Compare hidden states layer by layer
    2. Compute cosine similarity and effective dimensionality
    3. Identify which layers change most
    
    Expected Result: Identify critical layers for hyperfitting effect
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
    
    def compute_effective_dimensionality(self, hidden_states: torch.Tensor) -> float:
        """
        Compute participation ratio (effective dimensionality)
        
        PR = (sum of eigenvalues)^2 / (sum of eigenvalues^2)
        """
        h = hidden_states.view(-1, hidden_states.shape[-1]).float()
        h_centered = h - h.mean(dim=0)
        
        # Covariance matrix
        cov = torch.mm(h_centered.T, h_centered) / h.shape[0]
        
        # Eigenvalues
        try:
            eigvals = torch.linalg.eigvalsh(cov).clamp(min=1e-10)
            pr = (eigvals.sum() ** 2) / (eigvals ** 2).sum()
            return pr.item()
        except:
            return 0.0
    
    def compare_layers(
        self,
        input_ids: torch.Tensor,
    ) -> List[Dict]:
        """
        Compare hidden states at each layer. Probably the most important part of the experiment.
        """
        input_ids = input_ids.to(self.device)
        
        self.original_model.eval()
        self.hyperfitted_model.eval()
        
        with torch.no_grad():
            orig_out = self.original_model(input_ids, output_hidden_states=True)
            hyper_out = self.hyperfitted_model(input_ids, output_hidden_states=True)
        
        results = []
        
        for layer_idx, (orig_h, hyper_h) in enumerate(
            zip(orig_out.hidden_states, hyper_out.hidden_states)
        ):
            # Cosine similarity
            orig_flat = orig_h.view(-1, orig_h.shape[-1])
            hyper_flat = hyper_h.view(-1, hyper_h.shape[-1])
            
            cos_sim = F.cosine_similarity(orig_flat, hyper_flat, dim=-1).mean().item()
            
            # L2 distance
            l2_dist = torch.norm(orig_h - hyper_h, dim=-1).mean().item()
            
            # Effective dimensionality
            orig_dim = self.compute_effective_dimensionality(orig_h)
            hyper_dim = self.compute_effective_dimensionality(hyper_h)
            
            # Norm
            orig_norm = orig_h.norm(dim=-1).mean().item()
            hyper_norm = hyper_h.norm(dim=-1).mean().item()
            
            results.append({
                "layer": layer_idx,
                "cosine_similarity": cos_sim,
                "l2_distance": l2_dist,
                "original_effective_dim": orig_dim,
                "hyperfitted_effective_dim": hyper_dim,
                "dim_change": hyper_dim - orig_dim,
                "original_norm": orig_norm,
                "hyperfitted_norm": hyper_norm,
                "norm_change": hyper_norm - orig_norm,
            })
        
        return results
    
    def run(
        self,
        eval_texts: List[str],
        max_seq_length: int = 256,
    ) -> ExperimentResult:
        """
        Run the representation analysis experiment.
        """
        logger.info("=" * 60)
        logger.info("Experiment 4: Representation Analysis")
        logger.info("=" * 60)
        
        all_layer_results = []
        
        for text in tqdm(eval_texts, desc="Analyzing representations"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            layer_results = self.compare_layers(input_ids)
            all_layer_results.append(layer_results)
        
        # Aggregate by layer
        num_layers = len(all_layer_results[0])
        aggregated = []
        
        for layer_idx in range(num_layers):
            layer_data = [r[layer_idx] for r in all_layer_results]
            
            aggregated.append({
                "layer": layer_idx,
                "mean_cosine_sim": np.mean([d["cosine_similarity"] for d in layer_data]),
                "std_cosine_sim": np.std([d["cosine_similarity"] for d in layer_data]),
                "mean_l2_dist": np.mean([d["l2_distance"] for d in layer_data]),
                "mean_dim_change": np.mean([d["dim_change"] for d in layer_data]),
                "mean_norm_change": np.mean([d["norm_change"] for d in layer_data]),
            })
        
        # Find most changed layers
        most_changed = sorted(aggregated, key=lambda x: x["mean_l2_dist"], reverse=True)[:5]
        
        logger.info("\n" + "=" * 60)
        logger.info("RESULTS SUMMARY")
        logger.info("=" * 60)
        logger.info("\nLayer-wise Analysis (first 5 and last 5 layers):")
        
        for layer_info in aggregated[:5]:
            logger.info(f"  Layer {layer_info['layer']}: cos_sim={layer_info['mean_cosine_sim']:.4f}, L2={layer_info['mean_l2_dist']:.4f}")
        logger.info("  ...")
        for layer_info in aggregated[-5:]:
            logger.info(f"  Layer {layer_info['layer']}: cos_sim={layer_info['mean_cosine_sim']:.4f}, L2={layer_info['mean_l2_dist']:.4f}")
        
        logger.info("\nMost Changed Layers (by L2 distance):")
        for layer_info in most_changed:
            logger.info(f"  Layer {layer_info['layer']}: L2={layer_info['mean_l2_dist']:.4f}, dim_change={layer_info['mean_dim_change']:.2f}")
        
        return ExperimentResult(
            experiment_name="representation_analysis",
            model_name="comparison",
            results={
                "layer_wise_analysis": aggregated,
                "most_changed_layers": most_changed,
            },
            metadata={
                "num_eval_texts": len(eval_texts),
                "num_layers": num_layers,
                "max_seq_length": max_seq_length,
            },
        )


def save_results(result: ExperimentResult, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, f"{result.experiment_name}_results.json")
    
    # Convert to serializable format
    output = {
        "experiment_name": result.experiment_name,
        "model_name": result.model_name,
        "metadata": result.metadata,
        "results": {},
    }
    
    # Convert numpy arrays and other non-serializable types
    def convert_to_serializable(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, dict):
            return {k: convert_to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(item) for item in obj]
        else:
            return obj
    
    output["results"] = convert_to_serializable(result.results)
    
    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)
    
    logger.info(f"Results saved to {output_path}")
