import os
import json
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from collections import Counter, defaultdict
from dataclasses import dataclass
from scipy.stats import entropy as scipy_entropy
from scipy.spatial.distance import cosine as cosine_distance
import logging
import re

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class TokenCategory:
    """Categories for token analysis"""
    CONTENT_WORD = "content_word"  # Nouns, verbs, adjectives, adverbs
    FUNCTION_WORD = "function_word"  # Articles, prepositions, conjunctions
    PUNCTUATION = "punctuation"
    NUMBER = "number"
    SUBWORD = "subword"  # Tokens starting with ## or similar
    RARE_WORD = "rare_word"  # Low frequency in training data
    FOREIGN = "foreign"  # Non-English tokens
    SPECIAL = "special"  # Special tokens like <eos>, <pad>
    OTHER = "other"


class TokenCategoryAnalyzer:
    """Analyze what types of tokens get promoted during hyperfitting"""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
        # Common function words in English
        self.function_words = {
            'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
            'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
            'should', 'may', 'might', 'must', 'shall', 'can', 'need', 'dare',
            'ought', 'used', 'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by',
            'from', 'as', 'into', 'through', 'during', 'before', 'after', 'above',
            'below', 'between', 'under', 'again', 'further', 'then', 'once',
            'and', 'but', 'or', 'nor', 'so', 'yet', 'both', 'either', 'neither',
            'not', 'only', 'own', 'same', 'than', 'too', 'very', 'just', 'also',
            'now', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'each',
            'every', 'both', 'few', 'more', 'most', 'other', 'some', 'such', 'no',
            'any', 'this', 'that', 'these', 'those', 'i', 'me', 'my', 'myself',
            'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours', 'yourself',
            'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it',
            'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves',
            'what', 'which', 'who', 'whom', 'whose',
        }
        
        # Build token frequency from tokenizer vocab if possible
        self.token_frequencies = self._estimate_token_frequencies()
    
    def _estimate_token_frequencies(self) -> Dict[int, float]:
        """Estimate token frequencies from vocab (rough approximation)"""
        # use token ID as proxy (lower IDs tend to be more common)
        frequencies = {}
        vocab_size = self.tokenizer.vocab_size
        for i in range(vocab_size):
            # Zipf-like distribution approximation
            frequencies[i] = 1.0 / (np.log(i + 2))
        return frequencies
    
    def categorize_token(self, token_id: int) -> str:
        """Categorize a single token"""
        try:
            token_str = self.tokenizer.decode([token_id]).strip().lower()
        except:
            return TokenCategory.OTHER
        
        # Check for special tokens
        if token_id in self.tokenizer.all_special_ids:
            return TokenCategory.SPECIAL
        
        # Check for punctuation
        if re.match(r'^[^\w\s]+$', token_str) or token_str in '.,!?;:\'"()-[]{}':
            return TokenCategory.PUNCTUATION
        
        # Check for numbers
        if re.match(r'^[\d.,]+$', token_str):
            return TokenCategory.NUMBER
        
        # Check for subword tokens (starts with special chars like Ġ, ##, etc.)
        raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
        if raw_token and (raw_token.startswith('##') or raw_token.startswith('Ġ') or 
                         raw_token.startswith('▁') or raw_token.startswith('Ċ')):
            # It's a subword, but categorize based on content
            clean_token = re.sub(r'^[##Ġ▁Ċ]+', '', raw_token).lower()
            if clean_token in self.function_words:
                return TokenCategory.FUNCTION_WORD
            return TokenCategory.SUBWORD
        
        # Check for non-ASCII (likely foreign)
        if not token_str.isascii():
            return TokenCategory.FOREIGN
        
        # Check for function words
        if token_str in self.function_words:
            return TokenCategory.FUNCTION_WORD
        
        # Check for rare words (low frequency proxy)
        if self.token_frequencies.get(token_id, 0) < 0.001:
            return TokenCategory.RARE_WORD
        
        # Default to content word
        return TokenCategory.CONTENT_WORD
    
    def analyze_promoted_tokens(
        self,
        promoted_tokens: List[Dict],
    ) -> Dict:
        """
        Analyze the categories of promoted tokens
        
        Args:
            promoted_tokens: List of dicts with 'token_id', 'rank_improvement', etc.
        
        Returns:
            Analysis results with category distributions
        """
        category_counts = Counter()
        category_improvements = defaultdict(list)
        category_examples = defaultdict(list)
        
        for token_info in promoted_tokens:
            token_id = token_info['token_id']
            category = self.categorize_token(token_id)
            
            category_counts[category] += 1
            category_improvements[category].append(token_info['rank_improvement'])
            
            if len(category_examples[category]) < 10:
                category_examples[category].append({
                    'token': token_info.get('token_str', self.tokenizer.decode([token_id])),
                    'original_rank': token_info['original_rank'],
                    'hyperfitted_rank': token_info['hyperfitted_rank'],
                    'improvement': token_info['rank_improvement'],
                })
        
        # Compute statistics
        total_tokens = len(promoted_tokens)
        results = {
            'total_promoted_tokens': total_tokens,
            'category_distribution': {},
            'category_mean_improvement': {},
            'category_examples': dict(category_examples),
        }
        
        for category in category_counts:
            count = category_counts[category]
            results['category_distribution'][category] = {
                'count': count,
                'percentage': count / total_tokens * 100 if total_tokens > 0 else 0,
            }
            improvements = category_improvements[category]
            results['category_mean_improvement'][category] = {
                'mean': np.mean(improvements) if improvements else 0,
                'std': np.std(improvements) if improvements else 0,
                'max': max(improvements) if improvements else 0,
            }
        
        return results


