from tkinter.constants import TRUE
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
from torch_geometric.utils import to_dense_batch, subgraph
import py3Dmol
from rdkit import Chem
import torch_geometric
import torch_geometric.datasets
import matplotlib.pyplot as plt
from datetime import datetime
import copy
import random
import math
import os

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

def generate_random_quaternion():
    rotation = R.random()  # Generate a random SO(3) rotation
    quaternion = rotation.as_quat()  # Convert to quaternion (x, y, z, w)
    return quaternion

def generate_many_random_quaternions(num):
    fixed_quaternions = np.zeros((num, 4)) # generating more than needed, but it's clean
    for i in range(num):
        fixed_quaternions[i] = generate_random_quaternion()
    return fixed_quaternions

def sample_random_rotations(n_rotations):
    """Generate n_rotations random rotation matrices from SO(3)."""
    from scipy.spatial.transform import Rotation as R
    return torch.tensor(R.random(n_rotations).as_matrix(), dtype=torch.float32)  # (n_rot, 3, 3)

def random_rotation_matrix():
    """
    Generate a random 3x3 rotation matrix in PyTorch.

    Returns:
        torch.Tensor: A 3x3 rotation matrix.
    """
    # Generate a random 3x3 matrix
    random_matrix = torch.randn(3, 3)

    # Perform QR decomposition to orthogonalize the matrix
    q, r = torch.linalg.qr(random_matrix)

    # Ensure the determinant is 1 (right-handed coordinate system)
    if torch.det(q) < 0:
        q[:, 2] *= -1  # Flip the sign of the third column if the determinant is -1

    return q

def canonicalize_points(points, method='closest'):
    """
    Canonicalize an N x 3 tensor of points using a rotation matrix derived
    from Gram-Schmidt orthogonalization of the two points with the greatest norm.

    Args:
        points (torch.Tensor): An N x 3 tensor containing points.
        method (str): 'closest' or 'PCA' or 'PCA_rand_flip' or 'smooth_sample'

    Returns:
        torch.Tensor: The canonicalized N x 3 tensor of points.
    """
    if points.size(1) != 3:
        raise ValueError("Input tensor must have dimensions N x 3.")
    
    # should I center the points first??
    centers = torch.sum(points, dim=0) / points.shape[0]
    points = points - centers

    if method == 'closest' or 'smooth' in method: # == 'smooth_sample':
        # Compute the norms of all points
        norms = torch.norm(points, dim=1)

        if method == 'closest':
            # Find the indices of the two points with the greatest norms
            top_indices = torch.topk(norms, 2).indices
            v1, v2 = points[top_indices[0]], points[top_indices[1]]
        elif method == 'smooth_sample':
            sampled_inds = torch.multinomial(norms, num_samples=2, replacement=False)
            v1 = points[sampled_inds[0]]
            v2 = points[sampled_inds[1]]
        elif method == 'smooth_proper_sample':
            sampled_ind = torch.multinomial(norms, num_samples=1)
            v1 = points[sampled_ind][0, :]
            u1 = v1 / torch.norm(v1)
            # print('v1', v1.shape, 'u1', u1.shape)
            # for v2 in points:
            #     print('v2 in points', v2.shape)
            #     break
            all_proj_onto_u1 = [(torch.dot(v2, u1) * u1) for v2 in points]
            proj_norms = [torch.norm(pp) for pp in all_proj_onto_u1] # set to 0 for sampled_ind
            proj_norms[sampled_ind] = 0
            sampled_ind2 = torch.multinomial(torch.tensor(proj_norms), num_samples=1)
            v2 = points[sampled_ind2][0, :]
            #print('v2', v2.shape, 'v1', v1.shape, 'u1', u1.shape)
            # print('sampled_ind', sampled_ind, 'sampled_ind2', sampled_ind2)
            # for i, elt in enumerate(norms):
            #     print(f'norms {i} {elt}')
            # for i, elt in enumerate(proj_norms):
            #     print(f'proj_norms {i} {elt}')
            if sampled_ind == sampled_ind2:
                print('error: sampled the same ind')



            # v2 = # sample prop to the norm of the ortho projection onto v1

        # Gram-Schmidt orthogonalization
        if torch.norm(v1) < 1e-5:
            print(f'Error: norm of v1 is too small {torch.norm(v1)}')
        u1 = v1 / torch.norm(v1)  # Normalize v1 to create u1
        proj_v2_u1 = (torch.dot(v2, u1) * u1)  # Projection of v2 onto u1
        u2 = (v2 - proj_v2_u1) / torch.norm(v2 - proj_v2_u1)  # Orthogonalize and normalize
        if torch.norm(v2 - proj_v2_u1) < 1e-5:
            print(f'Error: v1 and v2 are linearly dependent {torch.norm(v2 - proj_v2_u1)}')
        u3 = torch.linalg.cross(u1, u2)  # Compute u3 as orthogonal to both u1 and u2

        # Create the rotation matrix
        R = torch.stack([u1, u2, u3], dim=1)  # Each column is a basis vector
    elif 'PCA' in method:
        # Compute the covariance matrix
        centered_points = points
        cov_matrix = torch.matmul(centered_points.T, centered_points) / (points.shape[0] - 1)
        
        # Compute the eigenvalues and eigenvectors (PCA)
        eigenvalues, eigenvectors = torch.linalg.eigh(cov_matrix)
        
        # Sort eigenvectors by eigenvalues in descending order
        sorted_indices = torch.argsort(eigenvalues, descending=True)
        R = eigenvectors[:, sorted_indices]
        
        # Optionally flip signs randomly
        if 'rand' in method or 'flip' in method:
            for i in range(R.shape[1]):
                if random.choice([True, False]):
                    R[:, i] *= -1

    # Apply the rotation matrix to all points
    canonicalized_points = torch.matmul(points, R)

    return canonicalized_points

def get_fixed_random_rotation_matrix(data):
    """
    Generate a deterministic pseudo-random 3x3 rotation matrix based on a hash of the input data.
    
    Args:
        data (torch_geometric.data.Data): Contains at least 'pos' (N x 3) and 'z' (N, atomic numbers).
    
    Returns:
        torch.Tensor: 3x3 rotation matrix (orthonormal, det=+1)
    """
    import hashlib
    # Extract data as numpy arrays for hashing
    pos_np = data.pos.cpu().numpy()
    z_np = data.z.cpu().numpy()

    # Flatten and convert to bytes
    pos_bytes = pos_np.tobytes()
    z_bytes = z_np.tobytes()

    # Combine and hash
    combined = pos_bytes + z_bytes
    hash_digest = hashlib.sha256(combined).digest()

    # Use first 4 or 8 bytes of hash to create seed int
    seed = int.from_bytes(hash_digest[:4], byteorder='big', signed=False)

    # Create numpy RandomState with seed
    rng = np.random.RandomState(seed)

    # Generate a random 3x3 matrix
    random_matrix = rng.randn(3, 3)

    # Use QR decomposition to get orthonormal matrix
    q, r = np.linalg.qr(random_matrix)

    # Ensure determinant is +1 (proper rotation)
    if np.linalg.det(q) < 0:
        q[:, 2] *= -1

    # Convert to torch tensor on the right device and dtype
    R = torch.tensor(q, dtype=data.pos.dtype, device=data.pos.device)

    return R

def align_vectors(v_from, v_to):
    v_from = v_from / v_from.norm()
    v_to = v_to / v_to.norm()

    cross = torch.cross(v_from, v_to)
    dot = torch.dot(v_from, v_to)
    if torch.allclose(cross, torch.zeros_like(cross), atol=1e-6):
        # Already aligned or opposite
        if dot > 0:
            return torch.eye(3)
        else:
            # 180-degree rotation around any orthogonal vector
            orth = torch.tensor([1.0, 0.0, 0.0])
            if torch.allclose(v_from, orth, atol=1e-3):
                orth = torch.tensor([0.0, 1.0, 0.0])
            axis = torch.cross(v_from, orth)
            axis = axis / axis.norm()
            return rotation_matrix(axis, torch.tensor(torch.pi))
    
    axis = cross / cross.norm()
    angle = torch.acos(torch.clamp(dot, -1.0, 1.0))
    return rotation_matrix(axis, angle)

def rotation_matrix(axis, angle):
    axis = axis / axis.norm()
    K = torch.tensor([
        [0, -axis[2], axis[1]],
        [axis[2], 0, -axis[0]],
        [-axis[1], axis[0], 0]
    ])
    I = torch.eye(3)
    return I + torch.sin(angle) * K + (1 - torch.cos(angle)) * K @ K

