import os
import logging

from tqdm import tqdm

import torch
import ase.db
import numpy as np

from ase import Atoms
import schnetpack as spk
from schnetpack import properties
from schnetpack.interfaces.ase_interface import AtomsConverter
from schnetpack.interfaces import SpkCalculator
from schnetpack.datasets import QM7X
from ase.optimize.lbfgs import LBFGS
from schnetpack.diffusion.utils import rmsd
from schnetpack.diffusion.transforms import Diffuse
from schnetpack.diffusion.noise_schedule import PolynomialSchedule


logger = logging.getLogger()
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
logger.removeHandler(logger.handlers[0])
logger.addHandler(consoleHandler)

raw_path = None 
db_path = None 
target_path = None 
denoised_db_path = None 
split_path = None 
model_path =None 
continue_old = False


# settings
cutoff = 5.
force_th = 0.000005
max_steps = 1000
split_name = "val_idx"
n_struc_per_mol = 10
batch_size = 3

update_postprocessor = False
use_cpu = True

        
if __name__ == "__main__":
    
    if os.path.exists(denoised_db_path) and not continue_old:
        raise ValueError(f"Target path for denoised systems: {denoised_db_path} already exists.")
            
    transforms=[
        spk.transform.SubtractCenterOfGeometry(),
        Diffuse(PolynomialSchedule(T=1000, s=1e-5), diffuse_z=False, diffuse_all=False, use_forces=True, t_train=250, include_t=None, exclude_eps_0=False),
        spk.transform.MatScipyNeighborList(cutoff=cutoff),
        spk.transform.CastTo32(),
        ]
    
    
    data_qm7 = QM7X(db_path, batch_size=1, raw_data_path=raw_path, split_file=split_path,
                only_equilibrium=False, num_train=30, num_val=3,
                transforms=transforms, load_properties=[QM7X.energy, QM7X.forces, QM7X.RMSD], property_units={QM7X.energy: "eV", QM7X.forces:"eV/Ang", QM7X.RMSD:"Ang"}, num_workers=4)
    data_qm7.prepare_data()
    data_qm7.setup()
    
    if update_postprocessor:
        with ase.db.connect(db_path) as conn:
            atomrefs = conn.metadata["atomrefs"]
            
        model = torch.load(model_path, map_location=torch.device('cpu') if use_cpu else None)

        model.postprocessors[1].atomref = torch.tensor(atomrefs[QM7X.energy])
        model.postprocessors[1].mean = data_qm7.get_stats(
                        QM7X.energy, True, True
                    )[0]
    
        torch.save(model, model_path)
        
    # build atoms converter and calculator
    transforms=[
        spk.transform.SubtractCenterOfGeometry(),
        spk.transform.MatScipyNeighborList(cutoff=cutoff),
        spk.transform.CastTo32(),
        ]
    
    calculator = SpkCalculator(
            model_file=model_path,
            converter=AtomsConverter,
            device=torch.device("cpu" if use_cpu else "cuda"),
            neighbor_list=None,
            transforms=transforms,
            energy_unit="eV",
            position_unit="Ang",
        )
    
    if continue_old and os.path.exists(denoised_db_path):
        with ase.db.connect(denoised_db_path) as conn:
            current_idx = conn.count()
            logger.warning(f"Continuing from system with index {current_idx}")
    else:
        current_idx = 0
    
    for idx in tqdm(range(current_idx, len(data_qm7.val_dataset))):
        # load atoms
        mol = data_qm7.val_dataset[idx]
        noisy_R = mol[properties.R].detach().cpu().numpy()
        atoms = Atoms(positions=mol[properties.R].detach().cpu().numpy(),
                      numbers=mol[properties.Z].detach().cpu().numpy())

        ref_mol = atoms.copy()
        ref_mol.positions = mol['original_R'].detach().cpu().numpy()
        try:
            orig_rmsd = rmsd(ref_mol, atoms)
        except:
            logger.warning("RMSD calculation failed and rmsd set to NaN. Maybe very different structures !")
            orig_rmsd = np.nan

        # build optimizer
        atoms.calc = calculator
        
        optimizer = LBFGS(atoms, force_consistent=False, damping=0.7)

        # optimize
        optimizer.run(steps=max_steps, fmax=force_th)

        ref_mol = atoms.copy()
        ref_mol.positions = mol['original_R'].detach().cpu().numpy()
        try:
            diff = rmsd(ref_mol, atoms)
        except:
            logger.warning("RMSD calculation failed and rmsd set to NaN. Maybe very different structures !")
            diff = np.nan
            
        logging.info(f"Result RMSD: {diff:.4f} Ang with ratio {diff/orig_rmsd:.4f}")
        logging.info(f"Number of steps: {optimizer.get_number_of_steps()}")
        
        
        data = {
                "rmsd": diff,
                "orig_rmsd": orig_rmsd,
                "num_steps": optimizer.get_number_of_steps(),
                #"stable": int(stable_molecules[i]),
                'energy': atoms.get_potential_energy(),
                'forces': atoms.get_forces(),
                'noisy_R': noisy_R,
                'original_R': mol['original_R'].detach().cpu().numpy(),
            }
        
        # save structure to new db
        with ase.db.connect(denoised_db_path) as conn:                
                conn.write(
                    atoms,
                    orig_id=idx+1,
                    rmsd=diff,
                    data=data
                )
         
    