import torch
print("PyTorch version:", torch.__version__)
print("CUDA version used by PyTorch:", torch.version.cuda)

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
import matplotlib.pyplot as plt
import pandas as pd

import utils 
from models import TransformerModel
from datasets import AugmentedMD17Dataset

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

import os
#data_dir = os.environ['DATA_DIR']
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse

def getArgumentParser():
    """ Get arguments from command line"""
    parser = argparse.ArgumentParser(description="Script to map two symmetrically different charge/magnetic densities and learn order parameters.")
    parser.add_argument('-dataset',
                        '--dataset',
                        dest = 'dataset',
                        help = 'dataset to run',
                        default = 'revised aspirin'
    )
    return parser

def prepare_data_for_rotation(data_batch,num_atom_types):
    # For this task, we just need the positions of the atoms
    pos = data_batch.pos  # Positions of atoms (num_nodes, 3)
    atom_types = data_batch.z-1  # Atomic numbers of the atoms (num_nodes, 1)

    # Convert atomic numbers to one-hot encoding
    # Assuming the maximum number of atom types in the MD17 dataset
    atom_types_one_hot = torch.nn.functional.one_hot(atom_types.long(), num_classes=num_atom_types).float()

    # Node features are the concatenation of atom positions and atom type features
    node_features = torch.cat((pos, atom_types_one_hot), dim=-1)  # (num_nodes, 3 + num_atom_types)
    return node_features

def save_plot(x, y1, y2, xlabel, ylabel, labels, title, filename):
    """Save a plot to a file."""
    plt.figure()
    plt.plot(x, y1, label=labels[0])
    plt.plot(x, y2, label=labels[1])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.savefig(filename)
    plt.close()

def main():
    options = getArgumentParser().parse_args()
    datadir = '/data/NFS/potato/username/MD17'
    dataset = options.dataset
    molecule = dataset

    # Prepare datasets
    dataset = AugmentedMD17Dataset(
        root=datadir, molecule=molecule)
    print(dataset[0][0].energy)
    test_indices_df = pd.read_csv('splits/index_test_01.csv', header=None)
    test_indices = test_indices_df[0].tolist()
    train_indices_df = pd.read_csv('splits/index_train_01.csv', header=None)
    train_indices = train_indices_df[0].tolist()

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, test_indices)

    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    num_atoms = len(train_dataset[0][0].z)
    num_atom_types = max(train_dataset[0][0].z).item()
    num_node_features = int(num_atom_types + 3)
    max_nodes = int(num_atoms)
    hidden_dim = 128
    num_heads = 4
    num_layers = 4
    output_dim = 1
    dataset_type = "MD17"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TransformerModel(
        num_node_features=num_node_features,
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        output_dim=output_dim,
        max_nodes=max_nodes,
        use_pos_embedding = True
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    # Update the output dimension for regression
    output_dim = 1  # Predicting a scalar energy value

    # Update loss function to Mean Squared Error for regression
    criterion = nn.MSELoss()

    # Training Loop
    num_epochs = 301
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')  # Track the best validation loss

    for epoch in range(num_epochs):
        running_loss = 0.0
        
        model.train()
        
        for batch in train_loader:
            data, expected = batch
            data = data.to(device)
            expected = expected.to(device).float().view(-1, 1)  # Ensure expected is a float tensor and reshaped for regression
            optimizer.zero_grad()
            
            # Forward pass
            outputs = utils.apply_fwd_pass(model, data, dataset_type, num_atom_types)
            # Compute loss
            loss = criterion(outputs, expected)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        # Calculate average loss for the epoch
        epoch_train_loss = running_loss / len(train_loader)
        train_losses.append(epoch_train_loss)
        
        # Validation
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                data, expected = batch
                data, expected = data.to(device), expected.to(device).float().view(-1, 1)
                
                # Forward pass
                outputs = utils.apply_fwd_pass(model, data, dataset_type, num_atom_types)
                
                # Compute loss
                val_loss += criterion(outputs, expected).item()
                
        # Calculate average validation loss
        epoch_val_loss = val_loss / len(val_loader)
        val_losses.append(epoch_val_loss)
        
        # Update the best validation loss
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
        
        # Print progress
        if epoch % 10 == 0:
            print(f"Epoch {epoch + 1}, Train Loss: {epoch_train_loss:.4f}")
            print(f"Validation Loss: {epoch_val_loss:.4f}")

    # Save the best validation loss
    #with open('all_best_losses_new_model.txt', 'a') as f:
    #    f.write(f"{molecule}:\n")
    #    f.write(f"Best Validation Loss: {best_val_loss:.4f}\n")
    #    f.write("\n")

    #print(f"Best Validation Loss: {best_val_loss:.4f}")

    # Plot loss
    epochs = list(range(1, num_epochs + 1))
    save_plot(
        epochs, train_losses, val_losses,
        xlabel="Epoch", ylabel="Loss",
        labels=["Train Loss", "Validation Loss"],
        title=f"Loss Plot for {molecule}",
        filename=f"{molecule}_energy_loss_plot.pdf"
    )

if __name__ == '__main__':
    main()