def is_lex_smaller(a, b):
    """
    Lexicographically compare two coordinate arrays a and b.
    Args:
        a, b: np.ndarray shape (N_atoms, 3)
    Returns:
        True if a < b lex order, else False.
    """
    a_flat = np.array(a).flatten()
    b_flat = np.array(b).flatten()
    for x, y in zip(a_flat, b_flat):
        if x < y:
            return True
        elif x > y:
            return False
    return False  # equal arrays not smaller

def molecule_signature(mol):
    """
    Compute signature for canonicalization:
    Sort atoms by species string then by coords (rounded).
    Assumes molecule is already centered.
    """
    coords = np.array([site.coords for site in mol.sites])
    species = [site.species_string for site in mol.sites]
    # Combine for sorting
    sites = list(zip(species, coords))
    # Sort by species then lex coords
    sites_sorted = sorted(sites, key=lambda x: (x[0], *np.round(x[1], 6)))
    coords_sorted = np.array([s[1] for s in sites_sorted])
    return coords_sorted

def get_canonical_rotation_matrix_qm7(data, method='closest'):
    """
    Get the rotation matrix to canonicalize points.

    Args:
        data (torch_geometric.data.Data): Contains pos, z, and optionally dipole.
        method (str): 'closest', 'dipole', 'inertia', or 'dipole_inertia'. Note that
        the data has already been centered to the center of nuclear charge so there's no need
        to recompute the center of mass. 

    Returns:
        torch.Tensor: A 3x3 rotation matrix.
    """
    ### TODO: clean this up
    points = data.pos
    if points.size(1) != 3:
        raise ValueError("Input tensor must have dimensions N x 3.")
    
    from ase.data import atomic_masses
    z = data.z.cpu().numpy()
    masses = torch.tensor([atomic_masses[int(i)] for i in z], dtype=torch.float32, device=points.device)

    if method == 'closest':
        norms = torch.norm(points, dim=1)
        top_indices = torch.topk(norms, 2).indices
        v1, v2 = points[top_indices[0]], points[top_indices[1]]

        u1 = v1 / torch.norm(v1)
        proj_v2_u1 = torch.dot(v2, u1) * u1
        delta = v2 - proj_v2_u1
        norm_delta = torch.norm(delta)
        if norm_delta < 1e-6:
            # fall back to identity rotation or some default
            R = torch.eye(3, device=points.device)
        else:

            u2 = (v2 - proj_v2_u1) / torch.norm(v2 - proj_v2_u1)
            u3 = torch.cross(u1, u2)

            R = torch.stack([u1, u2, u3], dim=1)
            if torch.det(R) < 0:
                R[:, 2] *= -1

    elif method == 'dipole':
        dipole_vector = data.dipole[0,:]
        dipole_norm = torch.norm(dipole_vector)
        if dipole_norm < 1e-8:
            R = torch.eye(3, device=points.device)
        else:
            dipole_vector = dipole_vector / dipole_norm
            z_axis = torch.tensor([0.0, 0.0, 1.0], dtype=points.dtype, device=points.device)

            if torch.allclose(dipole_vector, z_axis, atol=1e-8):
                R = torch.eye(3, device=points.device, dtype=points.dtype)
            elif torch.allclose(dipole_vector, -z_axis, atol=1e-8):
                R = torch.diag(torch.tensor([1., -1., -1.], device=points.device))

            else:
                R = align_vectors(dipole_vector, torch.tensor([0., 0., 1.]))


    elif method == 'inertia':
        # Compute the inertia tensor
        I = torch.zeros(3, 3, device=points.device)
        for i in range(points.shape[0]):
            r = points[i]
            m = masses[i]
            I += m * (torch.dot(r, r) * torch.eye(3, device=points.device) - torch.outer(r, r))
        
        # Diagonalize inertia tensor
        eigvals, eigvecs = torch.linalg.eigh(I)
        idx = torch.argsort(eigvals)  # ascending
        R = eigvecs[:, idx]

        # Make it a proper rotation
        if torch.det(R) < 0:
            R[:, 2] *= -1 

    elif method == 'pointgroup':
        # Note for C1 (a significant part of the dataset), this does nothing
        # For now, try this, and then make more physically informed? idk
        from pymatgen.core.structure import Molecule
        from pymatgen.symmetry.analyzer import PointGroupAnalyzer
        from pymatgen.core.operations import SymmOp
        from pymatgen.core.periodic_table import Element

        points = points.cpu().numpy()
        species = [Element.from_Z(int(z_i)) for z_i in z]
        mol = Molecule(species, points)
        pga = PointGroupAnalyzer(mol)
        sym_ops = pga.get_symmetry_operations()

        min_signature = None
        best_mol = mol
        best_op = SymmOp.from_rotation_and_translation(np.eye(3), np.zeros(3))

        for op in sym_ops:
            transformed = mol.copy()
            transformed.apply_operation(op)
            sig = molecule_signature(transformed)

            if min_signature is None or is_lex_smaller(sig, min_signature):
                min_signature = sig
                best_mol = transformed
                best_op = op
        R = torch.tensor(best_op.rotation_matrix, dtype=torch.float32, device=points.device)
    
    elif method == 'inertia_pointgroup':
        # First canonicalize using inertia, then use pointgroup to disambiguate
        from pymatgen.core.structure import Molecule
        from pymatgen.symmetry.analyzer import PointGroupAnalyzer
        from pymatgen.core.operations import SymmOp
        from pymatgen.core.periodic_table import Element
        I = torch.zeros(3, 3, device=points.device)
        for i in range(points.shape[0]):
            r = points[i]
            m = masses[i]
            I += m * (torch.dot(r, r) * torch.eye(3, device=points.device) - torch.outer(r, r))
        
        # Diagonalize inertia tensor
        eigvals, eigvecs = torch.linalg.eigh(I)
        idx = torch.argsort(eigvals)  # ascending
        R = eigvecs[:, idx]

        # Make it a proper rotation
        if torch.det(R) < 0:
            R[:, 2] *= -1 

        points_aligned = points @ R.T
        points_aligned = points.cpu().numpy()
        species = [Element.from_Z(int(z_i)) for z_i in z]
        mol = Molecule(species, points_aligned)
        pga = PointGroupAnalyzer(mol)
        sym_ops = pga.get_symmetry_operations()

        min_signature = None
        best_mol = mol
        best_op = SymmOp.from_rotation_and_translation(np.eye(3), np.zeros(3))

        for op in sym_ops:
            transformed = mol.copy()
            transformed.apply_operation(op)
            sig = molecule_signature(transformed)

            if min_signature is None or is_lex_smaller(sig, min_signature):
                min_signature = sig
                best_mol = transformed
                best_op = op
        R_pg = torch.tensor(best_op.rotation_matrix, dtype=torch.float32, device=points.device)
        R = R_pg @ R  # Combine the two rotations

    elif method == 'hash':
        # Generate a deterministic pseudo-random rotation matrix based on a hash of the input data
        # first canonicalize using inertia
        # then apply a random hash-based rotation
        I = torch.zeros(3, 3, device=points.device)
        for i in range(points.shape[0]):
            r = points[i]
            m = masses[i]
            I += m * (torch.dot(r, r) * torch.eye(3, device=points.device) - torch.outer(r, r))
        
        # Diagonalize inertia tensor
        eigvals, eigvecs = torch.linalg.eigh(I)
        idx = torch.argsort(eigvals)  # ascending
        R = eigvecs[:, idx]

        # Make it a proper rotation
        if torch.det(R) < 0:
            R[:, 2] *= -1
        # then get a random rotation matrix based on a hash of the input data
        R_hash = get_fixed_random_rotation_matrix(data)
        R = R_hash @ R  # Combine the two rotations
    else:
        raise ValueError(f"Unknown method: {method}")

    return R

def canonicalize_qm7(data,method="closest"):
    """
    Canonicalizes the molecule in the QM7 dataset by applying a rotation matrix.
    """
    dev = data.pos.device
    # Step 1: Get the rotation matrix
    R = get_canonical_rotation_matrix_qm7(data, method=method)
    R = torch.tensor(R, dtype=torch.float32, device=dev)

    # Step 2: Apply the rotation matrix to the positions
    rotated_positions = data.pos@R.T#torch.einsum('ij,nj->ni', R, data.pos) #torch.matmul(points, R.T)
    # Step 3: Apply the rotation matrix to the dipole moment
    #print(data.dipole)
    rotated_dipole = data.dipole@R.T#torch.einsum('ij,nj->ni', R, data.dipole) #torch.matmul(points, R.T)
    #print(rotated_dipole)
    rotated_alpha = rotate_symmetric_tensor(data.alpha_tensor,R,dev)
    # rotated quadrupole
    rotated_quadrupole = rotate_symmetric_tensor(data.quadrupole,R,dev)

    data.pos = rotated_positions.clone().detach()
    data.dipole = rotated_dipole.clone().detach()
    data.alpha_tensor = rotated_alpha.clone().detach()
    data.quadrupole = rotated_quadrupole.clone().detach()
    return data

