import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
import matplotlib.pyplot as plt
from torch_geometric.datasets import MD17
import pandas as pd
from sklearn.neighbors import KernelDensity
from mpl_toolkits.mplot3d import Axes3D

import utils 
#from models import TransformerModel
from datasets import DetectRotatedMD17Dataset
import matplotlib.pyplot as plt
import seaborn as sns

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

import os
from ase import Atoms
# Map atomic numbers to masses
atom_type_map = {
    1: 'H',   # Hydrogen
    6: 'C',   # Carbon
    7: 'N',   # Nitrogen
    8: 'O',   # Oxygen
    16: 'S',  # Sulfur (common in organic molecules)
    9: 'F',   # Fluorine (if applicable)
    17: 'Cl', # Chlorine (if applicable)
    35: 'Br', # Bromine (if applicable)
    53: 'I'   # Iodine (if applicable)
}

atomic_masses = {
    1: 1.00784,   # Hydrogen
    6: 12.0107,   # Carbon
    7: 14.007,    # Nitrogen
    8: 15.999,    # Oxygen
    16: 32.065,   # Sulfur
    9: 18.998,    # Fluorine
    17: 35.45,    # Chlorine
    35: 79.904,   # Bromine
    53: 126.904   # Iodine
}


def compute_moments_of_inertia(positions, masses,normalize_by_mass = False):
    """
    Compute the principal moments of inertia of a molecule.

    Parameters:
    - positions: (N, 3) array of Cartesian coordinates of the atoms.
    - masses: (N,) array of atomic masses.

    Returns:
    - moments: Sorted array of principal moments of inertia (I1, I2, I3).
    """
    # Center the positions relative to the center of mass
    positions = positions.copy()  # Avoid modifying the input array
    center_of_mass = np.sum(positions.T * masses, axis=1) / np.sum(masses)
    positions -= center_of_mass

    # Compute the inertia tensor
    I = np.zeros((3, 3))
    for i, mass in enumerate(masses):
        x, y, z = positions[i]
        I[0, 0] += mass * (y**2 + z**2)
        I[1, 1] += mass * (x**2 + z**2)
        I[2, 2] += mass * (x**2 + y**2)
        I[0, 1] -= mass * x * y
        I[0, 2] -= mass * x * z
        I[1, 2] -= mass * y * z
    
    # Symmetrize the off-diagonal elements
    I[1, 0] = I[0, 1]
    I[2, 0] = I[0, 2]
    I[2, 1] = I[1, 2]

    # Compute eigenvalues (principal moments of inertia)

    if normalize_by_mass:
        I /= np.sum(masses)
    #moments = np.linalg.eigvalsh(I)
    # instead should be plotting the axes?
    eigvals, eigvecs = np.linalg.eigh(I)
    sort_idx = np.argsort(eigvals)
    eigvecs_sorted = eigvecs[:, sort_idx]
    eigvals_sorted = eigvals[sort_idx]
    return eigvals_sorted, eigvecs_sorted

def plot_moment_hist(all_moments, title, filename, scale=None):
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.preprocessing import StandardScaler, MinMaxScaler
    
    # Convert moments of inertia to a numpy array
    all_moments = np.array(all_moments)  # Shape: (n_frames, 3)

    # Apply scaling if needed
    if scale == 'standardize':
        scaler = StandardScaler()
        all_moments = scaler.fit_transform(all_moments)
    elif scale == 'minmax':
        scaler = MinMaxScaler()
        all_moments = scaler.fit_transform(all_moments)
    elif scale == 'log':
        all_moments = np.log1p(all_moments)

    # Plot histograms for each principal moment
    plt.figure(figsize=(7,3))

    for i in range(3):
        plt.hist(all_moments[:, i], bins=30, alpha=0.6, label=f"I{i+1}",density=True)

    plt.xlabel("Moment of Inertia (Normalized by Total Mass)")
    plt.ylabel("Frequency")
    #plt.yticks([])
    #plt.tick_params(axis='y', which='both', left=False)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def plot_moment_scatter(all_moments, title, filename, scale=None):
    import matplotlib.pyplot as plt
    import numpy as np
    
    all_moments = np.array(all_moments)

    if scale == 'log':
        all_moments = np.log1p(all_moments)

    plt.figure(figsize=(10, 4))
    
    pairs = [(0, 1), (1, 2), (0, 2)]
    labels = [('I1', 'I2'), ('I2', 'I3'), ('I1', 'I3')]

    for i, ((x_idx, y_idx), (xlab, ylab)) in enumerate(zip(pairs, labels)):
        plt.subplot(1, 3, i + 1)
        plt.scatter(all_moments[:, x_idx], all_moments[:, y_idx], alpha=0.5, s=10)
        plt.xlabel(xlab)
        plt.ylabel(ylab)
        plt.title(f"{xlab} vs {ylab}")

    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def plot_moment_ratios(all_moments, title, filename):
    import matplotlib.pyplot as plt
    import numpy as np
    
    all_moments = np.array(all_moments)
    
    ratio_21 = all_moments[:, 1] / all_moments[:, 0]
    ratio_31 = all_moments[:, 2] / all_moments[:, 0]
    ratio_32 = all_moments[:, 2] / all_moments[:, 1]

    plt.figure(figsize=(8, 3))

    plt.hist(ratio_21, bins=30, alpha=0.6, label="I2/I1", density=True)
    plt.hist(ratio_31, bins=30, alpha=0.6, label="I3/I1", density=True)
    plt.hist(ratio_32, bins=30, alpha=0.6, label="I3/I2", density=True)

    plt.xlabel("Ratio")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()


