# FILE: cdiffusion/datasets/get_general_data.py

import torch
import numpy as np
import os
import argparse
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from torch.func import vmap
import pickle

# Import necessary libraries
try:
    from torch_geometric.datasets import QM9
    from scipy.spatial.distance import pdist, squareform
    from sklearn.manifold import TSNE
    from rdkit import Chem
    RDKIT_AVAILABLE = True
except ImportError:
    print("One or more required libraries are not installed. Please install them to use this script fully.")
    print("pip install torch_geometric scipy scikit-learn rdkit-pypi")
    RDKIT_AVAILABLE = False


import sys

# Add the parent directory to the path to allow importing 'constraints'
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from constraints import get_constraint_functions

# Dictionary to map atomic number to symbol for .xyz format
ATOM_MAP = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}

def mol_from_torch_geometric(data):
    """
    Creates a canonical SMILES string from a torch_geometric Data object.
    This version uses RDKit's powerful xyz block parsing for robust bond perception.
    """
    if not RDKIT_AVAILABLE: return None
    
    positions = data.pos.numpy()
    atom_types = data.z.numpy()
    num_atoms = len(atom_types)

    # Create an .xyz file format string
    xyz_string = f"{num_atoms}\n\n"
    for i in range(num_atoms):
        atom_symbol = ATOM_MAP.get(int(atom_types[i]), 'X') # Default to 'X' if atom is not in map
        x, y, z = positions[i]
        xyz_string += f"{atom_symbol} {x: .8f} {y: .8f} {z: .8f}\n"

    # Use RDKit's MolFromXYZBlock to automatically infer bonds and structure
    mol = Chem.MolFromXYZBlock(xyz_string)
    
    if mol is None:
        return None

    try:
        # Sanitize and generate canonical SMILES
        Chem.SanitizeMol(mol)
        return Chem.MolToSmiles(mol, canonical=True)
    except:
        # RDKit may fail to sanitize if the geometry is highly unreasonable
        return None


def sample_on_sphere(n_samples, dim=3):
    """Generates random points uniformly on the surface of a unit sphere."""
    x = torch.randn(n_samples, dim)
    x = x / torch.linalg.norm(x, dim=1, keepdim=True)
    return x

# =================================================================================
# --- DATA GENERATION FUNCTIONS ---
# =================================================================================

def generate_sphere_mog_data(h_func, g_func, n_samples=5000, dim=3, n_clusters=10, std=0.15):
    """
    Generates a Mixture of Gaussians dataset on a sphere.
    """
    print("Generating 'sphere_mog' data...")
    centers = torch.zeros(n_clusters, dim)
    for i in range(n_clusters):
        while True:
            center_candidate = sample_on_sphere(1, dim)
            if g_func is not None and torch.all(g_func(center_candidate) <= 0):
                centers[i] = center_candidate
                break
    print("Successfully sampled cluster centers within the constrained region.")
    all_samples = []
    with torch.no_grad():
        while len(all_samples) < n_samples:
            num_needed = n_samples - len(all_samples)
            batch_size = int(num_needed * 1.2) + 1
            cluster_assignments = torch.randint(0, n_clusters, (batch_size,))
            noise = torch.randn(batch_size, dim) * std
            samples_ambient = centers[cluster_assignments] + noise
            samples_on_sphere = samples_ambient / torch.linalg.norm(samples_ambient, dim=1, keepdim=True)
            if g_func is not None:
                inequality_mask = torch.all(g_func(samples_on_sphere) <= 0, dim=1)
                samples_constrained = samples_on_sphere[inequality_mask]
            else:
                samples_constrained = samples_on_sphere
            all_samples.extend(samples_constrained.cpu().tolist())
    final_samples_tensor = torch.tensor(all_samples[:n_samples])
    print(f"Successfully generated exactly {final_samples_tensor.shape[0]} samples.")
    return final_samples_tensor.numpy()


def generate_robot_arm_data(h_func, g_func, **kwargs):
    """Placeholder for future robot_arm data generation."""
    print("Data generation for 'robot_arm' is not yet implemented.")
    raise NotImplementedError