def canonicalize_everything(MetaAugmentedDataset, method='closest'):
    """
    Canonicalizes all molecules in the dataset using the specified method.
    """
    for i, data in enumerate(MetaAugmentedDataset):
        # print('canonicalizing', i)
        data = canonicalize_qm7(data[0], method=method)
        MetaAugmentedDataset[i][0] = data
        MetaAugmentedDataset[i][1] = data.dipole[0,:]
    return MetaAugmentedDataset

# Apply rotation to a molecule
def rotate_molecule(data, quaternion, mode='quaternion'):
    if mode == 'quaternion':
      rotation = R.from_quat(quaternion)
    else:
      rotation = R.from_matrix(quaternion)
    
    dev = data.pos.device
    rotation_matrix_torch = torch.tensor(rotation.as_matrix(), dtype=torch.float32, device=dev)
    # Step 4: Apply the rotation matrix
    rotated_positions = torch.einsum('ij,nj->ni', rotation_matrix_torch, data.pos) #torch.matmul(points, rotation_matrix_torch.T)

    data.pos = rotated_positions.clone().detach() #torch.tensor(rotated_positions, dtype=torch.float32).to(dev)
    return data

def rotate_symmetric_tensor(tensor,R, dev):
    """
    Assumes that the tensor X is stored in xx,yy,zz,xy,xz,yz format.
    Rotates according to X'=RXR^T.
    """
    xx, yy, zz, xy, xz, yz = tensor.squeeze()
    full = torch.tensor([
        [xx, xy, xz],
        [xy, yy, yz],
        [xz, yz, zz],
    ], dtype=torch.float32, device=dev)
    rotated = R@full@R.T

    rotated_6 = torch.tensor([
            rotated[0, 0],  # xx
            rotated[1, 1],  # yy
            rotated[2, 2],  # zz
            rotated[0, 1],  # xy
            rotated[0, 2],  # xz
            rotated[1, 2],  # yz
    ], dtype=torch.float32, device=dev).unsqueeze(0)

    return rotated_6

def get_e3nn_matrix(tensor, dev,to_irreps=True):
    """
    Converts a symmetric tensor in the format [xx, yy, zz, xy, xz, yz] to irreps.
    """
    import e3nn
    tp_target = e3nn.o3.ReducedTensorProducts('ij', i='1o', j='1o',filter_ir_out=[e3nn.o3.Irrep("0e"), e3nn.o3.Irrep("2e")])
    change_of_basis = tp_target.change_of_basis.to(dev)
    if to_irreps: 
        xx, yy, zz, xy, xz, yz = tensor.squeeze()
        mat = torch.tensor([
            [xx, xy, xz],
            [xy, yy, yz],
            [xz, yz, zz],
        ], dtype=torch.float32, device=dev)
        mat_output = torch.einsum('lij,ij -> l',change_of_basis,mat)
        
    else:
        # convert from irreps back to original tensor
        mat_output = torch.einsum('lij,...l-> ij', change_of_basis,tensor)
        return torch.tensor([mat_output[0, 0], mat_output[1, 1], mat_output[2, 2],
                        mat_output[0, 1], mat_output[0, 2], mat_output[1, 2]], device=dev).unsqueeze(0)

    return mat_output


def rotate_qm7_quantities(data, quaternion, mode='quaternion'):
    """
    Rotates all qm7 tensor quantities, note need to check how these are being rotated.
    """
    if mode == 'quaternion':
      rotation = R.from_quat(quaternion)
    else:
      rotation = R.from_matrix(quaternion)
    
    dev = data.pos.device
    rotation_matrix_torch = torch.tensor(rotation.as_matrix(), dtype=torch.float32, device=dev)

    # Step 4: Apply the rotation matrix
    rotated_positions = torch.einsum('ij,nj->ni', rotation_matrix_torch, data.pos) #torch.matmul(points, rotation_matrix_torch.T)
    rotated_dipole = torch.einsum('ij,nj->ni', rotation_matrix_torch,data.dipole)
    # polarizability
    rotated_alpha = rotate_symmetric_tensor(data.alpha_tensor,rotation_matrix_torch,dev)
    # rotated quadrupole
    rotated_quadrupole = rotate_symmetric_tensor(data.quadrupole,rotation_matrix_torch,dev)

    data.pos = rotated_positions.clone().detach()
    data.dipole = rotated_dipole.clone().detach()
    data.alpha_tensor = rotated_alpha.clone().detach()
    data.quadrupole = rotated_quadrupole.clone().detach()
    return data


def canonicalize_molecule(data, method='closest'):
    rotated_positions = canonicalize_points(data.pos, method=method)
    if type(rotated_positions) == torch.Tensor:
        data.pos = rotated_positions
    else:
        data.pos = torch.tensor(rotated_positions, dtype=torch.float32)
    return data

### should actually make this go in the label operator??
def get_canon_with_network(canon_model,data,dataset_type):
    data = preprocess(data,dataset_type=dataset_type)
    outputs = canon_model(data) 
    # assume this outputs 1x0e+1x1e+1x2e for a general matrix
    tp_target = e3nn.o3.ReducedTensorProducts('ij', i=f'1o', j=f'1o')
    change_of_basis = tp_target.change_of_basis
    mat_output = torch.einsum('l,lij -> ij',outputs,change_of_basis)

    # Perform QR decomposition to orthogonalize the matrix
    # is this necessary?
    q, r = torch.linalg.qr(mat_output)

    # Ensure the determinant is 1 (right-handed coordinate system)
    if torch.det(q) < 0:
        q[:, 2] *= -1  # Flip the sign of the third column if the determinant is -1

    return q



def canonicalize_toy_circle(data, method="unit"):
    """
    Normalizes each 2D point to unit norm, then applies a scaling method.

    Parameters:
    - data (tuple of length 2):
        first is a (torch.Tensor): Shape (2,) for a single (x, y) point or (batch, 2) for a batch.
        second is either a single label or a batch of labels
    - method (str): Scaling method. Options:
        - "unit": Keep unit norm.
        - "random": Scale using a deterministic pseudo-random factor in [0,1] based on hashing.
        - "adversarial": Scale to 0.5 if x < 0, and to 1 if x >= 0.

    Returns:
    - torch.Tensor: Scaled points with the same shape as input.
    """

    # Normalize all points to unit norm
    #data = data[0] # ignore the label

    pts, labels = data
    norm = torch.norm(pts, dim=-1, keepdim=True)  # Compute Euclidean norm
    unit_pts = pts / norm  # Normalize to lie on the unit circle

    if method == "unit":
        return unit_pts, labels  # Already normalized, return as is

    elif method == "random":
        # Compute a deterministic pseudo-random scale in [0,1] based on a hash of the input
        hashed = torch.sin((unit_pts * 31.337).sum(dim=-1, keepdim=True))  # Deterministic pseudo-random scaling
        scale = (hashed - hashed.floor())  # Map to [0, 1]
        return unit_pts * scale, labels  # Scale the normalized data

    elif method == "adversarial":
        # Assign scale 0.5 if x < 0, else 1.0
        # change next line if batch to unit_data[:, 0:1]
        # scale = torch.where(unit_data < 0, 0.5, 1.0)  # 
        if unit_pts[0] < 0:
            scale = 0.5
        else:
            scale = 1.0 
        # Extract x, scale accordingly
        return unit_pts * scale, labels  # Scale the normalized data

    else:
        raise ValueError(f"Unknown method: {method}")

def transform_swiss_roll_batch(data, idx):
    # out of date - should include batch of labels
    # but, this fxn is unused, so no need to update yet

     # Suppose data is your (batch, 3) tensor on some device
    # e.g., data = torch.randn(batch_size, 3, device=device)

    # Sample 0 or 1 with equal probability for each batch element
    random_bits = torch.randint(0, 2, (data.shape[0],), device=data.device)

    # Set the last coordinate (i.e., column 2) to the sampled values
    data[:, 2] = random_bits
    return data

def transform_swiss_roll(data, idx):
    # is NOT memoized - different transform each time

    # INCLUDES label

     # Suppose data is your (batch, 3) tensor on some device
    # e.g., data = torch.randn(batch_size, 3, device=device)

    # Sample 0 or 1 with equal probability for each batch element
    coords = data[0]
    random_bit = torch.randint(0, 2, [1]).to(coords.device)

    # Set the last coordinate (i.e., column 2) to the sampled values
    coords[2] = random_bit

    return coords, data[1]


