"""
Metrics Module for Hyperfitting Analysis

Includes:
- Type-Token Ratio (TTR)
- N-gram repetition rates
- Entropy metrics
- Rank analysis metrics
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List
from scipy.stats import spearmanr
import logging

logger = logging.getLogger(__name__)


def compute_ttr(tokens: List[int], window_size: int = 96) -> float:
    """
    Compute Type-Token Ratio (TTR)
    
    TTR = Number of Unique Tokens / Total Number of Tokens
    
    Args:
        tokens: List of token IDs
        window_size: Number of tokens to consider (from the end)
    
    Returns:
        TTR score between 0 and 1
    """
    if len(tokens) == 0:
        return 0.0
    
    # Take last window_size tokens
    if len(tokens) > window_size:
        tokens = tokens[-window_size:]
    
    unique_tokens = len(set(tokens))
    total_tokens = len(tokens)
    
    return unique_tokens / total_tokens


def compute_ngram_repetition(tokens: List[int], n: int = 2) -> float:
    """
    Compute n-gram repetition rate
    
    Args:
        tokens: List of token IDs
        n: Size of n-grams
    
    Returns:
        Repetition rate (1 - unique_ngrams / total_ngrams)
    """
    if len(tokens) < n:
        return 0.0
    
    ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
    
    if len(ngrams) == 0:
        return 0.0
    
    unique_ngrams = len(set(ngrams))
    total_ngrams = len(ngrams)
    
    return 1 - (unique_ngrams / total_ngrams)


def compute_sequence_entropy(probs: torch.Tensor) -> float:
    """
    Compute entropy of a probability distribution
    
    Args:
        probs: Probability distribution tensor [..., vocab_size]
    
    Returns:
        Average entropy
    """
    # Add small epsilon to avoid log(0)
    probs = probs.clamp(min=1e-10)
    entropy = -torch.sum(probs * torch.log(probs), dim=-1)
    return entropy.mean().item()


def compute_top_k_probability(logits: torch.Tensor, k: int = 1) -> float:
    """
    Compute cumulative probability of top-k tokens
    
    Args:
        logits: Logits tensor [..., vocab_size]
        k: Number of top tokens
    
    Returns:
        Average top-k probability
    """
    probs = F.softmax(logits, dim=-1)
    top_k_probs, _ = torch.topk(probs, k, dim=-1)
    cumulative_prob = top_k_probs.sum(dim=-1)
    return cumulative_prob.mean().item()


def compute_perplexity(
    model,
    input_ids: torch.Tensor,
    device: str = "cuda",
) -> float:
    """
    Compute perplexity on input sequence
    
    Args:
        model: Language model
        input_ids: Input token IDs [batch, seq_len]
        device: Device to use
    
    Returns:
        Perplexity score
    """
    model.eval()
    input_ids = input_ids.to(device)
    
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        
        # Shift for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        # Compute cross-entropy
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        
        perplexity = torch.exp(loss).item()
    
    return perplexity


class DistributionAnalyzer:
    """Analyze prediction distributions from models"""
    
    def __init__(self, model, tokenizer, device: str = "cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
    
    def get_logits(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Get logits for input sequence"""
        self.model.eval()
        input_ids = input_ids.to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids)
            return outputs.logits
    
    def analyze_distribution(
        self,
        input_ids: torch.Tensor,
        temperature: float = 1.0,
    ) -> Dict:
        """
        Analyze the prediction distribution
        
        Returns dict with:
        - entropy: Average entropy
        - top1_prob: Average probability of top-1 token
        - top3_prob: Average cumulative probability of top-3
        - top5_prob: Average cumulative probability of top-5
        - top10_prob: Average cumulative probability of top-10
        """
        logits = self.get_logits(input_ids)
        
        # Apply temperature
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)
        
        # Entropy
        entropy = compute_sequence_entropy(probs)
        
        # Top-k probabilities
        results = {
            "entropy": entropy,
            "top1_prob": compute_top_k_probability(scaled_logits, k=1),
            "top3_prob": compute_top_k_probability(scaled_logits, k=3),
            "top5_prob": compute_top_k_probability(scaled_logits, k=5),
            "top10_prob": compute_top_k_probability(scaled_logits, k=10),
        }
        
        return results
    
    def find_matching_temperature(
        self,
        input_ids: torch.Tensor,
        target_entropy: float,
        tolerance: float = 0.01,
        max_iterations: int = 50,
    ) -> float:
        """
        Find temperature that produces target entropy
        
        Uses binary search to find temperature T such that
        entropy(softmax(logits/T)) ≈ target_entropy
        """
        logits = self.get_logits(input_ids)
        
        low, high = 0.01, 5.0
        
        for _ in range(max_iterations):
            mid = (low + high) / 2
            scaled_logits = logits / mid
            probs = F.softmax(scaled_logits, dim=-1)
            current_entropy = compute_sequence_entropy(probs)
            
            if abs(current_entropy - target_entropy) < tolerance:
                return mid
            
            if current_entropy > target_entropy:
                high = mid  # Need lower temperature (sharper)
            else:
                low = mid   # Need higher temperature (flatter)
        
        return (low + high) / 2


