import os
import numpy as np
import matplotlib.pyplot as plt


class ChargeEvaluator:
    """Evaluates material samples by analyzing their charge distributions."""

    def __init__(self, save_dir='cache/charge_eval'):
        """
        Initialize the charge evaluator.

        Args:
            save_dir (str): Directory to save evaluation results.
        """
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)

    def extract_charge_values(self, samples):
        """
        Extract charge values from samples.

        Args:
            samples (list): List of Sample objects.

        Returns:
            np.ndarray: Array of charge values.
        """
        charges = []
        for sample in samples:
            charge = sample.get_charge()
            if charge is not None:
                charges.append(charge)

        return np.array(charges)

    def plot_charge_distributions(self, datasets_data, figsize=(15, 10),
                                 bins=50, alpha=0.7, density=True):
        """
        Plot charge distributions for multiple datasets in subplots.

        Args:
            datasets_data (dict): Dictionary mapping dataset names to list of Sample objects.
            figsize (tuple): Figure size (width, height).
            bins (int): Number of histogram bins.
            alpha (float): Transparency of histograms.
            density (bool): Whether to normalize histograms to show density.
        """
        n_datasets = len(datasets_data)

        # Calculate subplot layout
        if n_datasets == 1:
            rows, cols = 1, 1
        elif n_datasets == 2:
            rows, cols = 1, 2
        elif n_datasets <= 4:
            rows, cols = 2, 2
        elif n_datasets <= 6:
            rows, cols = 2, 3
        elif n_datasets <= 9:
            rows, cols = 3, 3
        else:
            rows, cols = 4, 4

        fig, axes = plt.subplots(rows, cols, figsize=figsize)

        # Handle single subplot case
        if n_datasets == 1:
            axes = [axes]
        elif rows == 1 or cols == 1:
            axes = axes.flatten()
        else:
            axes = axes.flatten()

        colors = ['blue', 'red', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan']

        for i, (dataset_name, samples) in enumerate(datasets_data.items()):
            if i >= len(axes):
                break

            ax = axes[i]
            charges = self.extract_charge_values(samples)

            if len(charges) > 0:
                color = colors[i % len(colors)]

                # Plot histogram
                ax.hist(charges, bins=bins, alpha=alpha, color=color,
                       density=density, edgecolor='black', linewidth=0.5)

                # Add statistics text
                mean_charge = np.mean(charges)
                std_charge = np.std(charges)
                balanced_count = np.sum(charges == 0)
                balanced_proportion = balanced_count / len(charges) * 100
                ax.axvline(0, color='red', linestyle='--', linewidth=2)

                ax.set_title(f'{dataset_name}\n(μ={mean_charge:.3f}, σ={std_charge:.3f}, {balanced_proportion:.2f}% balanced)')
                ax.set_xlabel('Charge')
                ax.set_ylabel('Density' if density else 'Count')
                ax.grid(True, alpha=0.3)
            else:
                ax.text(0.5, 0.5, f'{dataset_name}\nNo charge data',
                       transform=ax.transAxes, ha='center', va='center',
                       fontsize=12)
                ax.set_title(dataset_name)

        # Hide unused subplots
        for i in range(len(datasets_data), len(axes)):
            axes[i].set_visible(False)

        plt.tight_layout()

        # Save plots
        plt.savefig(os.path.join(self.save_dir, "charge_distributions.png"), dpi=300, bbox_inches='tight')
        plt.savefig(os.path.join(self.save_dir, "charge_distributions.pdf"), bbox_inches='tight')
        plt.close()

        print(f"Charge distribution plots saved to: {self.save_dir}")

    def evaluate_datasets(self, datasets_data, **plot_kwargs):
        """
        Evaluate charge distributions for multiple datasets.

        Args:
            datasets_data (dict): Dictionary mapping dataset names to list of Sample objects.
            **plot_kwargs: Additional keyword arguments for plotting.
        """
        self.plot_charge_distributions(datasets_data, **plot_kwargs)