def canonicalize_swiss_roll(data, idx, method=None):
    points, labels = data
    if method == 'vanilla':
        points[:, 2] = 0.0 # vanilla swiss roll
    elif method == 'label':
        points[:, 2] = idx.float() # z the same as label
    else:
        raise ValueError(f"Unknown method: {method}")
    return points, labels


def init_weights(module):
    if isinstance(module, nn.Linear):
        # Example: Kaiming Uniform initialization (non-zero by default)
        nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
        if module.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(module.bias, -bound, bound)




def plot_decision_boundary(model, dataloader, savename, resolution=200, title=None):
    # the problem is that this function is not using the dataloader, which is unfortunately what it should do
    # add tomorrow I suppose

    print('savename', savename)
    model.eval()  # Set model to evaluation mode

    # Define grid range
    x_min, x_max = -1.2, 1.2 #dataset.data[:, 0].min() - 0.2, dataset.data[:, 0].max() + 0.2
    y_min, y_max = -1.2, 1.2 #dataset.data[:, 1].min() - 0.2, dataset.data[:, 1].max() + 0.2
    
    # Create a mesh grid
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution),
                         np.linspace(y_min, y_max, resolution))
    grid_points = np.c_[xx.ravel(), yy.ravel()]
    
    # Convert grid to torch tensor and predict
    grid_tensor = torch.tensor(grid_points, dtype=torch.float32)
    with torch.no_grad():
        grid_preds = model(grid_tensor.to(device)).argmax(dim=1).cpu().numpy()
    
    # Reshape predictions into grid shape
    zz = grid_preds.reshape(xx.shape)

    # Predict labels for dataset points
    all_preds = []
    all_data = []
    for data, label in dataloader:
        all_data.append(data)
        with torch.no_grad():
            data_preds = model(data.to(device)).argmax(dim=1).cpu().numpy() # dataset.data.to(device)
            all_preds.append(data_preds)
    data_preds = np.concatenate(all_preds, 0)
    data = torch.cat(all_data, 0)

    # Plot decision boundary
    plt.figure()
    plt.contourf(xx, yy, zz, alpha=0.3, cmap=plt.cm.coolwarm)

    # Plot dataset points colored by predicted labels
    X = data.numpy() # dataset.data
    plt.scatter(X[:, 0], X[:, 1], c=data_preds, cmap=plt.cm.coolwarm, edgecolors='k')

    theta = np.linspace(0, 2 * np.pi, 1000)
    circle_x = np.cos(theta)
    circle_y = np.sin(theta)

    # Plot the circle as a smooth line
    plt.plot(circle_x, circle_y, color='red', linewidth=2, label='Unit Circle')

    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    if title is None:
        plt.title("Learned Decision Boundary (Points Colored by Predicted Labels)")
    else:
        plt.title(title)
    if savename is None:
        plt.show()
    else:
        plt.savefig(savename)
        plt.close()



def init_weights(module):
    if isinstance(module, nn.Linear):
        # Example: Kaiming Uniform initialization (non-zero by default)
        nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
        if module.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(module.bias, -bound, bound)
    
    # Add additional initializations if needed:
    # elif isinstance(module, nn.Conv2d): ...

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def normalize_quaternions(q):
  nrms = torch.norm(q, dim=1)
  q = q / nrms.unsqueeze(1).repeat(1, q.shape[1])
  return q

def quaternion_loss(q1, q2):
  # q1 and q2 are batch x 4
  q1, q2 = normalize_quaternions(q1), normalize_quaternions(q2)
  inner_prods = torch.abs(torch.sum(q1 * q2, 1)).clamp(0, 1.0)
  acs = torch.arccos(inner_prods)
  return torch.mean(acs)

def normalized_mse(a, b):
  return torch.linalg.norm(a-b) / torch.linalg.norm(a)

def classification_loss(outputs, targets):
  # outputs is batch x num_classes
  # targets is batch, array of integers
  total = targets.size(0)
  _, predicted = torch.max(outputs, 1)
  correct = (predicted == targets).sum().item()
  test_acc = 100 * correct / total
  return test_acc

def classification_loss_per_class(outputs, targets):
    # outputs is batch x num_classes
    # targets is batch, array of integers
    total = targets.size(0)
    _, predicted = torch.max(outputs, 1)
    
    unique_labels = torch.unique(targets)
    acc_dict = {}
    
    for label in unique_labels:
        mask = (targets == label)
        if mask.sum() > 0:  
            class_correct = ((predicted == targets) & mask).sum().item()
            class_total = mask.sum().item()
            class_acc = 100 * class_correct / class_total
            acc_dict[label.item()] = class_acc
    
    return acc_dict



def apply_canon_to_pair(canon_model, data):
    # data is a pair of x, y
    x, y = data
    cx = canon_model(x)
    return cx, y

def apply_canon_just_to_data(canon_model, data):
    # data is "raw", no label
    return canon_model(data)

# no longer used!!
def apply_fwd_pass(model, data, dataset_type="qm9", max_nodes=29,filter_mol=None):
    if 'Transformer' in str(type(model)):
        x, mask = preprocess(data, dataset_type = dataset_type, max_nodes=max_nodes,filter_mol=filter_mol) # x is just node features, nothing edgy
        outputs = model(x, mask=mask)
    elif 'MLP' in str(type(model)):
        x = preprocess_data_for_MLP(data)
        outputs = model(x.view(x.shape[0], -1))
    elif 'GNN' in str(type(model)):
        ### note data.z should probably be one hot encoded
        data.x = torch.cat([data.x,data.pos], dim=1)
        outputs = model(data.x,data.edge_index,data.edge_attr,data.batch)
    elif 'E3' in str(type(model)):
        data = preprocess(data,dataset_type=dataset_type,needs_edge_index=True)
        outputs = model(data)
    elif 'Graphormer' in str(type(model)):
        x,mask = preprocess(data, dataset_type = dataset_type, max_nodes=max_nodes,filter_mol=filter_mol)
        atoms = x[:,:,0]
        pos = x[:,:,1:]
        outputs = model(atoms.long(),pos,mask)  # Forward pass
    else:
        outputs = model(data) # does data include label, or not? 
    return outputs

def ignore_qm9_weird_molecules(base_dataset):
    '''
    Filters out molecules that are not physical in the QM9 dataset. 
    '''
    ### TODO implement this function, or maybe don't need to?
    return NotImplementedError("This function is not implemented yet.")

def normalize_outputs(arr, means, stds):
    # arr is N by d
    # means is 1 x d, stds is 1 x d

    return (arr - means) / stds

def unnormalize_outputs(arr, means, stds): 
    # arr is N by d
    # means is 1 x d, stds is 1 x d
    return (arr * stds) + means

def qm9_regression_mae(outputs, targets, means, stds, normalize=True):
    if not normalize:
        unnorm_outputs = outputs
        unnorm_targets = targets
    else:
        means = means.to(outputs.device)
        stds = stds.to(outputs.device)
        unnorm_outputs = unnormalize_outputs(outputs, means, stds)
        unnorm_targets = unnormalize_outputs(targets, means, stds)
    
    mae_per_batch = torch.abs(unnorm_outputs - unnorm_targets)
    return torch.mean(mae_per_batch, 0)


def qm7_regression_mae(outputs, targets, means, stds, from_irreps=False, target_type="dipole"):
    """
    Computes MAE for QM7 regression, unnormalizing outputs first.
    
    Args:
        outputs: network predictions (normalized)
        targets: true labels (normalized)
        means, stds: mean and std used for normalization
        from_irreps: whether outputs/targets are in e3nn irreps
        target_type: one of "alpha_iso", "alpha_aniso", "dipole", "alpha_tensor", "quadrupole"
    Returns:
        mae: mean absolute error
    """
    # Convert from irreps if needed
    means = means.to(outputs.device)
    stds = stds.to(outputs.device)
    if from_irreps:
        outputs = get_e3nn_matrix(outputs, outputs.device, to_irreps=False)
        targets = get_e3nn_matrix(targets, targets.device, to_irreps=False)

    # Unnormalize
    if target_type in ["alpha_iso", "alpha_aniso"]:
        unnorm_outputs = outputs * stds + means
        unnorm_targets = targets * stds + means
    else:  # vectors/tensors
        unnorm_outputs = outputs * stds  # stds is a scalar for vector/tensor targets
        unnorm_targets = targets * stds

    mae_per_batch = torch.abs(unnorm_outputs - unnorm_targets)
    return torch.mean(mae_per_batch, 0)


