import os
import logging

from tqdm import tqdm

import torch
import ase.db
import numpy as np


import schnetpack as spk
from schnetpack.interfaces.ase_interface import AtomsConverter
from schnetpack.interfaces import SpkCalculator
from schnetpack.datasets import QM7X
from ase.optimize.lbfgs import LBFGS
from schnetpack.diffusion.utils import rmsd


logger = logging.getLogger()
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
logger.removeHandler(logger.handlers[0])
logger.addHandler(consoleHandler)

raw_path = None
db_path = None
target_path = None
denoised_db_path = None
split_path = None
continue_old = False


# settings
cutoff = 5.
force_th = 0.000005
max_steps = 1000
split_name = "val_idx"
n_struc_per_mol = 10
batch_size = 3

update_postprocessor = False
use_cpu = True


def prepare_noisy_db():
    
    if os.path.exists(target_path):
        logger.warning(f"Target path: {target_path} already exists."
                        f"the function will return and not overwrite the existing db.")
        return
    
    md = ase.db.connect(db_path).metadata
    groups_ids = np.array(md['groups_ids']['smiles_id'])
    steps_ids = np.array(md['groups_ids']['step_id'])
    stereo_iso_ids = np.array(md['groups_ids']['stereo_iso_id'])
    conform_ids = np.array(md['groups_ids']['conform_id'])
    split = np.load(split_path)[split_name]

    # get non_equilibrium system indices
    group = np.unique(groups_ids[steps_ids == 0][split])

    if not os.path.exists(db_path):
        raise ValueError(f"Source db path: {db_path} does not exist")

    with ase.db.connect(db_path) as source, ase.db.connect(target_path) as target:
        for g in tqdm(group):
            g_mask = groups_ids == g
            iso_id = np.unique(stereo_iso_ids[g_mask])
            conf_id = np.unique(conform_ids[g_mask])
            for i in iso_id:
                for c in conf_id:
                    _rmsd_list = []
                    rows = []
                    idx = np.where((groups_ids == g) & (stereo_iso_ids == i) & (conform_ids == c) & (steps_ids != 0))[0]
                    if len(idx) == 0:
                        continue
                    assert len(np.unique(steps_ids[idx])) == 100
                    ids = np.array(md['groups_ids']['id'])[idx]
                    for id in ids:
                        rows.append(source.get(int(id)))
                        _rmsd_list.append(rows[-1].data['rmsd'].item())
                    _rmsd_sort = np.argsort(_rmsd_list)
                    _samples_idx = _rmsd_sort[np.linspace(0, len(ids)-1, n_struc_per_mol, dtype=int)]
                    for samp_idx in _samples_idx:
                        orig_idx = np.where((groups_ids == g) & (stereo_iso_ids == i) & (conform_ids == c) & (steps_ids == 0))[0]
                        assert len(orig_idx) == 1
                        orig_id = np.array(md['groups_ids']['id'])[orig_idx]
                        orig_mol = source.get(int(orig_id[0]))
                        row_data = {
                                "orig_R": orig_mol.positions,
                                "rmsd": np.array([_rmsd_list[samp_idx]])
                            }
                        target.write(rows[samp_idx].toatoms(),
                                    group_id=g,
                                    stereo_iso_id=i,
                                    conform_id=c,
                                    step_id=steps_ids[idx][samp_idx],
                                    orig_id=ids[samp_idx],
                                    data=row_data)
            break # !!!!!!!!!!!!!!!
                                
    with ase.db.connect(db_path) as source, ase.db.connect(target_path) as target:
        target.metadata =  {
            "split_path": split_path,
            "orig_db": db_path,
            "groups_ids": source.metadata['groups_ids'],
            "_property_unit_dict": {
                "orig_R": "Ang",
                "rmsd": "Ang"
            },
            "_distance_unit": "Ang",
            "atomrefs": []
        }
        
        
if __name__ == "__main__":
    
    if os.path.exists(denoised_db_path) and not continue_old:
        raise ValueError(f"Target path for denoised systems: {denoised_db_path} already exists.")
    
    prepare_noisy_db()
    
    model = torch.load(model_path, map_location=torch.device('cpu') if use_cpu else None)
    
    transforms=[
        spk.transform.SubtractCenterOfGeometry(),
        spk.transform.MatScipyNeighborList(cutoff=cutoff),
        spk.transform.CastTo32(),
        ]
    
    if update_postprocessor:
        with ase.db.connect(db_path) as conn:
            atomrefs = conn.metadata["atomrefs"]
            
        data_qm7 = QM7X(db_path, batch_size=batch_size, raw_data_path=raw_path, split_file=split_path,
                    only_equilibrium=False, num_train=30, num_val=3,
                    transforms=transforms, load_properties=[QM7X.energy, QM7X.forces, QM7X.RMSD], property_units={QM7X.energy: "eV", QM7X.forces:"eV/Ang", QM7X.RMSD:"Ang"}, num_workers=4)
        data_qm7.prepare_data()
        data_qm7.setup()
        
        model.postprocessors[1].atomref = torch.tensor(atomrefs[QM7X.energy])
        model.postprocessors[1].mean = data_qm7.get_stats(
                        QM7X.energy, True, True
                    )[0]
    
        torch.save(model, model_path)
        
    # build atoms converter and calculator
    calculator = SpkCalculator(
            model_file=model_path,
            converter=AtomsConverter,
            device=torch.device("cpu" if use_cpu else "cuda"),
            neighbor_list=None,
            transforms=transforms,
            energy_unit="eV",
            position_unit="Ang",
        )
    
    with ase.db.connect(target_path) as conn:
        n_systems = conn.count()
    
    if continue_old and os.path.exists(denoised_db_path):
        with ase.db.connect(denoised_db_path) as conn:
            current_idx = conn.count()
            logger.warning(f"Continuing from system with index {current_idx}")
    else:
        current_idx = 0
    
    for idx in tqdm(range(current_idx, n_systems)):
        # load atoms
        with ase.db.connect(target_path) as conn:
            row = conn.get(idx+1)
            atoms = row.toatoms()
            orig_R = row.data['orig_R']
            orig_rmsd = row.data['rmsd'][0]

        # build optimizer
        atoms.calc = calculator
        
        optimizer = LBFGS(atoms, force_consistent=False, damping=0.7)

        # optimize
        optimizer.run(steps=max_steps, fmax=force_th)

        ref_mol = atoms.copy()
        ref_mol.positions = orig_R
        try:
            diff = rmsd(ref_mol, atoms)
        except:
            logger.warning("RMSD calculation failed and rmsd set to NaN. Maybe very different structures !")
            diff = np.nan
            
        logging.info(f"Result RMSD: {diff:.4f} Ang with ratio {diff/orig_rmsd:.4f}")
        logging.info(f"Number of steps: {optimizer.get_number_of_steps()}")
        
        
        data = {
                "rmsd": diff,
                "orig_rmsd": orig_rmsd,
                "num_steps": optimizer.get_number_of_steps(),
                #"stable": int(stable_molecules[i]),
                'energy': atoms.get_potential_energy(),
                'forces': atoms.get_forces(),
            }
        
        # save structure to new db
        with ase.db.connect(denoised_db_path) as conn:                
                conn.write(
                    atoms,
                    orig_id=idx+1,
                    rmsd=diff,
                    data=data
                )
         
    