class ContextualPromotionAnalyzer:
    """Analyze WHY tokens get promoted in specific contexts"""
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
    
    def analyze_promotion_contexts(
        self,
        input_ids: torch.Tensor,
        top_k_promoted: int = 10,
    ) -> List[Dict]:
        """
        For each position where a token is promoted, analyze the context
        to understand why the promotion might improve generation.
        """
        input_ids = input_ids.to(self.device)
        
        self.original_model.eval()
        self.hyperfitted_model.eval()
        
        with torch.no_grad():
            orig_logits = self.original_model(input_ids).logits
            hyper_logits = self.hyperfitted_model(input_ids).logits
        
        orig_rankings = torch.argsort(orig_logits, dim=-1, descending=True)
        hyper_rankings = torch.argsort(hyper_logits, dim=-1, descending=True)
        
        promotion_analyses = []
        batch_size, seq_len, vocab_size = orig_logits.shape
        
        for b in range(batch_size):
            for pos in range(1, seq_len):  # Skip position 0
                hyper_top1 = hyper_rankings[b, pos, 0].item()
                orig_top1 = orig_rankings[b, pos, 0].item()
                
                if hyper_top1 != orig_top1:
                    # Find original rank of hyperfitted top-1
                    orig_rank = (orig_rankings[b, pos] == hyper_top1).nonzero(as_tuple=True)[0]
                    orig_rank = orig_rank[0].item() if len(orig_rank) > 0 else vocab_size
                    
                    if orig_rank > 10:  # Significant promotion
                        # Get context (previous tokens)
                        context_start = max(0, pos - 10)
                        context_ids = input_ids[b, context_start:pos].tolist()
                        context_text = self.tokenizer.decode(context_ids)
                        
                        # Get the promoted token and original top-1
                        promoted_token = self.tokenizer.decode([hyper_top1])
                        original_choice = self.tokenizer.decode([orig_top1])
                        
                        # Analyze what follows in the original text (if available)
                        actual_next = input_ids[b, pos].item() if pos < seq_len else None
                        actual_next_token = self.tokenizer.decode([actual_next]) if actual_next else None
                        
                        # Check if promotion avoids repetition
                        recent_tokens = input_ids[b, max(0, pos-5):pos].tolist()
                        avoids_repetition = orig_top1 in recent_tokens and hyper_top1 not in recent_tokens
                        
                        # Compute entropy change
                        orig_entropy = -torch.sum(
                            F.softmax(orig_logits[b, pos], dim=-1) * 
                            F.log_softmax(orig_logits[b, pos], dim=-1)
                        ).item()
                        hyper_entropy = -torch.sum(
                            F.softmax(hyper_logits[b, pos], dim=-1) * 
                            F.log_softmax(hyper_logits[b, pos], dim=-1)
                        ).item()
                        
                        promotion_analyses.append({
                            'position': pos,
                            'context': context_text,
                            'promoted_token': promoted_token,
                            'promoted_token_id': hyper_top1,
                            'original_choice': original_choice,
                            'original_choice_id': orig_top1,
                            'original_rank_of_promoted': orig_rank,
                            'actual_next_token': actual_next_token,
                            'avoids_repetition': avoids_repetition,
                            'original_entropy': orig_entropy,
                            'hyperfitted_entropy': hyper_entropy,
                            'entropy_reduction': orig_entropy - hyper_entropy,
                        })
        
        # Sort by rank improvement and return top-k
        promotion_analyses.sort(key=lambda x: x['original_rank_of_promoted'], reverse=True)
        return promotion_analyses[:top_k_promoted * batch_size]