def is_rotation_matrix(R: torch.Tensor, atol=1e-5) -> bool:
    """
    Checks if a 3x3 matrix is a valid rotation matrix.
    Args:
        R: Tensor of shape (3, 3)
        atol: Absolute tolerance for floating point comparison
    Returns:
        True if R is a rotation matrix, False otherwise
    """
    if R.shape != (3, 3):
        return False
    should_be_identity = R.T @ R
    I = torch.eye(3, dtype=R.dtype, device=R.device)
    is_orthogonal = torch.allclose(should_be_identity, I, atol=atol)
    has_det_one = torch.allclose(torch.det(R), torch.tensor(1.0, dtype=R.dtype, device=R.device), atol=atol)
    return is_orthogonal and has_det_one

def batch_is_rotation_matrix(R_batch: torch.Tensor, atol=1e-5) -> torch.Tensor:
    """
    Checks if each matrix in a batch is a valid rotation matrix.
    Args:
        R_batch: Tensor of shape (batch_size, 3, 3)
        atol: Absolute tolerance for floating point comparison
    Returns:
        Boolean tensor of shape (batch_size,) where each element is True if the corresponding matrix is a rotation matrix
    """
    batch_size = R_batch.shape[0]
    I = torch.eye(3, dtype=R_batch.dtype, device=R_batch.device)
    is_orthogonal = torch.allclose(R_batch @ R_batch.transpose(1, 2), I[None, :, :], atol=atol)
    dets = torch.det(R_batch)
    has_det_one = torch.allclose(dets, torch.ones_like(dets), atol=atol)
    return is_orthogonal and has_det_one  # returns bool not tensor

# If you want to return per-matrix bool tensor, instead do:
def batch_is_rotation_matrix_elementwise(R_batch: torch.Tensor, atol=1e-5) -> torch.Tensor:
    """
    Elementwise check: returns a boolean tensor indicating which matrices are valid rotation matrices.
    """
    I = torch.eye(3, dtype=R_batch.dtype, device=R_batch.device).expand(R_batch.shape[0], -1, -1)
    orthogonality_check = torch.allclose(R_batch @ R_batch.transpose(1, 2), I, atol=atol)
    det_check = torch.isclose(torch.det(R_batch), torch.tensor(1.0, dtype=R_batch.dtype, device=R_batch.device), atol=atol)
    return orthogonality_check & det_check

def average_rotation_angle(rot1: torch.Tensor, rot2: torch.Tensor, to_float=True) -> torch.Tensor:
    """
    Computes the average rotation angle (in radians) between corresponding rotation matrices.
    
    Args:
        rot1: Tensor of shape (B, 3, 3)
        rot2: Tensor of shape (B, 3, 3)
    
    Returns:
        Scalar tensor: average angle in radians
    """
    assert rot1.shape == rot2.shape and rot1.shape[1:] == (3, 3), "Inputs must be B x 3 x 3 rotation matrices"
    
    R_rel = torch.matmul(rot1.transpose(1, 2), rot2)  # shape: (B, 3, 3)
    trace = R_rel[:, 0, 0] + R_rel[:, 1, 1] + R_rel[:, 2, 2]  # shape: (B,)
    cos_theta = (trace - 1) / 2
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
    angles = torch.acos(cos_theta)  # shape: (B,)
    degree_angles = angles * 180 /(2 * torch.pi)

    mn = degree_angles.mean()
    if to_float:
        return float(mn.data)
    else:
        return mn

def rotation_error(outputs, targets, check_rot=False, to_float=True):
    # outputs and targets are each assumed to be batch x 9
    # have been output by gram-schmidt process
    batch, d = outputs.shape
    if d != 9:
        print('Error: outputs is shape', outputs.shape, 'but was expecting batch x 9')
    rot1 = outputs.reshape(batch, 3, 3)
    rot2 = targets.reshape(batch, 3, 3)

    # make sure they are rotation matrices
    if check_rot:
        print('outputs are rotations?', batch_is_rotation_matrix(rot1))
        print('targets are rotations?', batch_is_rotation_matrix(rot2))
    
    return average_rotation_angle(rot1, rot2, to_float=to_float)


def check_equivariance(train_loader, model):

    counter = 0
    for batch in train_loader:
        data, expected = batch
        data, expected = data.to(device), expected.to(device)
        outputs = apply_fwd_pass(model, data)
    
        # canon_data = canonicalize_molecule(data) # change this to apply a random rotation!
        # outputs_on_canon = apply_fwd_pass(model, canon_data)

        quaternion = generate_random_quaternion()
        rot_data = rotate_molecule(data, quaternion, mode='quaternion')
        outputs_on_rot = apply_fwd_pass(model, rot_data)
    
        print(f'Equivariance Difference: {torch.linalg.norm(outputs - outputs_on_rot)}')
        counter += 1
        if counter > 10:
            break

def normalize_outputs(arr, means, stds):
    # arr is N by d
    # means is 1 x d, stds is 1 x d

    return (arr - means) / stds

def unnormalize_outputs(arr, means, stds): 
    # arr is N by d
    # means is 1 x d, stds is 1 x d
    return (arr * stds) + means
'''
def qm9_regression_mae(outputs, targets, means, stds):
    means = means.to(outputs.device)[:,7:8]
    stds = stds.to(outputs.device)[:,7:8]
    unnorm_outputs = unnormalize_outputs(outputs, means, stds)
    unnorm_targets = unnormalize_outputs(targets, means, stds)
    mae_per_batch = torch.abs(unnorm_outputs - unnorm_targets)
    return torch.mean(mae_per_batch, 0)
'''

def smush_for_MLP(tensor_list):
    """
    Flattens and concatenates a list/tuple of tensors along the batch dimension,
    preparing for input to an MLP.

    Args:
        tensor_list (list or tuple of torch.Tensor): Each tensor must have shape (B, ...),
            and all tensors must share the same batch size B.

    Returns:
        torch.Tensor of shape (B, D), where D is the sum of flattened dimensions.
    """
    if not isinstance(tensor_list, (list, tuple)):
        raise TypeError("Input must be a list or tuple of tensors.")

    batch_size = tensor_list[0].shape[0]
    flattened = []

    for t in tensor_list:
        if not isinstance(t, torch.Tensor):
            raise TypeError("All elements must be torch.Tensors.")
        if t.shape[0] != batch_size:
            raise ValueError("All tensors must have the same batch size in dim 0.")
        flattened.append(t.reshape(batch_size, -1))

    return torch.cat(flattened, dim=1)

# Data Preparation
def preprocess(data, dataset_type="qm9", max_nodes=29, filter_mol='None', needs_edge_index=False):
    """
    Converts data to padded batch format with mask. 
    Note the edge_index for if it's a graph based model with the QM9 dataset.
    """

    if dataset_type == "qm9_e3nn" or (dataset_type == "qm9_atomic" and needs_edge_index):
        import e3tools
        data.edge_index = e3tools.radius_graph(data.pos, 5.0, data.batch)
        return data

    elif dataset_type == "qm9_point_clouds":
        pos, mask = to_dense_batch(data.pos, batch=data.batch, max_num_nodes=max_nodes)
        return pos, mask

    elif dataset_type in ["qm9", "local_qm9", "qm9_atomic"]:
        node_feats = torch.cat((data.z[:, None], data.pos), dim=1)
        node_features, mask = to_dense_batch(node_feats, batch=data.batch, max_num_nodes=max_nodes)
        return node_features, mask
    
    elif dataset_type == "qm7":
        node_feats = torch.cat((data.z[:,None], data.pos), dim=1)
        node_features, mask = to_dense_batch(node_feats, batch=data.batch, max_num_nodes=max_nodes)
        return node_features, mask

    elif dataset_type == "md17":
        node_features = torch.cat((data.z[:, None], data.pos), dim=1)
        node_features, mask = to_dense_batch(node_features, batch=data.batch)
        return node_features, mask

    elif dataset_type == "oc20":
        if filter_mol != 'None':
            mask = data.tags == filter_mol
            pos = data.pos[mask]
            atom_types = data.atomic_numbers[mask]
            batch = data.batch[mask]
        else:
            pos = data.pos
            atom_types = data.atomic_numbers
            batch = data.batch

        node_features = torch.cat((atom_types[:, None], pos), dim=1)
        node_features, mask = to_dense_batch(node_features, batch=batch)
        return node_features, mask

    elif dataset_type == "toycircle":
        return data
    elif dataset_type == 'modelnet':
        pos, mask = to_dense_batch(data.pos, batch=data.batch, max_num_nodes=1024)
        return pos, mask
    else:
        raise ValueError(f"unsupported dataset type: {dataset_type}")