class RankAnalyzer:
    """
    Analyze token rankings between baseline and hyperfitted models
    """
    
    def __init__(
        self,
        original_model,
        hyperfitted_model,
        tokenizer,
        device: str = "cuda",
    ):
        self.original_model = original_model
        self.hyperfitted_model = hyperfitted_model
        self.tokenizer = tokenizer
        self.device = device
    
    def get_rankings(self, model, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Get token rankings for each position.
        """
        model.eval()
        input_ids = input_ids.to(self.device)
        
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
            # Get rankings
            rankings = torch.argsort(logits, dim=-1, descending=True)
        
        return rankings
    
    def compare_top1_predictions(self, input_ids: torch.Tensor) -> Dict:
        """
        Compare top-1 predictions between original and hyperfitted model
        
        Returns:
        - agreement_rate: How often both models agree on top-1
        - hyper_top1_in_orig_top_k: Rate at which hyperfitted top-1 appears in original top-k
        """
        orig_rankings = self.get_rankings(self.original_model, input_ids)
        hyper_rankings = self.get_rankings(self.hyperfitted_model, input_ids)
        
        batch_size, seq_len, vocab_size = orig_rankings.shape
        
        results = {
            "top1_agreement": 0.0,
            "hyper_top1_in_orig_top5": 0.0,
            "hyper_top1_in_orig_top10": 0.0,
            "hyper_top1_in_orig_top50": 0.0,
            "hyper_top1_in_orig_top100": 0.0,
            "total_positions": 0,
        }
        
        for b in range(batch_size):
            for pos in range(seq_len):
                orig_top1 = orig_rankings[b, pos, 0].item()
                hyper_top1 = hyper_rankings[b, pos, 0].item()
                
                orig_top_k = orig_rankings[b, pos, :].tolist()
                
                # Check agreement
                if orig_top1 == hyper_top1:
                    results["top1_agreement"] += 1
                
                # Check if hyperfitted top-1 is in original top-k
                hyper_top1_orig_rank = orig_top_k.index(hyper_top1) if hyper_top1 in orig_top_k else vocab_size
                
                if hyper_top1_orig_rank < 5:
                    results["hyper_top1_in_orig_top5"] += 1
                if hyper_top1_orig_rank < 10:
                    results["hyper_top1_in_orig_top10"] += 1
                if hyper_top1_orig_rank < 50:
                    results["hyper_top1_in_orig_top50"] += 1
                if hyper_top1_orig_rank < 100:
                    results["hyper_top1_in_orig_top100"] += 1
                
                results["total_positions"] += 1
        
        # Normalize
        total = results["total_positions"]
        for key in list(results.keys()):
            if key != "total_positions":
                results[key] /= total
        
        return results
    
    def compute_rank_correlation(
        self,
        input_ids: torch.Tensor,
        top_k: int = 100,
    ) -> Dict:
        """
        Compute Spearman rank correlation for top-k tokens
        """
        orig_rankings = self.get_rankings(self.original_model, input_ids)
        hyper_rankings = self.get_rankings(self.hyperfitted_model, input_ids)
        
        batch_size, seq_len, _ = orig_rankings.shape
        
        correlations = []
        
        for b in range(batch_size):
            for pos in range(seq_len):
                # Get top-k from original model
                orig_top_k = orig_rankings[b, pos, :top_k].cpu().numpy()
                
                # Find where these tokens rank in hyperfitted model
                hyper_full = hyper_rankings[b, pos, :].cpu().numpy()
                
                hyper_ranks = []
                for token_id in orig_top_k:
                    idx = np.where(hyper_full == token_id)[0]
                    if len(idx) > 0:
                        hyper_ranks.append(idx[0])
                    else:
                        hyper_ranks.append(len(hyper_full))
                
                # Compute Spearman correlation
                if len(set(hyper_ranks)) > 1:  # Need variation
                    corr, _ = spearmanr(list(range(top_k)), hyper_ranks)
                    if not np.isnan(corr):
                        correlations.append(corr)
        
        return {
            "mean_rank_correlation": np.mean(correlations) if correlations else 0.0,
            "std_rank_correlation": np.std(correlations) if correlations else 0.0,
            "num_positions": len(correlations),
        }
    
    def find_promoted_tokens(
        self,
        input_ids: torch.Tensor,
        threshold: int = 50,
    ) -> List[Dict]:
        """
        Find tokens that are significantly promoted in hyperfitted model
        
        Args:
            threshold: Original rank threshold (find tokens promoted from rank > threshold to top-10)
        
        Returns:
            List of promoted token info
        """
        orig_rankings = self.get_rankings(self.original_model, input_ids)
        hyper_rankings = self.get_rankings(self.hyperfitted_model, input_ids)
        
        batch_size, seq_len, vocab_size = orig_rankings.shape
        
        promoted_tokens = []
        
        for b in range(batch_size):
            for pos in range(seq_len):
                hyper_top10 = set(hyper_rankings[b, pos, :10].tolist())
                orig_top_threshold = set(orig_rankings[b, pos, :threshold].tolist())
                
                # Find tokens in hyperfitted top-10 but not in original top-threshold
                promoted = hyper_top10 - orig_top_threshold
                
                for token_id in promoted:
                    orig_full = orig_rankings[b, pos, :].tolist()
                    orig_rank = orig_full.index(token_id) if token_id in orig_full else vocab_size
                    
                    hyper_full = hyper_rankings[b, pos, :].tolist()
                    hyper_rank = hyper_full.index(token_id) if token_id in hyper_full else vocab_size
                    
                    promoted_tokens.append({
                        "token_id": token_id,
                        "token_str": self.tokenizer.decode([token_id]),
                        "original_rank": orig_rank,
                        "hyperfitted_rank": hyper_rank,
                        "rank_improvement": orig_rank - hyper_rank,
                        "position": pos,
                        "batch": b,
                    })
        
        # Sort by rank improvement
        promoted_tokens.sort(key=lambda x: x["rank_improvement"], reverse=True)
        
        return promoted_tokens


class GenerationMetrics:
    @staticmethod
    def compute_all_metrics(
        generated_tokens: List[int],
        ttr_window: int = 96,
    ) -> Dict:
        """
        Compute all generation metrics
        
        Args:
            generated_tokens: List of generated token IDs
            ttr_window: Window size for TTR calculation
        
        Returns:
            Dictionary of metrics
        """
        metrics = {
            "length": len(generated_tokens),
            "ttr": compute_ttr(generated_tokens, window_size=ttr_window),
            "bigram_repetition": compute_ngram_repetition(generated_tokens, n=2),
            "trigram_repetition": compute_ngram_repetition(generated_tokens, n=3),
            "4gram_repetition": compute_ngram_repetition(generated_tokens, n=4),
            "unique_tokens": len(set(generated_tokens)),
            "unique_ratio": len(set(generated_tokens)) / len(generated_tokens) if generated_tokens else 0,
        }
        
        return metrics
    
    @staticmethod
    def compute_dataset_overlap(
        generated_tokens: List[int],
        training_tokens: List[List[int]],
    ) -> Dict:
        """
        Compute overlap between generated text and training data
        
        Returns:
        - max_overlap_length: Longest overlapping subsequence
        - overlap_rate: Fraction of generated tokens that overlap
        """
        # Convert training data to set of n-grams for fast lookup
        training_ngrams = {}
        for n in range(3, 20):  # Check n-grams of length 3 to 19
            training_ngrams[n] = set()
            for seq in training_tokens:
                for i in range(len(seq) - n + 1):
                    training_ngrams[n].add(tuple(seq[i:i+n]))
        
        # Find longest overlapping subsequence
        max_overlap = 0
        for n in range(19, 2, -1):  # Check from longest to shortest
            for i in range(len(generated_tokens) - n + 1):
                ngram = tuple(generated_tokens[i:i+n])
                if n in training_ngrams and ngram in training_ngrams[n]:
                    max_overlap = max(max_overlap, n)
                    break
            if max_overlap >= n:
                break
        
        return {
            "max_overlap_length": max_overlap,
        }
