"""
Implements the logic for RQ1: Analyzing compliance with Benford's Law.
"""
import torch
import numpy as np
import logging
import os
from scipy.stats import chisquare
from typing import Dict, List, Any
from datasets import load_dataset
from transformers import AutoTokenizer

#os.environ["HF_HOME"] = "/change/if/needed"  # Change if needed

# Configure logging
logger = logging.getLogger(__name__)

class BenfordAnalyzer:
    """
    Analyzes the distribution of first significant digits in model parameters
    and activations to test for compliance with Benford's Law.
    """

    def __init__(self, model: torch.nn.Module, config: Dict[str, Any]):
        """
        Initializes the Benford Analyzer.

        Args:
            model (torch.nn.Module): The Hugging Face model to analyze.
            config (Dict[str, Any]): The analysis configuration dictionary.
        """
        self.model = model
        self.config = config
        self.benford_dist = np.array([np.log10(1 + 1/d) for d in range(1, 10)])

    def get_first_significant_digit(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Extracts the first significant digit from each element in a tensor.

        Args:
            tensor (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: A tensor of the same shape containing the first
                          significant digits (1-9). Zeros are returned for
                          input values of zero.
        """
        # Handle non-zero elements
        non_zero_mask = tensor != 0
        abs_tensor = torch.abs(tensor[non_zero_mask])

        # Use logarithms to find the first digit
        # floor(log10(x)) gives the power of 10
        # 10^(log10(x) - floor(log10(x))) gives the number in [1, 10)
        # Taking the floor gives the first digit
        first_digits = torch.floor(10**(torch.log10(abs_tensor) - torch.floor(torch.log10(abs_tensor))))

        # Create a result tensor and place the digits back
        result = torch.zeros_like(tensor, dtype=torch.int8)
        result[non_zero_mask] = first_digits.to(torch.int8)
        return result

    def analyze(self) -> Dict[str, Any]:
        """
        Runs the full analysis pipeline on the model.

        Analyzes weights and, if configured, activations. The analysis is
        now multi-level, providing both aggregated and per-layer statistics.

        Returns:
            Dict[str, Any]: A nested dictionary containing analysis statistics.
        """
        results = {}
        if "weights" in self.config["analysis"]["parameter_types"]:
            logger.info("Analyzing model weights...")
            all_param_digits = self._get_all_parameter_digits()

            per_layer_analysis = self._analyze_per_layer(all_param_digits)
            aggregated_analysis = self._analyze_aggregated(all_param_digits)

            results["weights"] = {
                "per_layer_analysis": per_layer_analysis,
                "aggregated_analysis": aggregated_analysis,
            }

        if "activations" in self.config["analysis"]["parameter_types"]:
            logger.info("Analyzing model activations...")
            results["activations"] = self._analyze_activations()

        return results

    def _get_all_parameter_digits(self) -> Dict[str, torch.Tensor]:
        """Extracts first significant digits for all model parameters."""
        param_digits = {}
        for name, param in self.model.named_parameters():
            if param.dim() < 1:  # Skip scalars
                continue
            logger.debug(f"Extracting digits from layer: {name}, shape: {param.shape}")
            first_digits = self.get_first_significant_digit(param.detach().cpu())
            param_digits[name] = first_digits[first_digits != 0].flatten()
        return param_digits

    def _analyze_per_layer(self, all_param_digits: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Analyzes the weights of each layer individually.
        """
        logger.info("Running per-layer analysis...")
        per_layer_stats = {}
        for name, digits_tensor in all_param_digits.items():
            if digits_tensor.numel() == 0:
                logger.warning(f"Layer {name} has no non-zero values to analyze.")
                continue
            per_layer_stats[name] = self._compute_distribution_stats(digits_tensor.numpy())
        return per_layer_stats

    def _analyze_aggregated(self, all_param_digits: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Performs an aggregated analysis across all model weights.
        """
        logger.info("Running aggregated analysis...")
        all_digits_list = [t for t in all_param_digits.values() if t.numel() > 0]
        if not all_digits_list:
            logger.warning("No parameters found to analyze.")
            return {}

        all_digits_list = torch.cat(all_digits_list).numpy()
        return self._compute_distribution_stats(all_digits_list)

    def _analyze_activations(self) -> Dict[str, Any]:
        """Analyzes activation tensors using forward hooks and calibration data."""
        logger.info("Running activation analysis.")

        digit_arrays = []

        def hook(_, __, output):
            if isinstance(output, torch.Tensor):
                digits = self.get_first_significant_digit(output.detach().cpu())
                digits = digits[digits != 0].flatten().numpy()
                if digits.size > 0:
                    digit_arrays.append(digits)
            elif isinstance(output, (list, tuple)):
                for out in output:
                    if isinstance(out, torch.Tensor):
                        digits = self.get_first_significant_digit(out.detach().cpu())
                        digits = digits[digits != 0].flatten().numpy()
                        if digits.size > 0:
                            digit_arrays.append(digits)

        target_layers = self.config["analysis"].get("layers", "all")
        handles = []

        if target_layers == "all":
            for module in self.model.modules():
                if len(list(module.children())) == 0:
                    handles.append(module.register_forward_hook(hook))
        else:
            for idx, module in enumerate(self.model.modules()):
                if idx in target_layers and len(list(module.children())) == 0:
                    handles.append(module.register_forward_hook(hook))

        calib_cfg = self.config["analysis"].get("calibration_data", {})
        dataset_name = calib_cfg.get("dataset")
        subset = calib_cfg.get("subset")
        num_samples = calib_cfg.get("num_samples", 1)
        text_column = calib_cfg.get("text_column", "text")

        # Add cache_dir if needed
        #cache_dir = "/path/to/cache"  # Change if needed
        #dataset = load_dataset(dataset_name, subset, split=f"train[:{num_samples}]", cache_dir=cache_dir)

        dataset = load_dataset(dataset_name, subset, split=f"train[:{num_samples}]")
        texts = list(dataset[text_column])

        tokenizer = AutoTokenizer.from_pretrained(
            self.config["model"]["name"],
            trust_remote_code=self.config["model"].get("trust_remote_code", False),
            device_map='auto'
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.model.eval()
        from torch.utils.data import DataLoader

        batch_size = 1
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to('cuda')

            # Skip empty inputs
            if inputs["input_ids"].shape[0] == 0 or inputs["input_ids"].shape[1] == 0:
                continue
            with torch.no_grad():
                self.model(**inputs)
            
            torch.cuda.empty_cache()

        for handle in handles:
            handle.remove()

        if not digit_arrays:
            logger.warning("No non-zero activations to analyze.")
            return self._compute_distribution_stats(np.array([]))

        aggregated_digits = np.concatenate(digit_arrays)
        return self._compute_distribution_stats(aggregated_digits)

    def _compute_distribution_stats(self, digits: np.ndarray) -> Dict[str, Any]:
        """
        Computes the distribution of digits and the Chi-squared test statistic. Also computes Mean Absolute Deviation (MAD).

        After experiments note: Chi-squared was not a good choice. It does not work well in the scenario. We dont use it in the paper.
        """
        # Only consider digits 1-9
        digits = digits[(digits >= 1) & (digits <= 9)]
        if digits.size == 0:
            return {"observed_dist": {}, "chi2_stat": None, "p_value": None, "total_count": 0}

        counts = np.bincount(digits, minlength=10)[1:10]
        total = counts.sum()
        observed_dist = counts / total if total > 0 else np.zeros_like(counts)

        expected_counts = self.benford_dist * total
        if total == 0 or np.all(expected_counts == 0):
            chi2_stat, p_value = None, None
        else:
            chi2_stat, p_value = chisquare(f_obs=counts, f_exp=expected_counts)

        stats = {
            "observed_dist": {d: float(p) for d, p in zip(range(1, 10), observed_dist)},
            "benford_dist": {d: float(p) for d, p in zip(range(1, 10), self.benford_dist)},
            "mean_absolute_deviation": float(np.mean(np.abs(observed_dist - self.benford_dist))),
            "std_absolute_deviation": float(np.std(np.abs(observed_dist - self.benford_dist))),
            "max_absolute_deviation": float(np.max(np.abs(observed_dist - self.benford_dist))),
            "min_absolute_deviation": float(np.min(np.abs(observed_dist - self.benford_dist))),
            "digit_deviation_sum": float(np.sum(np.abs(observed_dist - self.benford_dist))),
            "chi2_stat": float(chi2_stat) if chi2_stat is not None else None,
            "p_value": float(p_value) if p_value is not None else None,
            "total_count": int(total)
        }
        return stats

    def plot_results(self, stats: Dict[str, Any], output_dir: str):
        """
        Generates and saves bar plots comparing observed vs. Benford distributions.

        Args:
            stats (Dict[str, Any]): Nested analysis statistics produced by
                :meth:`analyze`.
            output_dir (str): Directory where plots will be saved.
        """
        import os
        import matplotlib.pyplot as plt
        import numpy as np

        os.makedirs(output_dir, exist_ok=True)

        def _plot(observed: Dict[int, float], expected: Dict[int, float], title: str, filename: str):
            digits = np.arange(1, 10)
            obs_vals = [observed.get(int(d), 0.0) for d in digits]
            exp_vals = [expected.get(int(d), 0.0) for d in digits]

            width = 0.4
            fig, ax = plt.subplots()
            ax.bar(digits - width / 2, obs_vals, width, label="Observed")
            ax.bar(digits + width / 2, exp_vals, width, label="Benford")
            ax.set_xticks(digits)
            ax.set_xlabel("Digit")
            ax.set_ylabel("Probability")
            ax.set_title(title)
            ax.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, filename))
            plt.close(fig)

        for param_type, param_stats in stats.items():
            # Direct stats (e.g., activations)
            if "observed_dist" in param_stats:
                _plot(
                    param_stats["observed_dist"],
                    param_stats["benford_dist"],
                    param_type.capitalize(),
                    f"{param_type}.png",
                )
                continue

            # Nested stats (e.g., weights)
            for stat_name, stat_value in param_stats.items():
                if stat_name == "per_layer_analysis":
                    for layer_name, layer_stats in stat_value.items():
                        if "observed_dist" not in layer_stats:
                            continue
                        safe_name = layer_name.replace(".", "_").replace("/", "_")
                        _plot(
                            layer_stats["observed_dist"],
                            layer_stats["benford_dist"],
                            f"{param_type.capitalize()} - {layer_name}",
                            f"{param_type}_{safe_name}.png",
                        )
                else:
                    if "observed_dist" not in stat_value:
                        continue
                    _plot(
                        stat_value["observed_dist"],
                        stat_value["benford_dist"],
                        f"{param_type.capitalize()} - {stat_name.replace('_', ' ').title()}",
                        f"{param_type}_{stat_name}.png",
                    )
