"""
Implements the logic for analyzing the distribution of values in model
weights and activations.
"""
import torch
import numpy as np
import logging
from scipy.stats import describe
from typing import Dict, List, Any
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer

# Configure logging
logger = logging.getLogger(__name__)

class ValueDistributionAnalyzer:
    """
    Analyzes the distribution of values in model parameters and activations.
    """

    def __init__(self, model: torch.nn.Module, config: Dict[str, Any]):
        """
        Initializes the Value Distribution Analyzer.

        Args:
            model (torch.nn.Module): The Hugging Face model to analyze.
            config (Dict[str, Any]): The analysis configuration dictionary.
        """
        self.model = model
        self.config = config

    def analyze(self) -> Dict[str, Any]:
        """
        Runs the full analysis pipeline on the model.

        Analyzes weights and, if configured, activations. The analysis is
        multi-level, providing both aggregated and per-layer statistics.

        Returns:
            Dict[str, Any]: A nested dictionary containing analysis statistics.
        """
        results = {}
        if "weights" in self.config["analysis"]["parameter_types"]:
            logger.info("Analyzing model weights...")
            all_param_values = self._get_all_parameter_values()

            per_layer_analysis = self._analyze_per_layer(all_param_values)
            aggregated_analysis = self._analyze_aggregated(all_param_values)

            results["weights"] = {
                "per_layer_analysis": per_layer_analysis,
                "aggregated_analysis": aggregated_analysis,
            }

        if "activations" in self.config["analysis"]["parameter_types"]:
            logger.info("Analyzing model activations...")
            results["activations"] = self._analyze_activations()

        return results

    def _get_all_parameter_values(self) -> Dict[str, torch.Tensor]:
        """Extracts values for all model parameters."""
        param_values = {}
        for name, param in self.model.named_parameters():
            if param.dim() < 1:  # Skip scalars
                continue
            logger.debug(f"Extracting values from layer: {name}, shape: {param.shape}")
            param_values[name] = param.detach().to(torch.float16).cpu().flatten()
        return param_values

    def _analyze_per_layer(self, all_param_values: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Analyzes the weights of each layer individually.
        """
        logger.info("Running per-layer analysis...")
        per_layer_stats = {}
        for name, values_tensor in tqdm(all_param_values.items()):
            if values_tensor.numel() == 0:
                logger.warning(f"Layer {name} has no values to analyze.")
                continue
            per_layer_stats[name] = self._compute_distribution_stats(values_tensor.numpy())
        return per_layer_stats

    def _analyze_aggregated(self, all_param_values: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Performs an aggregated analysis across all model weights.
        """
        logger.info("Running aggregated analysis...")
        all_values_list = [t for t in all_param_values.values() if t.numel() > 0]
        if not all_values_list:
            logger.warning("No parameters found to analyze.")
            return {}
        all_values_list = torch.cat(all_values_list).numpy()
        return self._compute_distribution_stats(all_values_list)

    def _analyze_activations(self) -> Dict[str, Any]:
        """Analyzes activation tensors using forward hooks and calibration data."""
        logger.info("Running activation analysis.")

        activation_values = []

        def hook(_, __, output):
            if isinstance(output, torch.Tensor):
                activation_values.append(output.detach().cpu().flatten())
            elif isinstance(output, (list, tuple)):
                for out in output:
                    if isinstance(out, torch.Tensor):
                        activation_values.append(out.detach().cpu().flatten())

        target_layers = self.config["analysis"].get("layers", "all")
        handles = []

        if target_layers == "all":
            for module in self.model.modules():
                if len(list(module.children())) == 0:
                    handles.append(module.register_forward_hook(hook))
        else:
            for idx, module in enumerate(self.model.modules()):
                if idx in target_layers and len(list(module.children())) == 0:
                    handles.append(module.register_forward_hook(hook))

        calib_cfg = self.config["analysis"].get("calibration_data", {})
        dataset_name = calib_cfg.get("dataset")
        subset = calib_cfg.get("subset")
        num_samples = calib_cfg.get("num_samples", 1)
        text_column = calib_cfg.get("text_column", "text")

        dataset = load_dataset(dataset_name, subset, split=f"train[:{num_samples}]")
        texts = list(dataset[text_column])

        tokenizer = AutoTokenizer.from_pretrained(
            self.config["model"]["name"],
            trust_remote_code=self.config["model"].get("trust_remote_code", False),
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.model.eval()
        batch_size = 1
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
            if inputs["input_ids"].shape[0] == 0 or inputs["input_ids"].shape[1] == 0:
                continue
            with torch.no_grad():
                self.model(**inputs)

            torch.cuda.empty_cache()

        for handle in handles:
            handle.remove()

        if not activation_values:
            logger.warning("No activations to analyze.")
            return self._compute_distribution_stats(np.array([]))

        aggregated_values = torch.cat(activation_values).numpy()
        return self._compute_distribution_stats(aggregated_values)

    def _compute_distribution_stats(self, values: np.ndarray) -> Dict[str, Any]:
        """
        Computes descriptive statistics for the given values.
        """
        if values.size == 0:
            return {"nobs": 0, "minmax": (None, None), "mean": None, "variance": None, "std":None, "skewness": None, "kurtosis": None}

        # Using scipy.stats.describe to get all stats at once
        stats = describe(values)

        # JSON serializable format
        return {
            "nobs": int(stats.nobs),
            "minmax": (float(stats.minmax[0]), float(stats.minmax[1])),
            "mean": float(stats.mean),
            "variance": float(stats.variance),
            "std": float(np.sqrt(stats.variance)),
        }

    def plot_results(self, stats: Dict[str, Any], output_dir: str):
        """
        Generates and saves histograms of the value distributions.
        """
        import os
        import matplotlib.pyplot as plt

        os.makedirs(output_dir, exist_ok=True)

        def _plot_histogram(values: np.ndarray, title: str, filename: str, log_scale: bool = False):
            if values.size == 0:
                logger.warning(f"Skipping plot for '{title}' as there is no data.")
                return

            fig, ax = plt.subplots()
            ax.hist(values, bins=self.config["analysis"].get("histogram_bins", 100))
            ax.set_title(title)
            ax.set_xlabel("Value")
            ax.set_ylabel("Frequency")

            if log_scale:
                ax.set_xscale("log")

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, filename))
            plt.close(fig)

        logger.info("Generating plots...")

        if "weights" in stats:
            logger.info("Plotting weights distribution...")
            all_param_values = self._get_all_parameter_values()

            # Plot aggregated
            aggregated_values = torch.cat(list(all_param_values.values())).numpy()
            _plot_histogram(aggregated_values, "Aggregated Weights Distribution", "weights_aggregated.png", True) # Log scale on aggregated

            # Plot per-layer
            for name, values in all_param_values.items():
                safe_name = name.replace(".", "_").replace("/", "_")
                _plot_histogram(values.numpy(), f"Weights - {name}", f"weights_{safe_name}.png")

        if "activations" in stats:
            logger.warning("Plotting activation distributions requires running the forward pass again and is currently not implemented in the plotting function.")

    def plot_benford_distributions(self, output_dir: str):
        """
        Plots graphs comparing first significant digit distribution with benford's law layer by layer.
        """
        import os
        import matplotlib.pyplot as plt

        os.makedirs(output_dir, exist_ok=True)
        logger.info("Starting Benford distribution plots...")

        def first_digit(x: float) -> int:
            if x == 0:
                return 0
            x = abs(x)
            while x < 1:
                x *= 10
            while x >= 10:
                x /= 10
            return int(x)

        # Benford probabilities
        benford_probs = [np.log10(1 + 1/d) for d in range(1, 10)]
        digits = np.arange(1, 10)

        all_param_values = self._get_all_parameter_values()

        # Layer groups
        attention_mlp_layers = {k: v for k, v in all_param_values.items()
                                if ("norm" not in k.lower()) and any(x in k.lower() for x in ["attn", "mlp"])}
        other_layers = {k: v for k, v in all_param_values.items()
                        if k not in attention_mlp_layers}

        def compute_digit_distribution(values_tensor: torch.Tensor):
            values = values_tensor.numpy()
            digits_list = [first_digit(x) for x in values if x != 0]
            if not digits_list:
                return np.zeros(9)
            counts = np.bincount(digits_list, minlength=10)[1:10]
            return counts / counts.sum() if counts.sum() > 0 else np.zeros(9)

        # --- Plot 1: Attention + MLP ---
        plt.figure()
        for name, values in attention_mlp_layers.items():
            dist = compute_digit_distribution(values)
            plt.plot(digits, dist, marker="o", color="gray", alpha=0.4)
        plt.plot(digits, benford_probs, marker="o", color="blue", label="Benford Ideal")
        plt.xlabel("First Digit")
        plt.ylabel("Probability")
        plt.title("Benford Distribution - Attention + MLP")
        plt.legend()
        plt.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "benford_attention_mlp.png"))
        plt.close()

        # --- Plot 2: Other layers ---
        plt.figure()
        for name, values in other_layers.items():
            dist = compute_digit_distribution(values)
            plt.plot(digits, dist, marker="o", color="gray", alpha=0.1)
        plt.plot(digits, benford_probs, marker="o", color="blue", label="Benford Ideal")
        plt.xlabel("First Digit")
        plt.ylabel("Probability")
        plt.title("Benford Distribution - Other Layers")
        plt.legend()
        plt.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "benford_other_layers.png"))
        plt.close()

        logger.info(f"Output path: {output_dir}")