def pad_features(data, max_nodes, input_dim):
    # this needs to be fixed to work with a batch...
    """
    Pad the node features of a molecule to have a fixed number of nodes.
    """
    num_nodes = data.x.size(0)  # Number of nodes in the graph
    padded_features = torch.zeros((max_nodes, input_dim))
    padded_features[:num_nodes, :] = data.x
    return padded_features

# Custom transform to prepare data for the MLP
def preprocess_data_for_MLP(data, max_nodes=29, input_dim=11):
    """
    Preprocess a qm9 sample to pad features and remove edge information.
    """
    padded_x = pad_features(data, max_nodes=max_nodes, input_dim=input_dim)
    return padded_x


def new_empty_dict(existing_dict, value=[]): 
    new_dict = {}
    for ky, _ in existing_dict.items():
        new_dict[ky] = value
        if hasattr(value, 'copy'):
            new_dict[ky] = value.copy()
        else:
            new_dict[ky] = value  
    return new_dict

def simple_loss_plot(dct, savename, show=False, xlabel='Epoch', ylabel='Loss'):
    plt.figure()
    try:
        for ky, val in dct.items():
            if len(val) == 0:
                continue
            if isinstance(val, list) and isinstance(val[0], torch.Tensor):
                # Check if tensor has more than one element
                if val[0].numel() > 1:
                    # For multi-output case, choose the 7th if it exists
                    # This is how it was before, probably should change this
                    if val[0].numel() > 7:
                        val = [elt.cpu()[7].item() for elt in val]
                    else:
                        print(f"⚠️ Skipping key '{ky}' – tensor too short for index 7")
                        continue
                else:
                    val = [elt.item() for elt in val]
            plt.plot(val, label=ky)
    except Exception as e:
        print(val)
        print('plotting failed', e)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.savefig(savename)
    if show:
        plt.show()


def plot_results(results, save_prefix, aux_criteria, show=False):
    loss_dict = {'train': results['train_losses'], 'val': results['val_losses'], 'test': results['test_losses']}
    simple_loss_plot(loss_dict, savename=f'{save_prefix}_losses.pdf')

    for ky in aux_criteria.keys():
        loss_dict = {'train': results[f'train_{ky}'], 'val': results[f'val_{ky}'], 'test': results[f'test_{ky}']}
        simple_loss_plot(loss_dict, savename=f'{save_prefix}_{ky}.pdf', ylabel=ky)

def get_timestamp_for_filename():
    """
    Get the current date and time as a string suitable for use in file names.
    Returns:
        str: A timestamp string in the format YYYY-MM-DD_HH-MM-SS.
    """
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def get_qm9_ind(target = "U0"):
    target_names = [
    "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve",
    "U0", "U", "H", "G", "Cv",
    "U0_atom", "U_atom", "H_atom", "G_atom", "A", "B", "C",
    ]
    if target in target_names:
        property_to_index = {name: i for i, name in enumerate(target_names)}
        prop_idx = property_to_index[target]
        return prop_idx
    else:
        raise ValueError(f"Unknown target: {target}. Available targets are: {target_names}")
    
def get_qm9_means_stds(qm9dataset,target="All"):
    try:
        if target == "All":
            means = qm9dataset.y.mean(dim=0, keepdim=True)
            stds = qm9dataset.y.std(dim=0, keepdim=True)
        else:
            prop_idx = get_qm9_ind(target)
            means = qm9dataset.y[:, prop_idx].mean(dim=0, keepdim=True).view(1,1)
            stds = qm9dataset.y[:, prop_idx].std(dim=0, keepdim=True).view(1,1)
    except:
        means = torch.zeros((1, 19)) if target == "All" else torch.zeros((1, 1))
        stds = torch.zeros((1, 19)) if target == "All" else torch.zeros((1, 1))
    return means, stds

def get_qm9_atomic_means_stds(qm9dataset, target="All"):
    included_idxs, excluded_idxs = qm9dataset.included_idxs, qm9dataset.excluded_idxs
    try:
        all_y = []
        for ind in included_idxs:
            data = qm9dataset[ind]
            y = data.y
            if y.dim() == 1:
                y = y.view(1, -1)
            all_y.append(y)
        all_y = torch.cat(all_y, dim=0)  # shape: [N, 19] or [N, 20]

        if target == "All":
            means = all_y.mean(dim=0, keepdim=True)
            stds = all_y.std(dim=0, keepdim=True)
        else:
            prop_idx = get_qm9_ind(target)
            means = all_y[:, prop_idx].mean(dim=0, keepdim=True).view(1,1)
            stds = all_y[:, prop_idx].std(dim=0, keepdim=True).view(1,1)
    except Exception as e:
        print(f"[Warning] Failed to compute stats: {e}")
        if target == "All":
            means = torch.zeros((1, all_y.shape[1] if 'all_y' in locals() else 19))
            stds = torch.zeros((1, all_y.shape[1] if 'all_y' in locals() else 19))
        else:
            means = torch.zeros((1, 1))
            stds = torch.zeros((1, 1))

    return means, stds

def get_qm7_means_stds(qm7dataset, target="dipole"):
    """
    Computes mean and std for a specified target in QM7 dataset.
    - Scalars: per-component mean/std.
    - Vectors/tensors: scalar norm mean/std for equivariant-safe scaling.
    """
    if target not in ["alpha_iso", "alpha_aniso", "alpha_tensor", "dipole", "quadrupole"]:
        raise ValueError(f"Unsupported target: {target}")

    values = []

    for data in qm7dataset:
        val = getattr(data, target)

        if target in ["dipole", "alpha_tensor", "quadrupole"]:
            # Flatten tensor/vector to 1D
            val_flat = val.view(-1)
            # Compute norm as scalar
            val_scalar = val_flat.norm()
            values.append(val_scalar)
        else:
            # Scalar target: keep as is
            if val.dim() == 0:
                val = val.unsqueeze(0)
            elif val.dim() == 2 and val.shape[0] == 1:
                val = val.squeeze(0)
            values.append(val)

    all_values = torch.stack(values, dim=0)  # Shape: (N,) for vector/tensor, (N,D) for scalars

    mean = all_values.mean(dim=0, keepdim=True)
    std = all_values.std(dim=0, keepdim=True)

    return mean, std



def get_md17_means_stds(md17dataset):
    means = md17dataset.data.energy.mean(dim=0, keepdim=True)
    stds = md17dataset.data.energy.std(dim=0, keepdim=True)
    return means, stds

def get_oc20_means_stds(oc20dataset):
    means = oc20dataset.energy.mean(dim=0, keepdim=True)
    stds = oc20dataset.energy.std(dim=0, keepdim=True)
    return means, stds

def get_label_g_from_gx(data, transform_return_g_operator):
    # does NOT support memoization, only randomized transforms, because idx isn't passed in to the label operator
    transformed_data, g = transform_return_g_operator(data, -1)

    dev = data.pos.device
    rotation_matrix_vectorized_torch = torch.tensor(g.as_matrix(), dtype=torch.float32, device=dev).reshape(-1)

    return transformed_data, rotation_matrix_vectorized_torch # may want to change this

def get_label_from_qm9_element(data, means, stds, target = "All"): # 4:16
    # adding predicting a particular target or not
    if target == "All":
        label = (data.y[0, :] - means[0, :]) / stds[0, :]
        return data, label # previously was just label
    else:
        prop_idx = get_qm9_ind(target)
        label = (data.y[0, prop_idx] - means[0,:]) / stds[0,:]
        return data, label

def get_label_from_modelnet_element(data):
    # print('data', data)
    label = data.y
    label = label.squeeze()
    return data, label


def get_label_from_qm9_element_atomic(data, means, stds, target="All"):
    if target == "All":
        label = (data.y - means) / stds
    else:
        prop_idx = get_qm9_ind(target)
        label = (data.y[:, prop_idx] - means[0, :]) / stds[0, :]
    return data, label

def get_label_from_qm7_element(data, means, stds, target="dipole", to_irreps=False):
    """
    Returns the target label for QM7 data with equivariant-safe normalization.
    
    Args:
        data: a torch_geometric.data.Data object
        means, stds: mean and std computed from the dataset
        target: "alpha_iso", "alpha_aniso", "alpha_tensor", "dipole", "quadrupole"
        to_irreps: whether to convert tensors to e3nn irreps
    Returns:
        data, label: the original data and normalized label
    """
    if target not in ["alpha_iso", "alpha_aniso", "alpha_tensor", "dipole", "quadrupole"]:
        raise ValueError(f"Unsupported target: {target}")

    val = getattr(data, target)

    # Scalars: scale normally
    if target in ["alpha_iso", "alpha_aniso"]:
        label = (val - means) / stds
        label = label[0,:]

    # Tensors or vectors
    else:
        # Flatten if needed
        val_flat = val.view(-1)
        
        # Apply equivariant-safe scaling: divide by scalar std
        label = val_flat / stds  # stds is a scalar for vector/tensor targets

        # Convert to irreps if requested
        if to_irreps and target in ["alpha_tensor", "quadrupole"]:
            label = get_e3nn_matrix(label, data.pos.device, to_irreps=True)

    return data, label


