'''
Parallel Alignment of QM7X
=========================
This script is used to normalize the QM7X dataset using the CategoricalPointCloud class.
In addition, it performs a random rotation and translation of the point cloud and calculates the Wasserstein distance between the original and the transformed point cloud.
It does so in parallel for all the molecules in the QM7X dataset.
'''

# Start up
# -------
import argparse
import sys
import logging

from time import time

import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.stats import wasserstein_distance_nd

import torch
from torch_geometric.loader import DataLoader

from pointgroup import PointGroup

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

from loaders.qm7x import QM7X


# Setup
# -----
parser = argparse.ArgumentParser()
parser.add_argument('--n_data', type=int, default=100, help='Random seed')
parser.add_argument('--frq_log', type=int, default=10, help='Random seed')
args = parser.parse_args()

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
start_time = time()

qm7x = QM7X('data/qm7x')

frame = Frame(tol=1e-2, save='all')

atomic_number_to_symbol = {
    1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'
    }
loss = 0
recon_loss = 0

# Helper Functions
# ----------------
def compute_loss(data, data_transformed):
    loss = wasserstein_distance_nd(data, data_transformed)
    return loss


np.random.seed(42)

pg_map = ["C1","C1h","C1v","C2","C2d","C2h","C2v","C3","C3h","C3v","C4","Cs","Cinfv","Ci","Dinfh","D2","D2d","D2h","D3","D3d","D3h","D6h","Oh","Td","S2","S4",]
pg_map = {pg: idx for idx, pg in enumerate(pg_map)}

pg_losses = [0 for _ in pg_map]
pg_counts = [0 for _ in pg_map]

# Main Loop
# ---------
for idx,data in enumerate(qm7x[:args.n_data]):

    logging.info(f"Completed {idx+1}/{args.n_data} iterations.")

    pc_data = data.pos
    cat_data = data.z.numpy()
    normalized_data, frame_R, frame_t = frame.get_frame(pc_data, cat_data)

    pos_mean = pc_data.mean(dim=0)
    hVDIP = data.hVDIP
    if torch.linalg.norm( hVDIP.mean(dim=0)) > 1e-4:
        print(f'VDIP: {hVDIP.mean(dim=0)}')
        exit()


    try:
        smiles = data.smiles
    except:
        smiles = ''.join([atomic_number_to_symbol[cat] for cat in cat_data])
    symbols = [atomic_number_to_symbol[cat] for cat in cat_data]
    try:
        pg = PointGroup(normalized_data, symbols).get_point_group()
    except:
        pg = 'C1'
    print(f'{smiles}: {pg}')
    print(f'Symmetric Elements: {frame.symmetric_elements}')
    print(f'Simple ASU: {frame.simple_asu}')

    loss += compute_loss(pc_data, normalized_data)

    pg_losses[pg_map[pg]] += compute_loss(pc_data, normalized_data)
    pg_counts[pg_map[pg]] += 1

    #inv_R = torch.tensor(R.from_matrix(frame_R).inv().as_matrix(), dtype=torch.float32)
    inv_R = torch.linalg.inv(frame_R)
    recon_data = (inv_R @ normalized_data.T).T + frame_t
    r_loss = compute_loss(pc_data, recon_data)
    if r_loss > 1e-4:
        logging.info(f'Loss {smiles}: {r_loss:.4f}')
        break
    recon_loss += r_loss

logging.info(f'Average move {loss/args.n_data:.4f}')
logging.info(f'Reconstruction loss {recon_loss/args.n_data:.4f}')
logging.info(f'PG Losses: {pg_losses}')
logging.info(f'PG Counts: {pg_counts}')
logging.info(f'Time: {time()-start_time:.4f}')
logging.info('Done!')