class ExperimentA_QualitativeAnalysis:
    """
    Full qualitative analysis experiment
    
    Goal: Explain WHY promoted tokens improve generation quality
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.category_analyzer = TokenCategoryAnalyzer(tokenizer)
        self.context_analyzer = ContextualPromotionAnalyzer(
            original_model, hyperfitted_model, tokenizer, device
        )
    
    def run(
        self,
        eval_texts: List[str],
        promoted_tokens: List[Dict],
        max_seq_length: int = 256,
    ) -> Dict:
        """
        Run full qualitative analysis
        """
        logger.info("=" * 60)
        logger.info("Experiment A: Qualitative Analysis of Token Promotion")
        logger.info("=" * 60)
        
        # Part 1: Categorize promoted tokens
        logger.info("Part 1: Categorizing promoted tokens...")
        category_results = self.category_analyzer.analyze_promoted_tokens(promoted_tokens)
        
        # Part 2: Analyze promotion contexts
        logger.info("Part 2: Analyzing promotion contexts...")
        all_context_analyses = []
        
        for text in tqdm(eval_texts[:20], desc="Analyzing contexts"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            analyses = self.context_analyzer.analyze_promotion_contexts(input_ids)
            all_context_analyses.extend(analyses)
        
        # Part 3: Summarize findings
        logger.info("Part 3: Summarizing findings...")
        
        # Count repetition avoidance
        repetition_avoidance_count = sum(
            1 for a in all_context_analyses if a['avoids_repetition']
        )
        total_analyses = len(all_context_analyses)
        
        # Entropy statistics
        entropy_reductions = [a['entropy_reduction'] for a in all_context_analyses]
        
        summary = {
            'total_promotions_analyzed': total_analyses,
            'repetition_avoidance_rate': repetition_avoidance_count / total_analyses if total_analyses > 0 else 0,
            'mean_entropy_reduction': np.mean(entropy_reductions) if entropy_reductions else 0,
            'std_entropy_reduction': np.std(entropy_reductions) if entropy_reductions else 0,
        }
        
        # Print key findings
        logger.info("\n" + "=" * 60)
        logger.info("KEY FINDINGS")
        logger.info("=" * 60)
        
        logger.info("\nToken Category Distribution:")
        for cat, stats in sorted(
            category_results['category_distribution'].items(),
            key=lambda x: x[1]['percentage'],
            reverse=True
        ):
            logger.info(f"  {cat}: {stats['percentage']:.1f}% ({stats['count']} tokens)")
        
        logger.info(f"\nRepetition Avoidance Rate: {summary['repetition_avoidance_rate']*100:.1f}%")
        logger.info(f"Mean Entropy Reduction: {summary['mean_entropy_reduction']:.4f}")
        
        logger.info("\nExample Promotions (showing context):")
        for i, analysis in enumerate(all_context_analyses[:5]):
            logger.info(f"\n  Example {i+1}:")
            logger.info(f"    Context: '...{analysis['context'][-50:]}'")
            logger.info(f"    Original would choose: '{analysis['original_choice']}'")
            logger.info(f"    Hyperfitted chooses: '{analysis['promoted_token']}' (was rank {analysis['original_rank_of_promoted']})")
            logger.info(f"    Avoids repetition: {analysis['avoids_repetition']}")
        
        return {
            'category_analysis': category_results,
            'context_analyses': all_context_analyses[:100],  # Keep top 100
            'summary': summary,
        }



class ExperimentB_FineTuningComparison:
    """
    Compare layer-wise changes between hyperfitting and normal fine-tuning
    
    Goal: Show that hyperfitting's terminal explosion is qualitatively different
    from normal fine-tuning dynamics
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        normal_finetuned_model,  # Model fine-tuned with early stopping
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.normal_finetuned_model = normal_finetuned_model
        self.tokenizer = tokenizer
        self.device = device
    
    def compute_layer_changes(
        self,
        model_a,
        model_b,
        input_ids: torch.Tensor,
    ) -> List[Dict]:
        """Compute layer-wise changes between two models"""
        input_ids = input_ids.to(self.device)
        
        model_a.eval()
        model_b.eval()
        
        with torch.no_grad():
            out_a = model_a(input_ids, output_hidden_states=True)
            out_b = model_b(input_ids, output_hidden_states=True)
        
        results = []
        for layer_idx, (h_a, h_b) in enumerate(
            zip(out_a.hidden_states, out_b.hidden_states)
        ):
            # Cosine similarity
            a_flat = h_a.view(-1, h_a.shape[-1])
            b_flat = h_b.view(-1, h_b.shape[-1])
            cos_sim = F.cosine_similarity(a_flat, b_flat, dim=-1).mean().item()
            
            # L2 distance
            l2_dist = torch.norm(h_a - h_b, dim=-1).mean().item()
            
            # Relative change (normalized by original norm)
            orig_norm = h_a.norm(dim=-1).mean().item()
            relative_change = l2_dist / (orig_norm + 1e-8)
            
            results.append({
                'layer': layer_idx,
                'cosine_similarity': cos_sim,
                'l2_distance': l2_dist,
                'relative_change': relative_change,
            })
        
        return results
    
    def run(
        self,
        eval_texts: List[str],
        max_seq_length: int = 256,
    ) -> Dict:
        """
        Compare hyperfitting vs normal fine-tuning layer changes
        """
        logger.info("=" * 60)
        logger.info("Experiment B: Hyperfitting vs Normal Fine-tuning")
        logger.info("=" * 60)
        
        hyper_changes_all = []
        normal_changes_all = []
        
        for text in tqdm(eval_texts[:30], desc="Comparing layer changes"):
            input_ids = self.tokenizer.encode(text, return_tensors="pt")
            if input_ids.shape[1] > max_seq_length:
                input_ids = input_ids[:, :max_seq_length]
            
            # Compare original vs hyperfitted
            hyper_changes = self.compute_layer_changes(
                self.original_model, self.hyperfitted_model, input_ids
            )
            hyper_changes_all.append(hyper_changes)
            
            # Compare original vs normal fine-tuned
            normal_changes = self.compute_layer_changes(
                self.original_model, self.normal_finetuned_model, input_ids
            )
            normal_changes_all.append(normal_changes)
        
        # Aggregate results
        num_layers = len(hyper_changes_all[0])
        
        hyper_aggregated = []
        normal_aggregated = []
        
        for layer_idx in range(num_layers):
            hyper_layer = [c[layer_idx] for c in hyper_changes_all]
            normal_layer = [c[layer_idx] for c in normal_changes_all]
            
            hyper_aggregated.append({
                'layer': layer_idx,
                'mean_l2_dist': np.mean([d['l2_distance'] for d in hyper_layer]),
                'mean_relative_change': np.mean([d['relative_change'] for d in hyper_layer]),
                'mean_cos_sim': np.mean([d['cosine_similarity'] for d in hyper_layer]),
            })
            
            normal_aggregated.append({
                'layer': layer_idx,
                'mean_l2_dist': np.mean([d['l2_distance'] for d in normal_layer]),
                'mean_relative_change': np.mean([d['relative_change'] for d in normal_layer]),
                'mean_cos_sim': np.mean([d['cosine_similarity'] for d in normal_layer]),
            })
        
        # Compute terminal explosion ratio
        # (change in last layer / average change in other layers)
        def compute_explosion_ratio(aggregated):
            last_layer_change = aggregated[-1]['mean_l2_dist']
            other_layers_change = np.mean([a['mean_l2_dist'] for a in aggregated[:-1]])
            return last_layer_change / (other_layers_change + 1e-8)
        
        hyper_explosion_ratio = compute_explosion_ratio(hyper_aggregated)
        normal_explosion_ratio = compute_explosion_ratio(normal_aggregated)
        
        # Print findings
        logger.info("\n" + "=" * 60)
        logger.info("FINDINGS")
        logger.info("=" * 60)
        
        logger.info(f"\nTerminal Explosion Ratio (last layer / avg other layers):")
        logger.info(f"  Hyperfitting: {hyper_explosion_ratio:.2f}x")
        logger.info(f"  Normal Fine-tuning: {normal_explosion_ratio:.2f}x")
        
        if hyper_explosion_ratio > normal_explosion_ratio * 1.5:
            logger.info("\n  → CONCLUSION: Hyperfitting shows significantly stronger terminal explosion!")
        else:
            logger.info("\n  → WARNING: Terminal explosion may not be unique to hyperfitting")
        
        return {
            'hyperfitting_layer_changes': hyper_aggregated,
            'normal_finetuning_layer_changes': normal_aggregated,
            'hyperfitting_explosion_ratio': hyper_explosion_ratio,
            'normal_finetuning_explosion_ratio': normal_explosion_ratio,
        }