def plot_principal_axes_density(all_vecs, title, filename, axes_idx=0):
    """
    Visualize the distribution of a selected principal axis (eigenvector) using a 2D density plot.

    Parameters:
    - all_vecs: list or array of shape (n_samples, 3, 3), where each [i,:,:] is the 3 eigenvectors.
    - title: title of the plot.
    - filename: where to save the figure.
    - axes_idx: which eigenvector to visualize (0 for smallest, 2 for largest).
    """
    all_vecs = np.array(all_vecs)
    selected_axes = all_vecs[:, :, axes_idx]  # shape: (n_samples, 3)
    
    # Take the x, y components of the eigenvectors (for 2D visualization)
    x, y = selected_axes[:, 0], selected_axes[:, 1]

    # Create density plot using seaborn (2D Kernel Density Estimate)
    plt.figure(figsize=(8, 6))
    sns.kdeplot(x, y, cmap="Blues", fill=True, levels=10)
    plt.title(f"{title}\nPrincipal Axis {axes_idx+1}")
    plt.xlabel("X Component")
    plt.ylabel("Y Component")
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def compute_angle(axis1, axis2):
    """
    Compute the angle (in degrees) between two vectors.
    """
    dot_product = np.dot(axis1, axis2)
    norm1 = np.linalg.norm(axis1)
    norm2 = np.linalg.norm(axis2)
    cos_theta = dot_product / (norm1 * norm2)
    cos_theta = np.clip(cos_theta, -1.0, 1.0)  # To avoid numerical issues
    theta = np.arccos(cos_theta)
    return np.degrees(theta)  # Return angle in degrees

def compute_angles(all_moments, canonical_axis=np.array([1, 0, 0])):
    """
    Compute the angles between the principal axes of each molecule
    and the canonical axis (e.g., x-axis).
    """
    angles = []
    for moments in all_moments:
        principal_axes = moments  # Assuming that moments[1] are the axes (eigenvectors)
        angles_molecule = []
        for axis in principal_axes.T:  # Iterate over each axis (column)
            angle = compute_angle(axis, canonical_axis)
            angles_molecule.append(angle)
        angles.append(angles_molecule)
    return np.array(angles)

def plot_3d_heatmap_all_axes(all_vecs, title, filename):
    """
    Visualize the distribution of the three principal axis directions in 3D space, colored by density.

    Parameters:
    - all_vecs: shape (n_samples, 3, 3), each [i, :, j] is the j-th eigenvector (axis).
    """
    all_vecs = np.array(all_vecs)
    
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    for i in range(3):  # for each principal axis
        axis_vectors = all_vecs[:, :, i]  # shape (n_samples, 3)
        kde = KernelDensity(kernel='gaussian', bandwidth=0.1)
        kde.fit(axis_vectors)
        density = np.exp(kde.score_samples(axis_vectors))

        x, y, z = axis_vectors[:, 0], axis_vectors[:, 1], axis_vectors[:, 2]
        scatter = ax.scatter(x, y, z, c=density, cmap='viridis', s=10, alpha=0.7, label=f'Axis {i+1}')

    cbar = plt.colorbar(scatter, ax=ax, shrink=0.5, aspect=5)
    cbar.set_label('Density')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(title)
    ax.legend()
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def plot_angles_distribution(angles):
    plt.figure(figsize=(8, 6))
    plt.hist(angles, bins=50, density=True, alpha=0.7)
    plt.xlabel("Angle (degrees)")
    plt.ylabel("Density")
    plt.title("Distribution of Angles Between Principal Axes and Canonical Axis")
    plt.savefig("angle_dist.pdf")
    #plt.show()

import plotly.graph_objects as go
from mpl_toolkits.mplot3d import Axes3D

