import torch
import numpy as np
import matplotlib.pyplot as plt
from manifolds.SOn import Manifold_SOn
from utils import set_seed_everywhere
import argparse
import os

def generate_SOn_data_memory_safe(well_num, n, samples_num=50000, batch_size=1000):
    """
    Generates and saves a dataset for the SO(n) manifold in batches to conserve memory.
    """
    set_seed_everywhere(777)
    assert well_num in [3, 5], "well_num should be 3 or 5."

    # --- Configuration ---
    std_list = torch.ones(well_num) * 0.05
    SOn = Manifold_SOn(n)
    dataset_name = f'SO{n}_{well_num}w'
    output_dir = "./data/SOn"
    fig_dir = "./datasets/figs/SOn" # Corrected figure directory

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(fig_dir, exist_ok=True)

    print(f"Generating dataset for {dataset_name} with n={n}...")

    # --- Generate Cluster Centers ---
    basic1 = np.array([[np.cos(np.pi / 3), np.sin(np.pi / 3)],
                       [-np.sin(np.pi / 3), np.cos(np.pi / 3)]])
    mat0 = np.eye(n)
    mat_list = []
    for i in range(well_num):
        mat_copy = mat0.copy()
        if 2 * (i + 1) <= n:
            mat_copy[2*i:2*(i+1), 2*i:2*(i+1)] = basic1
        mat_list.append(mat_copy)
    center_ori = torch.tensor(np.stack(mat_list), dtype=torch.float)

    mat_P = SOn.uniform_sample(well_num).reshape(well_num, n, n)
    center = torch.bmm(mat_P.transpose(dim0=1, dim1=2), torch.bmm(center_ori, mat_P))
    center_path = os.path.join(output_dir, f"{dataset_name}_center.npy")
    np.save(center_path, center.detach().cpu().numpy())
    print(f"Saved centers to {center_path}")

    # --- Generate Data Samples in Batches ---
    center = center.reshape(-1, n * n)
    all_data_gen = []
    print(f"Generating {samples_num} samples in batches of {batch_size}...")
    for i in range(0, samples_num, batch_size):
        current_batch_size = min(batch_size, samples_num - i)
        
        random_idx = torch.randint(low=0, high=well_num, size=(current_batch_size,))
        base_point = center[random_idx]
        
        gaussian_whole = torch.randn(current_batch_size, n * n) * std_list[random_idx].reshape(-1, 1)
        gaussian_proj = SOn.project_onto_tangent_space(gaussian_whole, base_point=base_point)
        data_gen_batch = SOn.exp(gaussian_proj, base_point=base_point)
        
        all_data_gen.append(data_gen_batch.cpu())
        print(f"  Generated {min(i + current_batch_size, samples_num)} / {samples_num} samples...")

    final_data = torch.cat(all_data_gen, dim=0)
    data_path = os.path.join(output_dir, f"{dataset_name}.npy")
    np.save(data_path, final_data.numpy())
    print(f"Saved data to {data_path}")

    # --- (Optional) Plotting for Verification in Batches ---
    print("Generating verification plot...")
    power_list = [1, 2, 4, 5]
    total_statistics = [[] for _ in power_list]

    # Use the generated data for plotting
    for i in range(0, samples_num, batch_size):
        batch_data = final_data[i:i+batch_size]
        samples = batch_data.reshape(-1, n, n)
        for j, power in enumerate(power_list):
            samples_pow = torch.linalg.matrix_power(samples, power)
            trace = samples_pow.diagonal(dim1=-2, dim2=-1).sum(dim=-1)
            total_statistics[j].append(trace.cpu().numpy())

    statistics = np.hstack([np.concatenate(stats) for stats in total_statistics]).reshape(-1, len(power_list))

    fig = plt.figure(figsize=(10, 10))
    for i in range(len(power_list)):
        ax = plt.subplot(2, 2, i + 1)
        ax.hist(statistics[:, i], bins=200, alpha=1., density=True)
        ax.set_title(f"Trace of matrix power {power_list[i]}")
    fig_path = os.path.join(fig_dir, f"hist_{dataset_name}.png")
    plt.savefig(fig_path, dpi=300)
    plt.close(fig)
    print(f"Saved verification plot to {fig_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate SO(n) data.')
    parser.add_argument('--n', type=int, required=True, help='Dimension of the SO(n) group.')
    parser.add_argument('--well_num', type=int, required=True, choices=[3, 5], help='Number of wells (clusters).')
    parser.add_argument('--samples_num', type=int, default=50000, help='Total number of samples to generate.')
    parser.add_argument('--batch_size', type=int, default=1000, help='Batch size for processing data to save memory.')
    args = parser.parse_args()

    generate_SOn_data_memory_safe(
        well_num=args.well_num, 
        n=args.n, 
        samples_num=args.samples_num, 
        batch_size=args.batch_size
    )