from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from scipy.stats import mannwhitneyu, ttest_ind
from sklearn.metrics import roc_auc_score
import warnings

from .collector import FeatureActivations


@dataclass
class FeatureStats:
    feature_index: int
    
    mean_reasoning: float
    mean_nonreasoning: float
    std_reasoning: float
    std_nonreasoning: float
    
    cohens_d: float
    log_fold_change: float
    
    mannwhitney_u: float
    mannwhitney_p: float
    ttest_t: float
    ttest_p: float
    
    roc_auc: float
    freq_active_reasoning: float
    freq_active_nonreasoning: float
    
    reasoning_score: float = field(default=0.0)
    
    _score_weights: dict = field(default_factory=lambda: {
        "auc": 0.3,
        "effect": 0.25,
        "pvalue": 0.25,
        "freq": 0.2,
    }, repr=False)
    
    def __post_init__(self):
        self._compute_reasoning_score()
    
    def _compute_reasoning_score(self):
        direction = 1 if self.mean_reasoning > self.mean_nonreasoning else -1
        auc_contrib = abs(self.roc_auc - 0.5) * 2
        effect_contrib = min(abs(self.cohens_d), 3.0) / 3.0
        p_contrib = min(-np.log10(self.mannwhitney_p + 1e-300), 50) / 50
        freq_ratio = (self.freq_active_reasoning + 0.01) / (self.freq_active_nonreasoning + 0.01)
        freq_contrib = min(np.log2(freq_ratio + 1) / 5, 1.0) if freq_ratio > 1 else 0
        
        weights = self._score_weights
        self.reasoning_score = direction * (
            weights["auc"] * auc_contrib +
            weights["effect"] * effect_contrib +
            weights["pvalue"] * p_contrib +
            weights["freq"] * freq_contrib
        )
    
    def set_score_weights(
        self,
        auc_weight: float = 0.3,
        effect_weight: float = 0.25,
        pvalue_weight: float = 0.25,
        freq_weight: float = 0.2,
    ):
        self._score_weights = {
            "auc": auc_weight,
            "effect": effect_weight,
            "pvalue": pvalue_weight,
            "freq": freq_weight,
        }
        self._compute_reasoning_score()
    
    def is_reasoning_feature(
        self,
        min_auc: float = 0.6,
        max_p_value: float = 0.01,
        min_effect_size: float = 0.3,
    ) -> bool:
        return (
            self.roc_auc >= min_auc and
            self.mannwhitney_p <= max_p_value and
            abs(self.cohens_d) >= min_effect_size and
            self.mean_reasoning > self.mean_nonreasoning
        )
    
    def to_dict(self) -> dict:
        return {
            "feature_index": self.feature_index,
            "mean_reasoning": self.mean_reasoning,
            "mean_nonreasoning": self.mean_nonreasoning,
            "std_reasoning": self.std_reasoning,
            "std_nonreasoning": self.std_nonreasoning,
            "cohens_d": self.cohens_d,
            "log_fold_change": self.log_fold_change,
            "mannwhitney_u": self.mannwhitney_u,
            "mannwhitney_p": self.mannwhitney_p,
            "ttest_t": self.ttest_t,
            "ttest_p": self.ttest_p,
            "roc_auc": self.roc_auc,
            "freq_active_reasoning": self.freq_active_reasoning,
            "freq_active_nonreasoning": self.freq_active_nonreasoning,
            "reasoning_score": self.reasoning_score,
        }


