import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from .explainer_utils import create_edge_embeds, sample_graph

ORANGE = "#FF0000"
BLUE = "#0072B2"
GREY = "#4D4D4D"
LABEL_FONTSIZE = 36
TICK_FONTSIZE = 30
LEGEND_FONTSIZE = 20


class DistributionAnalyzer:
    
    @staticmethod
    def calculate_kl_divergence(p_scores, q_scores, bins=50, eps=1e-8):
        
        if len(p_scores) == 0 or len(q_scores) == 0:
            return float('inf')
        
        all_scores = np.concatenate([p_scores, q_scores])
        min_val, max_val = np.min(all_scores), np.max(all_scores)
        
        p_hist, _ = np.histogram(p_scores, bins=bins, range=(min_val, max_val), density=True)
        q_hist, _ = np.histogram(q_scores, bins=bins, range=(min_val, max_val), density=True)
        
        bin_width = (max_val - min_val) / bins
        
        p_prob = p_hist * bin_width
        q_prob = q_hist * bin_width
        
        p_prob = np.where(p_prob == 0, eps, p_prob)
        q_prob = np.where(q_prob == 0, eps, q_prob)
        
        kl_div = np.sum(p_prob * np.log(p_prob / q_prob))
        
        return kl_div
    
    @staticmethod
    def compare_base_vs_round1_distributions(test_loader, model, explainer, explainer_round1, device, dataset, explainer_name,
        phase_label: str = "Phase 2 (Guided Explainer)",
        save_suffix: str = "",
        phase_tag: str = "Phase 2",
        abl = None
    ):        
        base_positive_scores = []
        base_negative_scores = []
        round1_positive_scores = []
        round1_negative_scores = []
        
        explainer.eval()
        explainer_round1.eval()
        model.eval()
        
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                
                _, _, node_embeds = model(batch)
                
                edge_embeds = create_edge_embeds(batch.edge_index, node_embeds).unsqueeze(dim=0)
                sampling_weights = explainer(edge_embeds).squeeze(-1)
                base_mask_scores = sample_graph(sampling_weights, device, training=False)
                
                round1_mask_scores = explainer_round1(batch).squeeze()
                
                edge_gt = batch.edge_gt.squeeze()
                
                base_scores = base_mask_scores.detach().cpu().numpy()
                round1_scores = round1_mask_scores.detach().cpu().numpy()
                gt_labels = edge_gt.detach().cpu().numpy()
                
                positive_mask = (gt_labels == 1)
                negative_mask = (gt_labels == 0)
                
                if positive_mask.sum() > 0:
                    base_positive_scores.extend(base_scores[positive_mask])
                    round1_positive_scores.extend(round1_scores[positive_mask])
                if negative_mask.sum() > 0:
                    base_negative_scores.extend(base_scores[negative_mask])
                    round1_negative_scores.extend(round1_scores[negative_mask])
        
        if abl is None:
            viz_dir = f"visualization/{dataset}"
        else:
            viz_dir = f"visualization_{abl}/{dataset}"
        os.makedirs(viz_dir, exist_ok=True)
        
        fig, axes = plt.subplots(1, 2, figsize=(20, 8))
        axes = np.atleast_1d(axes)
        common_bins = np.linspace(0.0, 1.0, 51)

        if len(base_positive_scores) > 0:
            axes[0].hist(
                base_positive_scores,
                bins=common_bins,
                color=ORANGE,
                alpha=0.5,
                label='GT=1 (Important)',
                density=True,
                edgecolor='black',
            )
        if len(base_negative_scores) > 0:
            axes[0].hist(
                base_negative_scores,
                bins=common_bins,
                color=BLUE,
                alpha=0.5,
                label='GT=0 (Unimportant)',
                density=True,
                edgecolor='black',
            )
        axes[0].set_xlabel('Mask Score', fontsize=LABEL_FONTSIZE)
        axes[0].set_ylabel('Density', fontsize=LABEL_FONTSIZE)
        axes[0].set_title('Base', fontsize=LABEL_FONTSIZE)
        axes[0].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
        axes[0].legend(fontsize=LEGEND_FONTSIZE)
        axes[0].grid(True, alpha=0.3)

        if len(round1_positive_scores) > 0:
            axes[1].hist(
                round1_positive_scores,
                bins=common_bins,
                color=ORANGE,
                alpha=0.5,
                label='GT=1 (Important)',
                density=True,
                edgecolor='black',
            )
        if len(round1_negative_scores) > 0:
            axes[1].hist(
                round1_negative_scores,
                bins=common_bins,
                color=BLUE,
                alpha=0.5,
                label='GT=0 (Unimportant)',
                density=True,
                edgecolor='black',
            )
        axes[1].set_xlabel('Mask Score', fontsize=LABEL_FONTSIZE)
        axes[1].set_ylabel('Density', fontsize=LABEL_FONTSIZE)
        if save_suffix == "first":
            axes[1].set_title('Round 1', fontsize=LABEL_FONTSIZE)
        else: 
            axes[1].set_title('Round 2', fontsize=LABEL_FONTSIZE)
        axes[1].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
        axes[1].legend(fontsize=LEGEND_FONTSIZE)
        axes[1].grid(True, alpha=0.3)

        # for ax in axes:
        #     ax.set_title('')

        plt.tight_layout()
        save_path_individual = f"{viz_dir}/distribution_{save_suffix}.png"
        plt.savefig(save_path_individual, dpi=300, bbox_inches='tight')
        plt.close()

        print(f"\nVisualization saved to:")
        print(f"  - {save_path_individual}")