def generate_mol_gen_data(h_func, g_func, n_samples=2000, num_atoms_exact=9):
    """
    Generates a molecular dataset based on QM9, filtered for an EXACT number of atoms.
    Saves distance matrices, atom types (Z), and generated SMILES strings.
    """
    if not RDKIT_AVAILABLE:
        raise ImportError("RDKit is required for the 'mol_gen' dataset generation.")

    print(f"Generating 'mol_gen' data for molecules with exactly {num_atoms_exact} atoms...")
    try:
        print("Loading QM9 dataset...")
        full_dataset = QM9(root='./data/QM9')
        print("QM9 dataset loaded successfully.")

        distance_matrices = []
        atom_types = []
        smiles_list = []

        for molecule_data in full_dataset:
            if len(distance_matrices) >= n_samples:
                break 

            if molecule_data.num_nodes == num_atoms_exact:
                # Generate SMILES string from the ground truth data to ensure validity
                smiles = mol_from_torch_geometric(molecule_data)
                if smiles is None: # Skip if RDKit cannot process the molecule
                    continue

                positions = molecule_data.pos.numpy()
                dist_matrix = squareform(pdist(positions))
                
                distance_matrices.append(dist_matrix)
                atom_types.append(molecule_data.z.numpy())
                smiles_list.append(smiles)

        if not distance_matrices:
             print(f"Error: No molecules found with exactly {num_atoms_exact} atoms. Please try a different number.")
             return None, None, None
        
        if len(distance_matrices) < n_samples:
            print(f"Warning: Found only {len(distance_matrices)} molecules with exactly {num_atoms_exact} atoms.")

        final_data = np.array(distance_matrices)
        final_atom_types = np.array(atom_types)
        
        print(f"Successfully generated {final_data.shape[0]} samples with shape {final_data.shape[1:]}.")
        
        return final_data, final_atom_types, smiles_list

    except NameError:
        print("Could not generate 'mol_gen' data. Required libraries might be missing.")
        raise
    except Exception as e:
        print(f"An error occurred during 'mol_gen' data generation: {e}")
        raise

# =================================================================================
# --- VISUALIZATION FUNCTIONS (Placeholders - can be filled as needed) ---
# =================================================================================
def visualize_sphere_mog(data, g_func, output_path):
    """
    Creates a dynamic visualization for the sphere_mog dataset.
    """
    # This function can be filled with the previous visualization code if needed.
    print("Sphere_mog visualization is defined but will be skipped in this run.")
    pass 

def visualize_mol_gen(data, atom_counts, output_path):
    """
    Visualizes the mol_gen dataset with a sample matrix and a t-SNE plot.
    """
    # This function can be filled with the previous visualization code if needed.
    print("Mol_gen visualization is defined but will be skipped in this run.")
    pass 


# =================================================================================
# --- MAIN EXECUTION ---
# =================================================================================

def main():
    parser = argparse.ArgumentParser(description="Generate data for constrained diffusion models.")
    parser.add_argument('--dataset', type=str, required=True, choices=['sphere_mog', 'robot_arm', 'mol_gen'])
    parser.add_argument('--n_samples', type=int, default=2000)
    parser.add_argument('--num_atoms', type=int, default=12, help='Exact number of atoms for mol_gen dataset.')
    parser.add_argument('--no_viz', action='store_true', help='Skip visualization after data generation.')
    args = parser.parse_args()

    # Construct the output directory path relative to the script's location
    # Assumes the script is in 'cdiffusion/datasets/'
    script_dir = os.path.dirname(__file__)
    output_dir = os.path.join(script_dir, '..', 'data', 'general')
    os.makedirs(output_dir, exist_ok=True)

    # Note: h_func and g_func are loaded but not used by mol_gen data generation
    # as constraints are implicit in the QM9 data structure.
    h_func, g_func = get_constraint_functions(args.dataset)
    h_func_vmap, g_func_vmap = vmap(h_func), vmap(g_func)

    data, atom_types, train_smiles = None, None, None
    
    if args.dataset == 'sphere_mog':
        data = generate_sphere_mog_data(h_func_vmap, g_func_vmap, n_samples=args.n_samples)
    elif args.dataset == 'robot_arm':
        data = generate_robot_arm_data(h_func, g_func)
    elif args.dataset == 'mol_gen':
        data, atom_types, train_smiles = generate_mol_gen_data(h_func, g_func, n_samples=args.n_samples, num_atoms_exact=args.num_atoms)

    if data is not None:
        file_path = os.path.join(output_dir, f"{args.dataset}_{args.num_atoms}atoms.npy")
        np.save(file_path, data)
        print(f"Data saved to {file_path}")
        
        if atom_types is not None:
            types_path = os.path.join(output_dir, f"{args.dataset}_{args.num_atoms}atoms_atom_types.npy")
            np.save(types_path, atom_types)
            print(f"Atom types saved to {types_path}")
        
        if train_smiles is not None:
            smiles_path = os.path.join(output_dir, f"{args.dataset}_{args.num_atoms}atoms_train_smiles.pkl")
            with open(smiles_path, 'wb') as f:
                pickle.dump(set(train_smiles), f)
            print(f"Training SMILES saved to {smiles_path}")

        # Optional visualization call
        if not args.no_viz:
            viz_path = os.path.join(output_dir, f"{args.dataset}_{args.num_atoms}atoms_viz.png")
            if args.dataset == 'sphere_mog':
                # You can fill in the visualize_sphere_mog function to use it
                pass
            elif args.dataset == 'mol_gen' and atom_types is not None:
                 # You can fill in the visualize_mol_gen function to use it
                 pass


if __name__ == '__main__':
    main()