# filename: codebase/step5_visualization_interpretation.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import datetime
import tensorly as tl
from tensorly.decomposition import tensor_train
import networkx as nx
from sklearn.manifold import TSNE  # Using TSNE as UMAP might not be installed
                                  # If umap is available, can switch to: from umap import UMAP

# Required for unpickling data from previous steps
from collections import defaultdict
from torch_geometric.data import Data

# Configuration
PROCESSED_MERGER_TREES_PATH = 'data/processed_merger_trees.pt'  # From Step 1
FINAL_PROCESSED_DATA_PATH = 'data/final_processed_data.pt'       # From Step 2
QITT_PROCESSED_DATA_PATH = 'data/qitt_processed_data.pt'         # From Step 3
OUTPUT_DIR = 'data'

# Matplotlib settings
plt.rcParams['text.usetex'] = False
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.family'] = 'serif'  # Using a commonly available font

# Global plot counter
plot_counter = 1


def get_timestamp_str():
    """Generates a timestamp string for filenames."""
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


def main():
    """
    Main function for Step 5: Visualization and Interpretation.
    """
    global plot_counter
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    current_timestamp = get_timestamp_str()
    tl.set_backend("numpy")

    # --- Load data from previous steps ---
    print("Loading data from previous steps...")
    try:
        data_step1 = torch.load(PROCESSED_MERGER_TREES_PATH, map_location='cpu', weights_only=False)
        data_step2 = torch.load(FINAL_PROCESSED_DATA_PATH, map_location='cpu', weights_only=False)
        data_step3 = torch.load(QITT_PROCESSED_DATA_PATH, map_location='cpu', weights_only=False)
    except FileNotFoundError as e:
        print("Error: A required data file was not found. " + str(e))
        return
    except Exception as e:
        print("Error loading data: " + str(e))
        import traceback
        traceback.print_exc()
        return

    train_trees_step1 = data_step1['train_data']
    feature_means_step1 = data_step1['feature_means']
    feature_stds_step1 = data_step1['feature_stds']

    train_tensors_step2 = data_step2['train_tensors']
    gnn_embedding_dim = data_step2['gnn_config']['embedding_dim']
    max_n_sub_step2 = data_step2['max_N_sub']
    d_feat_combined_step2 = data_step2['D_feat_combined']
    num_physical_features = d_feat_combined_step2 - gnn_embedding_dim

    optimal_tt_rank_full_step3 = data_step3['optimal_tt_rank_full']
    reshaped_dims_step3 = data_step3['reshaped_tensor_dimensions']

    
    # --- 1. UMAP/t-SNE Visualization of Topological Embeddings ---
    print("\n--- Generating t-SNE plot of topological embeddings ---")
    all_topological_embeddings = []
    all_physical_feature_for_color = []  # Using 'num_halos_in_branch'

    for tree_idx, tree_s1_data in enumerate(train_trees_step1):
        if not hasattr(tree_s1_data, 'substructures'):
            continue
        
        tree_s2_tensor = train_tensors_step2[tree_idx]  # (max_N_sub, D_feat_combined)
        
        num_actual_subs_in_tree = len(tree_s1_data.substructures)
        
        for sub_idx in range(min(num_actual_subs_in_tree, max_n_sub_step2)):
            sub_s1_data = tree_s1_data.substructures[sub_idx]
            
            # Extract topological embedding from Step 2 tensor
            # Rows in tree_s2_tensor correspond to substructures
            topo_embedding = tree_s2_tensor[sub_idx, num_physical_features:].numpy()
            all_topological_embeddings.append(topo_embedding)
            
            # Extract physical feature for coloring (num_halos_in_branch is index 9)
            # Ensure physical_features exists and has enough elements
            if hasattr(sub_s1_data, 'physical_features') and len(sub_s1_data.physical_features) > 9:
                num_halos = sub_s1_data.physical_features[9].item()
                all_physical_feature_for_color.append(num_halos)
            else:  # Should not happen if data is consistent
                all_physical_feature_for_color.append(0)

    if not all_topological_embeddings:
        print("No topological embeddings found to plot.")
    else:
        all_topological_embeddings_np = np.array(all_topological_embeddings)
        all_physical_feature_for_color_np = np.array(all_physical_feature_for_color)
        
        print("Total number of substructure embeddings extracted: " + str(len(all_topological_embeddings_np)))
        
        # Subsample if too many points for t-SNE, for performance
        num_points_for_tsne = 10000
        if len(all_topological_embeddings_np) > num_points_for_tsne:
            print("Subsampling " + str(num_points_for_tsne) + " points for t-SNE plot.")
            indices = np.random.choice(len(all_topological_embeddings_np), num_points_for_tsne, replace=False)
            embeddings_for_tsne = all_topological_embeddings_np[indices]
            colors_for_tsne = all_physical_feature_for_color_np[indices]
        else:
            embeddings_for_tsne = all_topological_embeddings_np
            colors_for_tsne = all_physical_feature_for_color_np

        print("Running t-SNE on " + str(len(embeddings_for_tsne)) + " embeddings...")
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings_for_tsne)-1), n_iter=300, verbose=0)
        try:
            embeddings_2d = tsne.fit_transform(embeddings_for_tsne)

            plt.figure(figsize=(10, 8))
            # Prepare normalization for color scale if valid positive values exist
            if np.any(colors_for_tsne > 0):
                from matplotlib.colors import LogNorm
                norm_val = LogNorm(vmin=max(1, np.min(colors_for_tsne[colors_for_tsne > 0])), vmax=np.max(colors_for_tsne))
            else:
                norm_val = None
            
            scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                                  c=colors_for_tsne, cmap='viridis', alpha=0.6, s=10, norm=norm_val)  
            
            cbar = plt.colorbar(scatter)
            cbar.set_label('Number of Halos in Substructure Branch (log scale)')
            plt.title('t-SNE of Substructure Topological Embeddings')
            plt.xlabel('t-SNE Dimension 1')
            plt.ylabel('t-SNE Dimension 2')
            plt.grid(True, linestyle='--', alpha=0.5)
            plt.tight_layout()
            
            plot_filename = os.path.join(OUTPUT_DIR, "tsne_topological_embeddings_" + str(plot_counter) + "_" + str(current_timestamp) + ".png")
            plt.savefig(plot_filename)
            print("Saved plot: " + str(plot_filename))
            print("  Description: t-SNE visualization of " + str(len(embeddings_for_tsne)) + " substructure topological embeddings, colored by the number of halos in the substructure branch.")
            print("  Number of halos range: Min=" + str(np.min(colors_for_tsne)) + ", Max=" + str(np.max(colors_for_tsne)) + ", Median=" + str(np.median(colors_for_tsne)))
            plt.close()
            plot_counter += 1
        except Exception as e:
            print("Error during t-SNE plotting: " + str(e))
            print("  This might be due to too few points or issues with perplexity.")


    # --- 2. TT-Core Magnitude Distribution ---
    print("\n--- Plotting TT-Core magnitude distributions for an example tree ---")
    if train_tensors_step2:
        example_tree_tensor_2d_pt = train_tensors_step2[0]  # (max_N_sub, D_feat_combined)
        example_tree_tensor_2d_np = example_tree_tensor_2d_pt.numpy()
        
        # Reshape to 3D: (max_N_sub, factor1, factor2)
        # reshaped_dims_step3 = (max_N_sub, D_feat_factor1, D_feat_factor2)
        example_tree_tensor_3d_np = example_tree_tensor_2d_np.reshape(reshaped_dims_step3)
        
        print("Decomposing example tree tensor of shape: " + str(example_tree_tensor_3d_np.shape) + " with ranks " + str(optimal_tt_rank_full_step3))
        try:
            cores = tensor_train(example_tree_tensor_3d_np, rank=optimal_tt_rank_full_step3)
            
            fig, axes = plt.subplots(1, len(cores), figsize=(5 * len(cores), 4))
            if len(cores) == 1: 
                axes = [axes]  # Make iterable if only one core

            for i, core in enumerate(cores):
                core_values = core.flatten()
                axes[i].hist(core_values, bins=50, color='skyblue', edgecolor='black')
                axes[i].set_title("TT-Core " + str(i) + " (Shape: " + str(core.shape) + ")")
                axes[i].set_xlabel("Magnitude")
                axes[i].set_ylabel("Frequency")
                axes[i].grid(True, linestyle='--', alpha=0.5)
                print("  Core " + str(i) + " stats: Shape=" + str(core.shape) + ", MinVal=" + str(np.min(core_values)) + ", MaxVal=" + str(np.max(core_values)) + ", MeanVal=" + str(np.mean(core_values)))

            fig.suptitle("Distribution of TT-Core Magnitudes for an Example Tree Tensor", fontsize=14)
            plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for suptitle
            
            plot_filename = os.path.join(OUTPUT_DIR, "tt_core_magnitudes_" + str(plot_counter) + "_" + str(current_timestamp) + ".png")
            plt.savefig(plot_filename)
            print("Saved plot: " + str(plot_filename))
            print("  Description: Histograms of magnitudes for each TT-core from the decomposition of the first training tree's feature tensor.")
            plt.close(fig)
            plot_counter += 1
        except Exception as e:
            print("Error during TT-Core magnitude plotting: " + str(e))
    else:
        print("No training tensors available to plot TT-core magnitudes.")


    # --- 3. Example Substructure Graph Visualization ---
    print("\n--- Visualizing example substructure graphs ---")
    num_example_subs_to_plot = 2
    if train_trees_step1 and hasattr(train_trees_step1[0], 'substructures') and len(train_trees_step1[0].substructures) >= num_example_subs_to_plot:
        fig_subs, axes_subs = plt.subplots(1, num_example_subs_to_plot, figsize=(8 * num_example_subs_to_plot, 7))
        if num_example_subs_to_plot == 1:
            axes_subs = [axes_subs]

        sf_mean = feature_means_step1[3].item()
        sf_std = feature_stds_step1[3].item()

        for i in range(num_example_subs_to_plot):
            sub_data = train_trees_step1[0].substructures[i]
            ax = axes_subs[i]

            if sub_data.num_nodes == 0:
                ax.text(0.5, 0.5, "Empty Substructure", ha='center', va='center')
                ax.set_title("Example Substructure " + str(i+1) + " (Empty)")
                continue

            # Unnormalize scale factor for coloring
            node_sf_normalized = sub_data.x[:, 3].numpy()
            node_sf_unnormalized = node_sf_normalized * sf_std + sf_mean
            
            # Create NetworkX graph
            G = nx.Graph()  # Using Undirected for visualization simplicity
            G.add_nodes_from(range(sub_data.num_nodes))
            if hasattr(sub_data, 'edge_index') and sub_data.edge_index is not None:
                edges = sub_data.edge_index.t().tolist()
                G.add_edges_from(edges)

            pos = nx.spring_layout(G, seed=42, k=0.5/np.sqrt(G.number_of_nodes()) if G.number_of_nodes() > 0 else 1)  # k for spread
            
            node_colors = node_sf_unnormalized
            nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, cmap=plt.cm.coolwarm, node_size=50)
            nx.draw_networkx_edges(G, pos, ax=ax, alpha=0.5)
            # nx.draw_networkx_labels(G, pos, ax=ax, font_size=8)  # Optional: labels

            ax.set_title("Example Substructure " + str(i+1) + " (Nodes: " + str(sub_data.num_nodes) + ", Edges: " + str(G.number_of_edges()) + ")")
            ax.axis('off')
            
            # Add colorbar for scale factor
            sm = plt.cm.ScalarMappable(cmap=plt.cm.coolwarm, norm=plt.Normalize(vmin=min(node_colors), vmax=max(node_colors)))
            sm.set_array([])
            cbar = fig_subs.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
            cbar.set_label('Scale Factor')
            
            print("  Substructure " + str(i+1) + " plotted: Nodes=" + str(sub_data.num_nodes) + ", Edges=" + str(G.number_of_edges()) + ", SF range: " + str(np.min(node_sf_unnormalized)) + "-" + str(np.max(node_sf_unnormalized)))

        fig_subs.suptitle("Visualization of Example Substructures from Tree 0", fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        
        plot_filename = os.path.join(OUTPUT_DIR, "example_substructure_graphs_" + str(plot_counter) + "_" + str(current_timestamp) + ".png")
        plt.savefig(plot_filename)
        print("Saved plot: " + str(plot_filename))
        print("  Description: Network graph visualizations of " + str(num_example_subs_to_plot) + " example substructures from the first training tree, nodes colored by scale factor.")
        plt.close(fig_subs)
        plot_counter += 1
    else:
        print("Not enough substructures in the first training tree to plot examples, or no training trees loaded.")

    print("\nStep 5 (Visualization and Interpretation) complete.")


if __name__ == '__main__':
    main()