class SemanticDiversityMetrics:
    """Compute semantic diversity using embeddings"""
    
    def __init__(self, device: str = "cuda"):
        self.device = device
        self.embedding_model = None
        self.embedding_tokenizer = None
    
    def load_embedding_model(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        """Load sentence embedding model"""
        try:
            from sentence_transformers import SentenceTransformer
            self.embedding_model = SentenceTransformer(model_name, device=self.device)
            logger.info(f"Loaded embedding model: {model_name}")
        except ImportError:
            logger.warning("sentence-transformers not installed. Using fallback method.")
            # Fallback: use mean pooling of LM hidden states
            self.embedding_model = None
    
    def get_sentence_embeddings(self, texts: List[str]) -> np.ndarray:
        """Get embeddings for a list of texts"""
        if self.embedding_model is not None:
            return self.embedding_model.encode(texts, convert_to_numpy=True)
        else:
            # Fallback: simple TF-IDF-like representation
            from sklearn.feature_extraction.text import TfidfVectorizer
            vectorizer = TfidfVectorizer(max_features=512)
            return vectorizer.fit_transform(texts).toarray()
    
    def compute_pairwise_diversity(self, embeddings: np.ndarray) -> Dict:
        """
        Compute pairwise cosine distances between embeddings
        Higher distance = more diverse
        """
        n = len(embeddings)
        if n < 2:
            return {'mean_distance': 0, 'std_distance': 0, 'min_distance': 0, 'max_distance': 0}
        
        distances = []
        for i in range(n):
            for j in range(i + 1, n):
                dist = cosine_distance(embeddings[i], embeddings[j])
                if not np.isnan(dist):
                    distances.append(dist)
        
        return {
            'mean_distance': np.mean(distances) if distances else 0,
            'std_distance': np.std(distances) if distances else 0,
            'min_distance': np.min(distances) if distances else 0,
            'max_distance': np.max(distances) if distances else 0,
        }
    
    def compute_self_bleu(self, texts: List[str], n_gram: int = 4) -> float:
        """
        Compute Self-BLEU score (lower = more diverse)
        
        Self-BLEU measures how similar generated texts are to each other
        """
        from collections import Counter
        
        def get_ngrams(text: str, n: int) -> Counter:
            tokens = text.lower().split()
            return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
        
        if len(texts) < 2:
            return 0.0
        
        bleu_scores = []
        
        for i, hypothesis in enumerate(texts):
            references = [texts[j] for j in range(len(texts)) if j != i]
            
            # Compute modified precision for n-grams
            hyp_ngrams = get_ngrams(hypothesis, n_gram)
            if not hyp_ngrams:
                continue
            
            # Collect reference n-grams
            ref_ngrams = Counter()
            for ref in references:
                ref_ngrams.update(get_ngrams(ref, n_gram))
            
            # Clipped counts
            clipped = sum(min(count, ref_ngrams[ngram]) for ngram, count in hyp_ngrams.items())
            total = sum(hyp_ngrams.values())
            
            if total > 0:
                bleu_scores.append(clipped / total)
        
        return np.mean(bleu_scores) if bleu_scores else 0.0
    
    def compute_distinct_n(self, texts: List[str], n: int = 2) -> float:
        """
        Compute Distinct-N score (higher = more diverse)
        
        Ratio of unique n-grams to total n-grams across all texts
        """
        all_ngrams = []
        
        for text in texts:
            tokens = text.lower().split()
            ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
            all_ngrams.extend(ngrams)
        
        if not all_ngrams:
            return 0.0
        
        return len(set(all_ngrams)) / len(all_ngrams)


class ExperimentC_SemanticDiversity:
    """
    Semantic diversity evaluation
    
    Goal: Show hyperfitting improves semantic diversity, not just surface metrics
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.diversity_metrics = SemanticDiversityMetrics(device)
        self.diversity_metrics.load_embedding_model()
    
    def generate_texts(
        self,
        model,
        prompts: List[str],
        max_new_tokens: int = 224,
    ) -> List[str]:
        """Generate texts from a model"""
        model.eval()
        generated = []
        
        for prompt in tqdm(prompts, desc="Generating"):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                output = model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            
            generated_text = self.tokenizer.decode(
                output[0][input_ids.shape[1]:],
                skip_special_tokens=True
            )
            generated.append(generated_text)
        
        return generated
    
    def run(
        self,
        prompts: List[str],
        max_new_tokens: int = 224,
    ) -> Dict:
        """
        Run semantic diversity evaluation
        """
        logger.info("=" * 60)
        logger.info("Experiment C: Semantic Diversity Evaluation")
        logger.info("=" * 60)
        
        # Generate texts from both models
        logger.info("Generating from original model...")
        original_texts = self.generate_texts(
            self.original_model, prompts, max_new_tokens
        )
        
        logger.info("Generating from hyperfitted model...")
        hyperfitted_texts = self.generate_texts(
            self.hyperfitted_model, prompts, max_new_tokens
        )
        
        # Compute metrics
        logger.info("Computing semantic diversity metrics...")
        
        # Get embeddings
        original_embeddings = self.diversity_metrics.get_sentence_embeddings(original_texts)
        hyperfitted_embeddings = self.diversity_metrics.get_sentence_embeddings(hyperfitted_texts)
        
        # Pairwise diversity
        original_pairwise = self.diversity_metrics.compute_pairwise_diversity(original_embeddings)
        hyperfitted_pairwise = self.diversity_metrics.compute_pairwise_diversity(hyperfitted_embeddings)
        
        # Self-BLEU (lower = more diverse)
        original_self_bleu = self.diversity_metrics.compute_self_bleu(original_texts)
        hyperfitted_self_bleu = self.diversity_metrics.compute_self_bleu(hyperfitted_texts)
        
        # Distinct-N
        original_distinct_1 = self.diversity_metrics.compute_distinct_n(original_texts, n=1)
        original_distinct_2 = self.diversity_metrics.compute_distinct_n(original_texts, n=2)
        hyperfitted_distinct_1 = self.diversity_metrics.compute_distinct_n(hyperfitted_texts, n=1)
        hyperfitted_distinct_2 = self.diversity_metrics.compute_distinct_n(hyperfitted_texts, n=2)
        
        results = {
            'original': {
                'pairwise_diversity': original_pairwise,
                'self_bleu': original_self_bleu,
                'distinct_1': original_distinct_1,
                'distinct_2': original_distinct_2,
                'sample_texts': original_texts[:5],
            },
            'hyperfitted': {
                'pairwise_diversity': hyperfitted_pairwise,
                'self_bleu': hyperfitted_self_bleu,
                'distinct_1': hyperfitted_distinct_1,
                'distinct_2': hyperfitted_distinct_2,
                'sample_texts': hyperfitted_texts[:5],
            },
        }
        
        # Print findings
        logger.info("\n" + "=" * 60)
        logger.info("SEMANTIC DIVERSITY RESULTS")
        logger.info("=" * 60)
        
        logger.info("\nPairwise Semantic Diversity (embedding cosine distance, higher=more diverse):")
        logger.info(f"  Original: {original_pairwise['mean_distance']:.4f} ± {original_pairwise['std_distance']:.4f}")
        logger.info(f"  Hyperfitted: {hyperfitted_pairwise['mean_distance']:.4f} ± {hyperfitted_pairwise['std_distance']:.4f}")
        
        logger.info("\nSelf-BLEU (lower=more diverse):")
        logger.info(f"  Original: {original_self_bleu:.4f}")
        logger.info(f"  Hyperfitted: {hyperfitted_self_bleu:.4f}")
        
        logger.info("\nDistinct-1/Distinct-2 (higher=more diverse):")
        logger.info(f"  Original: {original_distinct_1:.4f} / {original_distinct_2:.4f}")
        logger.info(f"  Hyperfitted: {hyperfitted_distinct_1:.4f} / {hyperfitted_distinct_2:.4f}")
        
        return results


class DecodingBaselines:
    """Implement various decoding strategies for comparison"""
    
    def __init__(self, model, tokenizer, device: str = "cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
    
    def generate_greedy(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 224,
    ) -> torch.Tensor:
        """Standard greedy decoding"""
        self.model.eval()
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        return output
    
    def generate_nucleus_sampling(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 224,
        top_p: float = 0.9,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """Nucleus (top-p) sampling"""
        self.model.eval()
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        return output
    
    def generate_with_repetition_penalty(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 224,
        repetition_penalty: float = 1.2,
    ) -> torch.Tensor:
        """Greedy with repetition penalty"""
        self.model.eval()
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                repetition_penalty=repetition_penalty,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        return output
    
    def generate_contrastive_search(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 224,
        top_k: int = 4,
        penalty_alpha: float = 0.6,
    ) -> torch.Tensor:
        """Contrastive search decoding"""
        self.model.eval()
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                top_k=top_k,
                penalty_alpha=penalty_alpha,
                pad_token_id=self.tokenizer.eos_token_id,
                trust_remote_code=True
            )
        return output
    
    def generate_typical_sampling(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 224,
        typical_p: float = 0.95,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """Typical sampling (locally typical sampling)"""
        self.model.eval()
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                typical_p=typical_p,
                temperature=temperature,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        return output


class ExperimentD_BaselineComparison:
    """
    Compare hyperfitting with other repetition-mitigation methods
    
    Methods compared:
    1. Original model + decoding baselines
    2. Hyperfitted model + same decoding baselines
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
        
        self.original_baselines = DecodingBaselines(original_model, tokenizer, device)
        self.hyperfitted_baselines = DecodingBaselines(hyperfitted_model, tokenizer, device)
    
    def compute_generation_metrics(self, tokens: List[int]) -> Dict:
        """Compute all metrics for a generated sequence"""
        from metrics import compute_ttr, compute_ngram_repetition
        
        return {
            'ttr': compute_ttr(tokens, window_size=96),
            'bigram_rep': compute_ngram_repetition(tokens, n=2),
            'trigram_rep': compute_ngram_repetition(tokens, n=3),
            'length': len(tokens),
        }
    
    def run(
        self,
        prompts: List[str],
        max_new_tokens: int = 224,
        num_runs: int = 3,  # For stochastic methods, average over multiple runs
    ) -> Dict:
        """
        Run baseline comparison
        """
        logger.info("=" * 60)
        logger.info("Experiment D: Baseline Comparison")
        logger.info("=" * 60)
        
        methods = {
            'original_greedy': lambda inp: self.original_baselines.generate_greedy(inp, max_new_tokens),
            'original_nucleus_p0.9': lambda inp: self.original_baselines.generate_nucleus_sampling(inp, max_new_tokens, top_p=0.9),
            'original_nucleus_p0.95': lambda inp: self.original_baselines.generate_nucleus_sampling(inp, max_new_tokens, top_p=0.95),
            'original_rep_penalty_1.1': lambda inp: self.original_baselines.generate_with_repetition_penalty(inp, max_new_tokens, repetition_penalty=1.1),
            'original_rep_penalty_1.2': lambda inp: self.original_baselines.generate_with_repetition_penalty(inp, max_new_tokens, repetition_penalty=1.2),
            'original_contrastive': lambda inp: self.original_baselines.generate_contrastive_search(inp, max_new_tokens),
            'original_typical_p0.95': lambda inp: self.original_baselines.generate_typical_sampling(inp, max_new_tokens, typical_p=0.95),
            'hyperfitted_greedy': lambda inp: self.hyperfitted_baselines.generate_greedy(inp, max_new_tokens),
            'hyperfitted_nucleus_p0.9': lambda inp: self.hyperfitted_baselines.generate_nucleus_sampling(inp, max_new_tokens, top_p=0.9),
            'hyperfitted_nucleus_p0.95': lambda inp: self.hyperfitted_baselines.generate_nucleus_sampling(inp, max_new_tokens, top_p=0.95),
            'hyperfitted_rep_penalty_1.1': lambda inp: self.hyperfitted_baselines.generate_with_repetition_penalty(inp, max_new_tokens, repetition_penalty=1.1),
            'hyperfitted_rep_penalty_1.2': lambda inp: self.hyperfitted_baselines.generate_with_repetition_penalty(inp, max_new_tokens, repetition_penalty=1.2),
            'hyperfitted_contrastive': lambda inp: self.hyperfitted_baselines.generate_contrastive_search(inp, max_new_tokens),
            'hyperfitted_typical_p0.95': lambda inp: self.hyperfitted_baselines.generate_typical_sampling(inp, max_new_tokens, typical_p=0.95),
        }
        
        results = {method: [] for method in methods}
        
        for prompt in tqdm(prompts, desc="Evaluating methods"):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            input_len = input_ids.shape[1]
            
            for method_name, generate_fn in methods.items():
                # For stochastic methods, average over multiple runs
                is_stochastic = 'nucleus' in method_name or 'typical' in method_name
                runs = num_runs if is_stochastic else 1
                
                method_metrics = []
                for _ in range(runs):
                    try:
                        output = generate_fn(input_ids)
                        generated_tokens = output[0][input_len:].tolist()
                        metrics = self.compute_generation_metrics(generated_tokens)
                        method_metrics.append(metrics)
                    except Exception as e:
                        logger.warning(f"Error with {method_name}: {e}")
                        continue
                
                if method_metrics:
                    # Average metrics across runs
                    avg_metrics = {
                        'ttr': np.mean([m['ttr'] for m in method_metrics]),
                        'bigram_rep': np.mean([m['bigram_rep'] for m in method_metrics]),
                        'trigram_rep': np.mean([m['trigram_rep'] for m in method_metrics]),
                        'length': np.mean([m['length'] for m in method_metrics]),
                    }
                    results[method_name].append(avg_metrics)
        
        # Aggregate results
        aggregated = {}
        for method_name, metrics_list in results.items():
            if metrics_list:
                aggregated[method_name] = {
                    'mean_ttr': np.mean([m['ttr'] for m in metrics_list]),
                    'std_ttr': np.std([m['ttr'] for m in metrics_list]),
                    'mean_bigram_rep': np.mean([m['bigram_rep'] for m in metrics_list]),
                    'mean_trigram_rep': np.mean([m['trigram_rep'] for m in metrics_list]),
                    'mean_length': np.mean([m['length'] for m in metrics_list]),
                }
        
        # Print results table
        logger.info("\n" + "=" * 60)
        logger.info("BASELINE COMPARISON RESULTS")
        logger.info("=" * 60)
        
        logger.info("\n{:<30} {:>10} {:>12} {:>12}".format(
            "Method", "TTR ↑", "Bigram Rep ↓", "Trigram Rep ↓"
        ))
        logger.info("-" * 66)
        
        # Sort by TTR
        sorted_methods = sorted(
            aggregated.items(),
            key=lambda x: x[1]['mean_ttr'],
            reverse=True
        )
        
        for method_name, metrics in sorted_methods:
            logger.info("{:<30} {:>10.4f} {:>12.4f} {:>12.4f}".format(
                method_name,
                metrics['mean_ttr'],
                metrics['mean_bigram_rep'],
                metrics['mean_trigram_rep'],
            ))
        
        # Highlight if any hyperfitted method is best
        best_method = sorted_methods[0][0]
        best_hyper_method = next(
            (m for m, _ in sorted_methods if m.startswith('hyperfitted_')),
            None,
        )
        if best_method.startswith('hyperfitted_'):
            logger.info(f"\n  → CONCLUSION: {best_method} achieves BEST diversity!")
        elif best_hyper_method:
            hyper_rank = [i for i, (m, _) in enumerate(sorted_methods) if m == best_hyper_method][0] + 1
            logger.info(
                f"\n  → NOTE: Best hyperfitted method ({best_hyper_method}) ranks #{hyper_rank} in TTR"
            )
        else:
            logger.info("\n  → NOTE: No hyperfitted methods produced metrics")
        
        return {
            'method_results': aggregated,
            'ranking': [(m, r['mean_ttr']) for m, r in sorted_methods],
        }



class ExperimentE_HumanEvalFramework:
    """
    Generate data for human evaluation study
    
    Creates pairwise comparisons for:
    - Fluency
    - Coherence  
    - Diversity/Interest
    - Overall preference
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
    
    def generate_evaluation_pairs(
        self,
        prompts: List[str],
        max_new_tokens: int = 224,
    ) -> List[Dict]:
        """
        Generate paired outputs for human evaluation
        """
        pairs = []
        
        for i, prompt in enumerate(tqdm(prompts, desc="Generating evaluation pairs")):
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            # Generate from original model
            self.original_model.eval()
            with torch.no_grad():
                orig_output = self.original_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            orig_text = self.tokenizer.decode(
                orig_output[0][input_ids.shape[1]:],
                skip_special_tokens=True
            )
            
            # Generate from hyperfitted model
            self.hyperfitted_model.eval()
            with torch.no_grad():
                hyper_output = self.hyperfitted_model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            hyper_text = self.tokenizer.decode(
                hyper_output[0][input_ids.shape[1]:],
                skip_special_tokens=True
            )
            
            # Randomize order to avoid position bias
            if np.random.random() > 0.5:
                text_a, text_b = orig_text, hyper_text
                label_a, label_b = "original", "hyperfitted"
            else:
                text_a, text_b = hyper_text, orig_text
                label_a, label_b = "hyperfitted", "original"
            
            pairs.append({
                'id': i,
                'prompt': prompt,
                'text_a': text_a,
                'text_b': text_b,
                'label_a': label_a,
                'label_b': label_b,
                'questions': [
                    {
                        'aspect': 'fluency',
                        'question': 'Which text is more fluent and grammatically correct?',
                        'options': ['A', 'B', 'Tie'],
                    },
                    {
                        'aspect': 'coherence',
                        'question': 'Which text is more coherent and logically consistent?',
                        'options': ['A', 'B', 'Tie'],
                    },
                    {
                        'aspect': 'diversity',
                        'question': 'Which text is more interesting and uses more diverse vocabulary?',
                        'options': ['A', 'B', 'Tie'],
                    },
                    {
                        'aspect': 'repetition',
                        'question': 'Which text has LESS repetition?',
                        'options': ['A', 'B', 'Tie'],
                    },
                    {
                        'aspect': 'overall',
                        'question': 'Overall, which text do you prefer?',
                        'options': ['A', 'B', 'Tie'],
                    },
                ],
            })
        
        return pairs
    
    def export_for_annotation(
        self,
        pairs: List[Dict],
        output_path: str,
    ):
        """Export pairs in a format suitable for annotation"""
        # Export as JSON
        with open(output_path, 'w') as f:
            json.dump(pairs, f, indent=2)
        
        # Also export as human-readable format
        txt_path = output_path.replace('.json', '.txt')
        with open(txt_path, 'w') as f:
            f.write("HUMAN EVALUATION STUDY\n")
            f.write("=" * 80 + "\n\n")
            f.write("Instructions:\n")
            f.write("For each prompt, you will see two text continuations (A and B).\n")
            f.write("Please answer the questions for each pair.\n")
            f.write("=" * 80 + "\n\n")
            
            for pair in pairs:
                f.write(f"--- Pair {pair['id'] + 1} ---\n\n")
                f.write(f"PROMPT: {pair['prompt']}\n\n")
                f.write(f"TEXT A:\n{pair['text_a']}\n\n")
                f.write(f"TEXT B:\n{pair['text_b']}\n\n")
                
                for q in pair['questions']:
                    f.write(f"Q ({q['aspect']}): {q['question']}\n")
                    f.write(f"   Options: {', '.join(q['options'])}\n")
                    f.write(f"   Your answer: _____\n\n")
                
                f.write("-" * 80 + "\n\n")
        
        logger.info(f"Exported {len(pairs)} pairs to {output_path} and {txt_path}")
    
    def run(
        self,
        prompts: List[str],
        output_dir: str,
        max_new_tokens: int = 224,
    ) -> Dict:
        """
        Generate human evaluation data
        """
        logger.info("=" * 60)
        logger.info("Experiment E: Human Evaluation Framework")
        logger.info("=" * 60)
        
        # Generate pairs
        pairs = self.generate_evaluation_pairs(prompts, max_new_tokens)
        
        # Export
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, 'human_eval_pairs.json')
        self.export_for_annotation(pairs, output_path)
        
        logger.info(f"\nGenerated {len(pairs)} evaluation pairs")
        logger.info(f"Export location: {output_path}")
        logger.info("\nNext steps:")
        logger.info("1. Recruit 3-5 annotators")
        logger.info("2. Have each annotator evaluate all pairs")
        logger.info("3. Compute inter-annotator agreement (Fleiss' kappa)")
        logger.info("4. Report win/loss/tie rates for each aspect")
        
        return {
            'num_pairs': len(pairs),
            'output_path': output_path,
            'sample_pair': pairs[0] if pairs else None,
        }



def run_all_additional_experiments(
    original_model_name: str,
    hyperfitted_model_path: str,
    output_dir: str,
    eval_texts: List[str],
    prompts: List[str],
    promoted_tokens: List[Dict],
    normal_finetuned_model_path: Optional[str] = None,
    device: str = "cuda",
    torch_dtype: str = "bfloat16",
):
    """
    Run all additional experiments
    """
    logger.info("=" * 80)
    logger.info("RUNNING ALL ADDITIONAL EXPERIMENTS")
    logger.info("=" * 80)
    
    # Load models
    dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
    dtype = dtype_map.get(torch_dtype, torch.bfloat16)
    
    tokenizer = AutoTokenizer.from_pretrained(original_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    original_model = AutoModelForCausalLM.from_pretrained(
        original_model_name, torch_dtype=dtype, device_map="auto"
    )
    hyperfitted_model = AutoModelForCausalLM.from_pretrained(
        hyperfitted_model_path, torch_dtype=dtype, device_map="auto"
    )
    
    os.makedirs(output_dir, exist_ok=True)
    all_results = {}
    
    # Experiment A: Qualitative Analysis
    logger.info("\n" + "=" * 80)
    logger.info("EXPERIMENT A: Qualitative Analysis")
    logger.info("=" * 80)
    
    exp_a = ExperimentA_QualitativeAnalysis(
        original_model, hyperfitted_model, tokenizer, device
    )
    results_a = exp_a.run(eval_texts, promoted_tokens)
    all_results['qualitative_analysis'] = results_a
    
    with open(os.path.join(output_dir, 'experiment_a_qualitative.json'), 'w') as f:
        json.dump(results_a, f, indent=2, default=str)
    
    # Experiment B: Fine-tuning Comparison (if normal fine-tuned model provided)
    if normal_finetuned_model_path:
        logger.info("\n" + "=" * 80)
        logger.info("EXPERIMENT B: Fine-tuning Comparison")
        logger.info("=" * 80)
        
        normal_model = AutoModelForCausalLM.from_pretrained(
            normal_finetuned_model_path, torch_dtype=dtype, device_map="auto"
        )
        
        exp_b = ExperimentB_FineTuningComparison(
            original_model, hyperfitted_model, normal_model, tokenizer, device
        )
        results_b = exp_b.run(eval_texts)
        all_results['finetuning_comparison'] = results_b
        
        with open(os.path.join(output_dir, 'experiment_b_finetuning.json'), 'w') as f:
            json.dump(results_b, f, indent=2, default=str)
        
        del normal_model
        torch.cuda.empty_cache()
    
    # Experiment C: Semantic Diversity
    logger.info("\n" + "=" * 80)
    logger.info("EXPERIMENT C: Semantic Diversity")
    logger.info("=" * 80)
    
    exp_c = ExperimentC_SemanticDiversity(
        original_model, hyperfitted_model, tokenizer, device
    )
    results_c = exp_c.run(prompts[:30])
    all_results['semantic_diversity'] = results_c
    
    with open(os.path.join(output_dir, 'experiment_c_semantic.json'), 'w') as f:
        json.dump(results_c, f, indent=2, default=str)
    
    # Experiment D: Baseline Comparison
    logger.info("\n" + "=" * 80)
    logger.info("EXPERIMENT D: Baseline Comparison")
    logger.info("=" * 80)
    
    exp_d = ExperimentD_BaselineComparison(
        original_model, hyperfitted_model, tokenizer, device
    )
    results_d = exp_d.run(prompts[:20])
    all_results['baseline_comparison'] = results_d
    
    with open(os.path.join(output_dir, 'experiment_d_baselines.json'), 'w') as f:
        json.dump(results_d, f, indent=2, default=str)
    
    # Experiment E: Human Evaluation Framework
    logger.info("\n" + "=" * 80)
    logger.info("EXPERIMENT E: Human Evaluation Framework")
    logger.info("=" * 80)
    
    exp_e = ExperimentE_HumanEvalFramework(
        original_model, hyperfitted_model, tokenizer, device
    )
    results_e = exp_e.run(prompts[:50], output_dir)
    all_results['human_eval_framework'] = results_e
    
    # Save all results
    with open(os.path.join(output_dir, 'all_additional_results.json'), 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    
    logger.info("\n" + "=" * 80)
    logger.info("ALL ADDITIONAL EXPERIMENTS COMPLETE")
    logger.info("=" * 80)
    logger.info(f"Results saved to: {output_dir}")
    
    return all_results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Run additional experiments for ICML")
    parser.add_argument("--original_model", type=str, required=True)
    parser.add_argument("--hyperfitted_model", type=str, required=True)
    parser.add_argument("--normal_finetuned_model", type=str, default=None,
                       help="Path to normally fine-tuned model (for Experiment B)")
    parser.add_argument("--output_dir", type=str, default="./results/additional")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
    parser.add_argument("--rank_analysis_results", type=str, default=None,
                       help="Path to rank_analysis_results.json for promoted tokens")
    
    args = parser.parse_args()
    
    # Load promoted tokens from previous experiment
    promoted_tokens = []
    if args.rank_analysis_results and os.path.exists(args.rank_analysis_results):
        with open(args.rank_analysis_results, 'r') as f:
            data = json.load(f)
            promoted_tokens = data.get('results', {}).get('promoted_tokens', [])
    
    # Load eval data
    from run_experiments import load_eval_data
    tokenizer = AutoTokenizer.from_pretrained(args.original_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    eval_texts, prompts = load_eval_data(
        num_samples=100,
        context_length=32,
        sequence_length=256,
        tokenizer=tokenizer,
    )
    
    run_all_additional_experiments(
        original_model_name=args.original_model,
        hyperfitted_model_path=args.hyperfitted_model,
        output_dir=args.output_dir,
        eval_texts=eval_texts,
        prompts=prompts,
        promoted_tokens=promoted_tokens,
        normal_finetuned_model_path=args.normal_finetuned_model,
        torch_dtype=args.torch_dtype,
    )