def get_label_from_md17_element(data, means, stds):
    label = (data.energy - means) / stds
    return data, label

def get_label_from_oc20_element(data, idx):
    label = data[idx].energy
    return data, label

def get_label_from_toy_circle_element(data):
    label = data[1]
    return data[0], label


def get_label_from_swiss_roll_element(data):
    label = data[1]
    return data[0], label

###TODO maybe combine these rotation functions into one with equivariant argument
def memoized_rotation(data, idx, fixed_quaternions):
    qlen = len(fixed_quaternions)
    if idx > qlen:
        idx = random.randint(0, qlen - 1)
    new_data = rotate_molecule(data, fixed_quaternions[idx % qlen])
    return new_data

def randomized_rotation(data, idx, fixed_quaternions, return_g=False):
    total = len(fixed_quaternions)
    id_use = int(torch.randint(total, (1,)))
    new_data = rotate_molecule(data, fixed_quaternions[id_use])
    if return_g:
        return new_data, R.from_quat(fixed_quaternions[id_use])
    else:
        return new_data

def randomized_rotation_qm7(data, idx, fixed_quaternions):
    total = len(fixed_quaternions)
    id_use = int(torch.randint(total, (1,)))
    new_data = rotate_qm7_quantities(data, fixed_quaternions[id_use])
    return new_data


def center_coords(coords):
    center = torch.sum(coords, 0) / coords.shape[0]
    centered_coords = coords - center.view(1,3)
    return centered_coords

def compute_PCA(coords, center=True):
    # n x 3
    if center:
        coords = center_coords(coords)
    XXT = torch.matmul(coords.permute(1,0), coords)
    U,S,V = torch.svd(XXT)
    return U, S, V

def is_planar(coords, cutoff=1e-3):
    U, S, V = compute_PCA(coords)
    if torch.min(S) < cutoff:
        return True, S
    else:
        return False, S

def to_rdkit_molecule(data: torch_geometric.data.Data) -> Chem.Mol:
    """
    Convert a PyTorch Geometric graph to an RDKit molecule using only atom types
    and positions (no edge/bond information).

    Args:
        data: A PyTorch Geometric Data object containing:
            - species: Node features tensor where each row represents an atom type
            - pos: Node position coordinates (N x 3 tensor)

    Returns:
        mol: RDKit Molecule object
    """
    # Create empty editable mol object
    mol = Chem.RWMol()

    # Add atoms
    atomic_numbers = data["z"]
    atom_idxs = []
    for atomic_num in atomic_numbers:
        atom = Chem.Atom(int(atomic_num))
        idx = mol.AddAtom(atom)
        atom_idxs.append(idx)

    # Convert to non-editable molecule
    mol = mol.GetMol()

    # Create a conformer to store 3D positions
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i, position in enumerate(data["pos"]):
        x, y, z = position.tolist()
        conf.SetAtomPosition(i, (float(x), float(y), float(z)))

    mol.AddConformer(conf)

    # add bonds as edges?
    mol = Chem.RWMol(mol)
    num_edges = int(data.edge_index.shape[1])
    for e in range(num_edges):
        a, b = data.edge_index[0,e], data.edge_index[1,e]
        if a<b:
            mol.AddBond(atom_idxs[a], atom_idxs[b], Chem.BondType.SINGLE)

    
    return mol

def qm9_species_to_atomic_numbers(z):
    qm9_atomic_numbers = np.array([1, 6, 7, 8, 9])
    return qm9_atomic_numbers[z]

def visualize_qm9_atomic_data(data, savename=None):
    from openbabel import pybel
    # Prepare the XYZ string exactly as before
    # note the QM9 atomic dataset has species instead of the raw atomic numbers
    atom_numbers = qm9_species_to_atomic_numbers(data.z)
    xyz_lines = [f"{len(data.z)}\n\n"]
    for z_val, coord in zip(atom_numbers, data.pos):
        sym = Chem.GetPeriodicTable().GetElementSymbol(int(z_val))
        xyz_lines.append(f"{sym} {coord[0]:.4f} {coord[1]:.4f} {coord[2]:.4f}\n")
    xyz_str = "".join(xyz_lines)

    mol_ob = pybel.readstring("xyz", xyz_str)
    mol_block_str = mol_ob.write("mol")
    mol = Chem.MolFromMolBlock(mol_block_str, removeHs=False)

    # Calculate formula
    formula = rdMolDescriptors.CalcMolFormula(mol)

    view = py3Dmol.view(
        data=Chem.MolToMolBlock(mol),
        style={"stick": {}, "sphere": {"scale": 0.3}}
    )
    view.zoomTo()

    if savename is not None:
        html = f"<h3>Molecular formula: {formula}</h3>\n" + view._make_html()
        with open(savename, "w") as f:
            f.write(html)
    else:
        print(f"Molecular formula: {formula}")
        view.show()

def visualize_qm9_data(data, savename=None):
    mol = to_rdkit_molecule(data)
    view = py3Dmol.view(
        data=Chem.MolToMolBlock(mol),  # Convert the RDKit molecule for py3Dmol
        style={"stick": {}, "sphere": {"scale": 0.3}}
    )
    view.zoomTo()

    if savename is not None:
        view.show()
        with open(savename, 'wb') as f:
            f.write(view.png().decode('base64'))
        
    else:
        view.show()

def get_neighborhood_data(data, node_idx, idx_to_use=None, verbose=False, center=True):    
    
    # Extract the local subgraph
    mask = (data.edge_index[0] == node_idx) | (data.edge_index[1] == node_idx)
    neighbor_nodes = data.edge_index[:, mask].unique()
    neighbor_nodes_bool = torch.zeros(data.x.shape[0]).bool()
    neighbor_nodes_bool[neighbor_nodes] = True
    
    # Get the subgraph of neighbors
    sub_edge_index, node_map = subgraph(neighbor_nodes, data.edge_index, relabel_nodes=True)
    
    # -----------------------------
    # Step 2: Extract Node and Edge Features
    # -----------------------------
    # Subselect node features
    sub_x = data.x[neighbor_nodes]  # Keep only the relevant nodes' features
    
    # Subselect edge features based on the new subgraph
    edge_mask = mask.nonzero().squeeze()  # Indices of edges in the original graph
    sub_edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None
    
    # -----------------------------
    # Step 3: Create New Data Object
    # -----------------------------
    # sub_data = copy.deepcopy(data)
    # sub_data.x = sub_x
    # sub_data.edge_index = sub_edge_index
    # sub_data.edge_attr = sub_edge_attr
    # sub_data.pos = data.pos[neighbor_nodes, :]
    # sub_data.z = data.z[neighbor_nodes]

    if idx_to_use == None:
        idx_to_use = data.idx 
    else:
        idx_to_use = idx_to_use
    if center:
        sub_pos = center_coords(data.pos[neighbor_nodes, :])
    else:
        sub_pos = data.pos[neighbor_nodes, :]
    sub_data = torch_geometric.data.Data(
        x=sub_x, 
        edge_index=sub_edge_index, 
        edge_attr=sub_edge_attr,
        pos=sub_pos,
        z=data.z[neighbor_nodes],
        smiles=data.smiles,
        name=data.name,
        idx=idx_to_use,
        y=data.y  # Keep the same target label
    )

    # Data(x=sub_x, edge_index=sub_edge_index, edge_attr=sub_edge_attr, pos=, z=, y=data.y)
    
    if verbose:
        print("Neighbors:", neighbor_nodes.tolist())
        print("Subgraph Edges:\n", sub_edge_index)

    return sub_data

def extract_neighborhood_graphs(data, k=5):
    """Extracts local neighborhood graphs from a molecule graph."""
    graphs = []
    num_nodes = data.x.size(0)

    for node_idx in range(num_nodes):
        # Find the k-hop subgraph for the given node
        edge_index, _, mask = torch_geometric.utils.subgraph(subset=[node_idx], edge_index=data.edge_index, relabel_nodes=True, return_edge_mask=True)

        print('edge_index', edge_index.shape, 'mask', mask.shape)
        
        # Create a new data object for the subgraph
        sub_data = torch_geometric.data.Data(
            x=data.x[mask], 
            edge_index=edge_index, 
            y=data.y  # Keep the same target label
        )
        graphs.append(sub_data)

    return graphs