def plot_3d_voxel_histogram(all_vecs, title, filename = "moment_3d_dist.html"):
    """
    Visualize the density of principal axis directions in 3D using voxels and Plotly.

    Parameters:
    - all_vecs: array of shape (n_samples, 3, 3)
    - title: plot title
    - filename: path to save HTML
    """
    all_vecs = np.array(all_vecs)  # (n_samples, 3, 3)

    # Each axis direction: shape (n_samples, 3)
    directions = [all_vecs[:, :, i] for i in range(3)]

    # Normalize to unit vectors (optional)
    directions = [v / np.linalg.norm(v, axis=1, keepdims=True) for v in directions]

    colors = ['red', 'green', 'blue']
    axis_labels = ['Axis 1', 'Axis 2', 'Axis 3']

    fig = go.Figure()

    for i, vecs in enumerate(directions):
        # Round to voxel grid (e.g., 20x20x20)
        res = 20
        grid = ((vecs + 1) / 2 * res).astype(int)
        grid = np.clip(grid, 0, res - 1)

        # Count frequencies
        voxels = np.zeros((res, res, res))
        for x, y, z in grid:
            voxels[x, y, z] += 1

        # Normalize for color intensity
        max_val = np.max(voxels)
        xs, ys, zs, vals = [], [], [], []
        for x in range(res):
            for y in range(res):
                for z in range(res):
                    v = voxels[x, y, z]
                    if v > 0:
                        xs.append(x)
                        ys.append(y)
                        zs.append(z)
                        vals.append(v / max_val)

        # Add a scatter3d plot for this axis

        # Define opacity buckets
        opacity_bins = np.linspace(0, 1, 10)
        vals = np.asarray(vals)
        for j in range(1, len(opacity_bins)):
            mask = (vals >= opacity_bins[j-1]) & (vals < opacity_bins[j])
            fig.add_trace(go.Scatter3d(
                x=xs[mask], y=ys[mask], z=zs[mask],
                mode='markers',
                marker=dict(
                    size=4,
                    color=colors[i],  # solid color
                    opacity=opacity_bins[j]
                ),
                showlegend=False
            ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        title=title,
        margin=dict(l=0, r=0, b=0, t=40),
        legend=dict(x=0, y=1)
    )

    fig.write_html(filename)




def main(dataset="QM9"):
    if dataset == "MD17":
        molecules=[
            "revised benzene",
            "revised uracil",
            "revised naphthalene",
            "revised aspirin",
            "revised salicylic acid",
            "revised malonaldehyde",
            "revised ethanol",
            "revised toluene",
            "revised paracetamol",
            "revised azobenzene"]
        for molecule in molecules:
            print(f"plotting {molecule}")
            data_dir = '/data/NFS/potato/username/md17'
            dataset = MD17(data_dir,name=molecule)

            test_indices_df = pd.read_csv('configs/dataset/md17_splits/index_test_01.csv',header=None)
            test_indices = test_indices_df[0].tolist()

            train_indices_df = pd.read_csv('configs/dataset/md17_splits/index_train_01.csv',header=None)
            train_indices = train_indices_df[0].tolist()

            # Create Subsets for training, validation, and testing
            train_dataset = torch.utils.data.Subset(dataset, train_indices)
            #val_dataset = torch.utils.data.Subset(dataset, val_indices)
            num_atoms = len(train_dataset[0].z)
            test_dataset = torch.utils.data.Subset(dataset, test_indices)
            print(dataset[0].z)
            atomic_symbols = [atom_type_map[atom_num.item()] for atom_num in dataset[0].z]
            curr_atom = Atoms(symbols=atomic_symbols)
            # Get atomic masses
            masses = curr_atom.get_masses()
            print("Atomic masses for current atom:", masses)

            masses = np.array([atomic_masses[z.item()] for z in dataset.z])
            masses = masses.reshape(int(len(masses)/num_atoms),num_atoms)[train_indices]
            pos = np.asarray(dataset.pos.reshape(int(len(dataset.pos)/num_atoms),num_atoms,3))[train_indices]

            all_moments = []
            for x,m in zip(pos,masses):
                all_moments.append(compute_moments_of_inertia(x,m))

            plot_moment_hist(all_moments,title=f"Distribution of Principal Components {molecule}",filename=f"{molecule}_inertia_hist.pdf",scale='log')
    elif dataset=="QM9":
        data_dir = 'data/NFS/potato/username/'
        dataset = QM9(root=data_dir)
        num_atoms = len(dataset)
        all_masses = []
        all_pos = []
        for elem in dataset:
            atomic_symbols = [atom_type_map[atom_num.item()] for atom_num in elem.z]
            num_atoms = len(atomic_symbols)
            curr_atom = Atoms(symbols=atomic_symbols)
            masses = curr_atom.get_masses()
            pos = np.array(elem.pos)
            all_masses.append(masses), all_pos.append(pos)
        all_components = []
        all_vecs = []
        for x,m in zip(all_pos,all_masses):
            I_comp, I_vec = compute_moments_of_inertia(x,m,normalize_by_mass=False)
            all_components.append(I_comp)
            all_vecs.append(I_vec)
        
        #plot_moment_hist(all_components,title=f"Distribution of Principal Components QM9",filename=f"QM9_inertia_hist.pdf",scale=None)
        all_vecs = np.array(all_vecs)
        # Example usage
        all_angles = compute_angles(all_vecs)
        plot_3d_voxel_histogram(all_vecs,title="Histogram of Principal Components")
        plot_angles_distribution(all_angles)
        #plot_moment_scatter(all_moments, title=f"Scatter of Inertia QM9", filename=f"QM9_inertia_scatter.pdf")
        #plot_moment_ratios(all_moments, title=f"Moment Ratios QM9", filename=f"QM9_inertia_ratios.pdf")

    



  
if __name__ == '__main__':
    main()
