import torch
print("PyTorch version:", torch.__version__)
print("CUDA version used by PyTorch:", torch.version.cuda)

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
import matplotlib.pyplot as plt
from torch_geometric.datasets import MD17
import pandas as pd

import utils 
from models import TransformerModel
from datasets import DetectRotatedMD17Dataset

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

import os
from ase import Atoms
# Map atomic numbers to masses
atom_type_map = {
    1: 'H',   # Hydrogen
    6: 'C',   # Carbon
    7: 'N',   # Nitrogen
    8: 'O',   # Oxygen
    16: 'S',  # Sulfur (common in organic molecules)
    9: 'F',   # Fluorine (if applicable)
    17: 'Cl', # Chlorine (if applicable)
    35: 'Br', # Bromine (if applicable)
    53: 'I'   # Iodine (if applicable)
}

atomic_masses = {
    1: 1.00784,   # Hydrogen
    6: 12.0107,   # Carbon
    7: 14.007,    # Nitrogen
    8: 15.999,    # Oxygen
    16: 32.065,   # Sulfur
    9: 18.998,    # Fluorine
    17: 35.45,    # Chlorine
    35: 79.904,   # Bromine
    53: 126.904   # Iodine
}


def compute_moments_of_inertia(positions, masses):
    """
    Compute the principal moments of inertia of a molecule.

    Parameters:
    - positions: (N, 3) array of Cartesian coordinates of the atoms.
    - masses: (N,) array of atomic masses.

    Returns:
    - moments: Sorted array of principal moments of inertia (I1, I2, I3).
    """
    # Center the positions relative to the center of mass
    positions = positions.copy()  # Avoid modifying the input array
    center_of_mass = np.sum(positions.T * masses, axis=1) / np.sum(masses)
    positions -= center_of_mass

    # Compute the inertia tensor
    I = np.zeros((3, 3))
    for i, mass in enumerate(masses):
        x, y, z = positions[i]
        I[0, 0] += mass * (y**2 + z**2)
        I[1, 1] += mass * (x**2 + z**2)
        I[2, 2] += mass * (x**2 + y**2)
        I[0, 1] -= mass * x * y
        I[0, 2] -= mass * x * z
        I[1, 2] -= mass * y * z
    
    # Symmetrize the off-diagonal elements
    I[1, 0] = I[0, 1]
    I[2, 0] = I[0, 2]
    I[2, 1] = I[1, 2]

    # Compute eigenvalues (principal moments of inertia)
    moments = np.linalg.eigvalsh(I)
    return np.sort(moments)

def plot_moment_hist(all_moments, title, filename, scale='log'):
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.preprocessing import StandardScaler, MinMaxScaler
    
    # Convert moments of inertia to a numpy array
    all_moments = np.array(all_moments)  # Shape: (n_frames, 3)

    # Apply scaling if needed
    if scale == 'standardize':
        scaler = StandardScaler()
        all_moments = scaler.fit_transform(all_moments)
    elif scale == 'minmax':
        scaler = MinMaxScaler()
        all_moments = scaler.fit_transform(all_moments)
    elif scale == 'log':
        all_moments = np.log1p(all_moments)

    # Plot histograms for each principal moment
    plt.figure(figsize=(10, 6))

    for i in range(3):
        plt.hist(all_moments[:, i], bins=30, alpha=0.6, label=f"I{i+1}",density=True)

    plt.xlabel("Moment of Inertia (log scaled)")
    plt.ylabel("Frequency")
    plt.title(title)
    plt.legend()
    plt.savefig(filename)
    plt.close()

def main():
  molecules=[
    "revised benzene",
    "revised uracil",
    "revised naphthalene",
    "revised aspirin",
    "revised salicylic acid",
    "revised malonaldehyde",
    "revised ethanol",
    "revised toluene",
    "revised paracetamol",
    "revised azobenzene"]
  for molecule in molecules:
    print(f"plotting {molecule}")
    data_dir = '/data/NFS/potato/username/MD17'
    dataset = MD17(data_dir,name=molecule)

    test_indices_df = pd.read_csv('splits/index_test_01.csv',header=None)
    test_indices = test_indices_df[0].tolist()

    train_indices_df = pd.read_csv('splits/index_train_01.csv',header=None)
    train_indices = train_indices_df[0].tolist()

    # Create Subsets for training, validation, and testing
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    #val_dataset = torch.utils.data.Subset(dataset, val_indices)
    num_atoms = len(train_dataset[0].z)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    print(dataset[0].z)
    atomic_symbols = [atom_type_map[atom_num.item()] for atom_num in dataset[0].z]
    curr_atom = Atoms(symbols=atomic_symbols)
    # Get atomic masses
    masses = curr_atom.get_masses()
    print("Atomic masses for current atom:", masses)

    masses = np.array([atomic_masses[z.item()] for z in dataset.z])
    masses = masses.reshape(int(len(masses)/num_atoms),num_atoms)[train_indices]
    pos = np.asarray(dataset.pos.reshape(int(len(dataset.pos)/num_atoms),num_atoms,3))[train_indices]

    all_moments = []
    for x,m in zip(pos,masses):
        all_moments.append(compute_moments_of_inertia(x,m))

    plot_moment_hist(all_moments,title=f"Distribution of Principal Components {molecule}",filename=f"{molecule}_inertia_hist.pdf")



  
if __name__ == '__main__':
    main()