def get_neighborhood_data_atomic_datasets(data, node_idx, idx_to_use=None, verbose=False, center=True):    
    
    # Extract the local subgraph
    mask = (data.edge_index[0] == node_idx) | (data.edge_index[1] == node_idx)
    neighbor_nodes = data.edge_index[:, mask].unique()
    neighbor_nodes_bool = torch.zeros(data.pos.shape[0]).bool()
    neighbor_nodes_bool[neighbor_nodes] = True
    
    # Get the subgraph of neighbors
    sub_edge_index, node_map = subgraph(neighbor_nodes, data.edge_index, relabel_nodes=True)
    
    # -----------------------------
    # Step 2: Extract Node and Edge Features
    # -----------------------------
    # Subselect node features
    #sub_x = data.x[neighbor_nodes]  # Keep only the relevant nodes' features
    
    # Subselect edge features based on the new subgraph
    edge_mask = mask.nonzero().squeeze()  # Indices of edges in the original graph
    #sub_edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None
    
    # -----------------------------
    # Step 3: Create New Data Object
    # -----------------------------
    # sub_data = copy.deepcopy(data)
    # sub_data.x = sub_x
    # sub_data.edge_index = sub_edge_index
    # sub_data.edge_attr = sub_edge_attr
    # sub_data.pos = data.pos[neighbor_nodes, :]
    # sub_data.z = data.z[neighbor_nodes]

    if idx_to_use == None:
        idx_to_use = data.idx 
    else:
        idx_to_use = idx_to_use
    if center:
        sub_pos = center_coords(data.pos[neighbor_nodes, :])
    else:
        sub_pos = data.pos[neighbor_nodes, :]
    sub_data = torch_geometric.data.Data(
        #x=sub_x, 
        edge_index=sub_edge_index, 
        #edge_attr=sub_edge_attr,
        pos=sub_pos,
        z=data.z[neighbor_nodes],
        #smiles=data.smiles,
        #name=data.name,
        idx=idx_to_use,
        y=data.y  # Keep the same target label
    )

    # Data(x=sub_x, edge_index=sub_edge_index, edge_attr=sub_edge_attr, pos=, z=, y=data.y)
    
    if verbose:
        print("Neighbors:", neighbor_nodes.tolist())
        print("Subgraph Edges:\n", sub_edge_index)

    return sub_data

def extract_neighborhood_graphs(data, k=5):
    """Extracts local neighborhood graphs from a molecule graph."""
    graphs = []
    num_nodes = data.x.size(0)

    for node_idx in range(num_nodes):
        # Find the k-hop subgraph for the given node
        edge_index, _, mask = torch_geometric.utils.subgraph(subset=[node_idx], edge_index=data.edge_index, relabel_nodes=True, return_edge_mask=True)

        print('edge_index', edge_index.shape, 'mask', mask.shape)
        
        # Create a new data object for the subgraph
        sub_data = torch_geometric.data.Data(
            x=data.x[mask], 
            edge_index=edge_index, 
            y=data.y  # Keep the same target label
        )
        graphs.append(sub_data)

    return graphs


def remove_uncharacterized_molecules(
    root_dir: str,
):
    """Remove molecules from the Atomic QM9 dataset that are uncharacterized.
    Removing file clean-up in case multiple runs are done in parallel (may encounter file lock errors).
    """
    from atomic_datasets.utils import download_url
    def is_int(string: str) -> bool:
        try:
            int(string)
            return True
        except:
            return False

    print("Dropping uncharacterized molecules.")
    gdb9_url_excluded = "https://springernature.figshare.com/ndownloader/files/3195404"
    gdb9_txt_excluded = download_url(gdb9_url_excluded, root_dir)

    # First, get list of excluded indices.
    excluded_strings = []
    with open(gdb9_txt_excluded) as f:
        lines = f.readlines()
        excluded_strings = [line.split()[0] for line in lines if len(line.split()) > 0]

    excluded_idxs = [int(idx) - 1 for idx in excluded_strings if is_int(idx)]

    assert len(excluded_idxs) == 3054, (
        f"There should be exactly 3054 excluded molecule. Found {len(excluded_idxs)}"
    )

    # Now, create a list of included indices.
    Ngdb9 = 133885
    included_idxs = np.array(sorted(list(set(range(Ngdb9)) - set(excluded_idxs))))
    return included_idxs, excluded_idxs

def parse_qm7_xyz_to_data(filepath):
    """
    Reads a .xyz file from https://www.nature.com/articles/s41597-019-0157-8#Tab2 
    at the CCSD level of theory containing polarizability, dipole, and quadrupole moments
    into a torch_geometric.data.Data object for further processing.
    """
    from ase.data import atomic_numbers

    with open(filepath) as f:
        lines = f.readlines()

    num_atoms = int(lines[0])
    prop_line = lines[1].strip().split(',')

    props = list(map(float, prop_line[1:]))  # Skip the tag "Properties"

    # Polarizabilities
    alpha_iso = torch.tensor([props[0]], dtype=torch.float32)
    alpha_aniso = torch.tensor([props[1]], dtype=torch.float32)

    # Polarizability tensor as (1,6): [xx, yy, zz, xy, xz, yz]
    alpha_tensor = torch.tensor(props[2:8], dtype=torch.float32).unsqueeze(0)
    #torch.tensor([[props[2], props[3], props[4], props[5], props[6], props[7]]], dtype=torch.float32)

    # Quadrupole tensor as (1,6): [xx, yy, zz, xy, xz, yz]
    quadrupole_tensor = torch.tensor(props[11:17], dtype=torch.float32).unsqueeze(0)
    #quadrupole_tensor = torch.tensor([[props[11], props[12], props[13], props[14], props[15], props[16]]], dtype=torch.float32)


    # Dipole moment μ
    dipole_vector = torch.tensor(props[8:11], dtype=torch.float32).unsqueeze(0)

    # Atom data
    atom_lines = lines[2:2 + num_atoms]
    positions = []
    atomic_nums = []

    for line in atom_lines:
        parts = line.strip().split()
        symbol = parts[0]
        coords = list(map(float, parts[1:4]))

        atomic_nums.append(atomic_numbers[symbol])
        positions.append(coords)

    pos = torch.tensor(positions, dtype=torch.float32)
    z = torch.tensor(atomic_nums, dtype=torch.long)#.unsqueeze(1)
    name = os.path.splitext(os.path.basename(filepath))[0]

    data = torch_geometric.data.Data(
        z=z,
        pos=pos,
        alpha_iso=alpha_iso,
        alpha_aniso=alpha_aniso,
        alpha_tensor=alpha_tensor,
        dipole=dipole_vector,
        quadrupole=quadrupole_tensor,
    )
    data.name=name
    return data


import py3Dmol
import math
import base64
from rdkit import Chem
from io import BytesIO
from PIL import Image
import numpy as np
import pickle
import torch
import os
import utils
import wandb
from collections.abc import Mapping, Sequence

def mol_with_offset(mol, offset):
    """Return a new MOL block string with coordinates translated by offset."""
    conf = mol.GetConformer()
    for i in range(mol.GetNumAtoms()):
        pos = conf.GetAtomPosition(i)
        conf.SetAtomPosition(i, pos + Chem.rdGeometry.Point3D(*offset))
    return Chem.MolToMolBlock(mol)
    
def visualize_qm9_data_grid(data_list, savename=None, grid_cols=8, spacing=5.0):
    assert len(data_list) > 0, "data_list must contain at least one molecule"

    grid_rows = math.ceil(len(data_list) / grid_cols)
    width = 250 * grid_cols
    height = 250 * grid_rows
    view = py3Dmol.view(width=width, height=height)

    for idx, data in enumerate(data_list):
        mol = utils.to_rdkit_molecule(data)

        row = idx // grid_cols
        col = idx % grid_cols
        offset = np.array([col * spacing, -row * spacing, 0.0])  # Negative row to go down visually

        mol_block = mol_with_offset(mol, offset)
        view.addModel(mol_block, "mol")
        view.setStyle({'model': idx}, {"stick": {}, "sphere": {"scale": 0.3}})

    view.zoomTo()
    view.setBackgroundColor("white")

    import time
    time.sleep(10)

    if savename is not None:
        img_data = view.png()
        image = Image.open(BytesIO(base64.b64decode(img_data)))
        image.save(savename)
    else:
        view.show()
    return data