class ReasoningFeatureDetector:
    
    def __init__(
        self,
        activations: FeatureActivations,
        aggregation: str = "max",
        score_weights: Optional[dict] = None,
    ):
        self.activations = activations
        self.aggregation = aggregation
        self.score_weights = score_weights or {
            "auc": 0.3,
            "effect": 0.25,
            "pvalue": 0.25,
            "freq": 0.2,
        }
        
        if aggregation == "max":
            self.agg_acts = activations.get_max_activations()
        elif aggregation == "mean":
            self.agg_acts = activations.get_mean_activations()
        elif aggregation == "sum":
            self.agg_acts = activations.activations.sum(dim=1)
        else:
            raise ValueError(f"Unknown aggregation: {aggregation}")
        
        self.reasoning_mask = activations.get_reasoning_mask()
        self.n_reasoning = self.reasoning_mask.sum().item()
        self.n_nonreasoning = (~self.reasoning_mask).sum().item()
        
        self._feature_stats: Optional[list[FeatureStats]] = None
    
    def compute_feature_stats(self, feature_idx: int) -> FeatureStats:
        acts = self.agg_acts[:, feature_idx].numpy()
        reasoning_acts = acts[self.reasoning_mask.numpy()]
        nonreasoning_acts = acts[~self.reasoning_mask.numpy()]
        
        mean_r = float(np.mean(reasoning_acts))
        mean_nr = float(np.mean(nonreasoning_acts))
        std_r = float(np.std(reasoning_acts))
        std_nr = float(np.std(nonreasoning_acts))
        
        pooled_std = np.sqrt(
            ((self.n_reasoning - 1) * std_r**2 + (self.n_nonreasoning - 1) * std_nr**2) /
            (self.n_reasoning + self.n_nonreasoning - 2)
        )
        cohens_d = (mean_r - mean_nr) / (pooled_std + 1e-10)
        
        log_fc = np.log2((mean_r + 1e-10) / (mean_nr + 1e-10))
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            try:
                u_stat, u_pval = mannwhitneyu(
                    reasoning_acts, nonreasoning_acts, alternative="two-sided"
                )
            except ValueError:
                u_stat, u_pval = 0.0, 1.0
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            try:
                t_stat, t_pval = ttest_ind(
                    reasoning_acts, nonreasoning_acts, equal_var=False
                )
            except ValueError:
                t_stat, t_pval = 0.0, 1.0
        
        try:
            labels = self.reasoning_mask.numpy().astype(int)
            roc_auc = roc_auc_score(labels, acts)
        except ValueError:
            roc_auc = 0.5
        
        threshold = 0.01 * max(acts.max(), 1e-10)
        freq_r = (reasoning_acts > threshold).mean()
        freq_nr = (nonreasoning_acts > threshold).mean()
        
        stats = FeatureStats(
            feature_index=feature_idx,
            mean_reasoning=mean_r,
            mean_nonreasoning=mean_nr,
            std_reasoning=std_r,
            std_nonreasoning=std_nr,
            cohens_d=cohens_d,
            log_fold_change=log_fc,
            mannwhitney_u=float(u_stat),
            mannwhitney_p=float(u_pval),
            ttest_t=float(t_stat) if not np.isnan(t_stat) else 0.0,
            ttest_p=float(t_pval) if not np.isnan(t_pval) else 1.0,
            roc_auc=roc_auc,
            freq_active_reasoning=float(freq_r),
            freq_active_nonreasoning=float(freq_nr),
            _score_weights=self.score_weights,
        )
        return stats
    
    def compute_all_stats(
        self, 
        verbose: bool = True,
        feature_indices: Optional[list[int]] = None,
    ) -> list[FeatureStats]:
        if self._feature_stats is not None and feature_indices is None:
            return self._feature_stats
        
        if feature_indices is not None:
            indices = feature_indices
        else:
            indices = list(range(self.activations.n_features))
        
        stats = []
        
        iterator = indices
        if verbose:
            import tqdm
            iterator = tqdm.tqdm(indices, desc="Computing feature statistics")
        
        for i in iterator:
            stats.append(self.compute_feature_stats(i))
        
        if feature_indices is None:
            self._feature_stats = stats
        return stats
    
    def get_reasoning_features(
        self,
        min_auc: float = 0.6,
        max_p_value: float = 0.01,
        min_effect_size: float = 0.3,
        top_k: Optional[int] = None,
        feature_indices: Optional[list[int]] = None,
    ) -> list[FeatureStats]:
        all_stats = self.compute_all_stats(feature_indices=feature_indices)
        
        reasoning_features = [
            s for s in all_stats
            if s.is_reasoning_feature(min_auc, max_p_value, min_effect_size)
        ]
        
        reasoning_features.sort(key=lambda x: x.reasoning_score, reverse=True)
        
        if top_k is not None:
            reasoning_features = reasoning_features[:top_k]
        
        return reasoning_features
    
    def get_top_features_by_score(
        self, 
        top_k: int = 100,
        feature_indices: Optional[list[int]] = None,
    ) -> list[FeatureStats]:
        all_stats = self.compute_all_stats(feature_indices=feature_indices)
        sorted_stats = sorted(all_stats, key=lambda x: x.reasoning_score, reverse=True)
        return sorted_stats[:top_k]
    
    def apply_bonferroni_correction(
        self,
        feature_indices: Optional[list[int]] = None,
    ) -> list[FeatureStats]:
        all_stats = self.compute_all_stats(feature_indices=feature_indices)
        n_tests = len(all_stats)
        
        for stat in all_stats:
            stat.mannwhitney_p = min(stat.mannwhitney_p * n_tests, 1.0)
            stat.ttest_p = min(stat.ttest_p * n_tests, 1.0)
        
        return all_stats
    
    def summary(
        self,
        feature_indices: Optional[list[int]] = None,
    ) -> dict:
        all_stats = self.compute_all_stats(feature_indices=feature_indices)
        reasoning_features = self.get_reasoning_features(feature_indices=feature_indices)
        
        return {
            "total_features": len(all_stats),
            "reasoning_features_count": len(reasoning_features),
            "percentage_reasoning": len(reasoning_features) / len(all_stats) * 100 if all_stats else 0,
            "top_10_features": [s.feature_index for s in reasoning_features[:10]],
            "top_10_scores": [s.reasoning_score for s in reasoning_features[:10]],
            "mean_auc_reasoning_features": np.mean([s.roc_auc for s in reasoning_features]) if reasoning_features else 0,
            "mean_cohens_d_reasoning_features": np.mean([s.cohens_d for s in reasoning_features]) if reasoning_features else 0,
        }
