import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import argparse

def visualize_son_samples(n, well_num):
    """
    Generates and saves histogram visualizations for SO(n) samples.
    """
    dataset_name = f'SO{n}_{well_num}w'
    # Corrected data path
    data_path = f'./data/SOn/{dataset_name}.npy'
    # Corrected figure output directory
    output_dir = './data/figs/SOn'
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f'{dataset_name}_visualization.png')

    print(f"Loading data from {data_path}...")
    try:
        data = np.load(data_path)
    except FileNotFoundError:
        print(f"Error: Data file not found at {data_path}")
        print("Please ensure you have generated the data and are running this script from the 'cdiffusion' directory.")
        return

    samples = torch.from_numpy(data).float()
    samples = samples.reshape(-1, n, n)

    print(f"Generating plot for {dataset_name}...")
    trace_list = []
    power_list = [1, 2, 4, 5]
    for i in range(len(power_list)):
        samples_pow = torch.linalg.matrix_power(samples, power_list[i])
        trace = samples_pow.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)
        trace_list.append(trace)
    statistics = torch.cat(trace_list, dim=1).cpu().numpy()

    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    fig.suptitle(f'Trace Statistics for {dataset_name}', fontsize=16)

    for i, ax in enumerate(axes.flatten()):
        ax.hist(statistics[:, i], bins=200, alpha=1.0, density=True)
        ax.set_title(f'Trace of matrix power {power_list[i]}')
        ax.set_xlabel('Trace Value')
        ax.set_ylabel('Density')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(output_path, dpi=300)
    plt.close(fig)
    print(f"Success! Saved visualization to {output_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Visualize SO(n) data samples.')
    parser.add_argument('--n', type=int, required=True, help='Dimension of the SO(n) group.')
    parser.add_argument('--well_num', type=int, required=True, choices=[3, 5], help='Number of wells (clusters).')
    args = parser.parse_args()

    visualize_son_samples(n=args.n, well_num=args.well_num)