import os
import logging
import sys 

import numpy as np

import fasteners

from ase import Atoms
import ase.db
from tqdm import tqdm
import itertools

import shutil

from schnetpack.diffusion.noise_schedule import *
from schnetpack.diffusion.utils import *
from schnetpack.data import AtomsLoader, ASEAtomsData
from schnetpack import properties


logger = logging.getLogger()
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
logger.removeHandler(logger.handlers[0])
logger.addHandler(consoleHandler)


db_path = None
target_path = None
denoised_db_path = None
split_path = None
model_path = None
data_workdir = None 
continue_old = True

use_cpu = True

T = 1000
cutoff = 5.
noise_schedule = PolynomialSchedule(T=T, s=1e-5)
start_step = 400

max_steps=1000
batch_size = 3
n_struc_per_mol = 7
num_workers = 4
split_name = "val_idx"
predict_time = True


transforms=[
        trn.SubtractCenterOfGeometry(),
        trn.MatScipyNeighborList(cutoff=5.),
        trn.CastTo32(),
        ]

def prepare_noisy_db():
    
    md = ase.db.connect(db_path).metadata
    groups_ids = np.array(md['groups_ids']['smiles_id'])
    steps_ids = np.array(md['groups_ids']['step_id'])
    stereo_iso_ids = np.array(md['groups_ids']['stereo_iso_id'])
    conform_ids = np.array(md['groups_ids']['conform_id'])
    split = np.load(split_path)[split_name]

    # get non_equilibrium system indices
    group = np.unique(groups_ids[steps_ids == 0][split])

    if os.path.exists(target_path):
        logger.warning(f"Target path: {target_path} already exists."
                        f"the function will return and not overwrite the existing db.")
        return
    if not os.path.exists(db_path):
        raise ValueError(f"Source db path: {db_path} does not exist")

    with ase.db.connect(db_path) as source, ase.db.connect(target_path) as target:
        for g in tqdm(group):
            g_mask = groups_ids == g
            iso_id = np.unique(stereo_iso_ids[g_mask])
            conf_id = np.unique(conform_ids[g_mask])
            for i in iso_id:
                for c in conf_id:
                    _rmsd_list = []
                    rows = []
                    idx = np.where((groups_ids == g) & (stereo_iso_ids == i) & (conform_ids == c) & (steps_ids != 0))[0]
                    if len(idx) == 0:
                        continue
                    assert len(np.unique(steps_ids[idx])) == 100
                    ids = np.array(md['groups_ids']['id'])[idx]
                    for id in ids:
                        rows.append(source.get(int(id)))
                        _rmsd_list.append(rows[-1].data['rmsd'].item())
                    _rmsd_sort = np.argsort(_rmsd_list)
                    _samples_idx = _rmsd_sort[np.linspace(0, len(ids)-1, n_struc_per_mol, dtype=int)]
                    for samp_idx in _samples_idx:
                        orig_idx = np.where((groups_ids == g) & (stereo_iso_ids == i) & (conform_ids == c) & (steps_ids == 0))[0]
                        assert len(orig_idx) == 1
                        orig_id = np.array(md['groups_ids']['id'])[orig_idx]
                        orig_mol = source.get(int(orig_id[0]))
                        row_data = {
                                "orig_R": orig_mol.positions,
                                "rmsd": np.array([_rmsd_list[samp_idx]])
                            }
                        target.write(rows[samp_idx].toatoms(),
                                    group_id=g,
                                    stereo_iso_id=i,
                                    conform_id=c,
                                    step_id=steps_ids[idx][samp_idx],
                                    orig_id=ids[samp_idx],
                                    data=row_data)
            break # !!!!!!!!!!!!!!!
                                
    with ase.db.connect(db_path) as source, ase.db.connect(target_path) as target:
        target.metadata =  {
            "split_path": split_path,
            "orig_db": db_path,
            "groups_ids": source.metadata['groups_ids'],
            "_property_unit_dict": {
                "orig_R": "Ang",
                "rmsd": "Ang"
            },
            "_distance_unit": "Ang",
            "atomrefs": []
        }

def _copy_to_workdir():
        if not os.path.exists(data_workdir):
            os.makedirs(data_workdir, exist_ok=True)
        name = target_path.split("/")[-1]
        datapath = os.path.join(data_workdir, name)
        denoised_data_path = os.path.join(data_workdir, denoised_db_path.split("/")[-1])
        
        if os.path.exists(denoised_data_path):
            raise ValueError(f"Target path for denoised systems: {denoised_data_path} already exists.")
        
        lock = fasteners.InterProcessLock(
            os.path.join(data_workdir, f"dataworkdir_{name}.lock")
        )
        with lock:
            # retry reading, in case other process finished in the meantime
            if not os.path.exists(datapath):
                shutil.copy(target_path, datapath)
                
        return datapath, denoised_data_path

def wrapper_rmsd(batch, reference):
    res = []
    for m in batch[properties.idx_m].unique():
        mask = batch[properties.idx_m] == m
        R = batch[properties.R][mask].detach().cpu().numpy()
        R_0 = reference[mask].detach().cpu().numpy()
        Z = batch[properties.Z][mask].detach().cpu().numpy()
        ref_mol = Atoms(positions=R_0, numbers=Z)
        mol = Atoms(positions=R, numbers=Z)
        try:
            diff = rmsd(ref_mol, mol)
        except:
            logger.warning("RMSD calculation failed and rmsd set to NaN. Maybe very different structures !")
            diff = torch.nan
        
        res.append(diff)
    return torch.tensor(res)
        
if __name__ == "__main__":
    
    if os.path.exists(denoised_db_path) and not continue_old:
        raise ValueError(f"Target path for denoised systems: {denoised_db_path} already exists.")
    
    prepare_noisy_db()
    
    if continue_old and data_workdir is not None:
        raise ValueError("continue_old and using a data_workdir are not supported.")
    
    if data_workdir is not None:
        logger.info(f"copying data to workdir {data_workdir}")
        old_target_path = target_path
        old_denoised_db_path = denoised_db_path
        target_path, denoised_db_path = _copy_to_workdir()        
    
    if continue_old and os.path.exists(denoised_db_path):
        with ase.db.connect(denoised_db_path) as conn:
            current_batch = (conn.count() // batch_size) + 1
            logger.warning(f"Continuing from batch {current_batch}")
    else:
        current_batch = 0
    
    with ase.db.connect(denoised_db_path) as conn:
        conn.metadata = {
            "path_noisy_db": target_path,
            "model_path": model_path,
        }
    
    dataset = ASEAtomsData(
                target_path,
                transforms = transforms,
            )

    dataset_loader = AtomsLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=True,
    )
    
    model = torch.load(model_path, map_location=torch.device('cpu') if use_cpu else None)
    
    model.eval()
    
    for batch in tqdm(itertools.islice(dataset_loader, current_batch, None)):
                
        if predict_time:
            denoi_R, history_R, steps, num_steps, *rest = sample_R_time(
                batch,
                model,
                noise_schedule,
                cutoff=cutoff,
                T=T,
                start=-1,
                random=False,
                use_forces=True,
                save_progress=True,
                progress_stride=1,
                use_cpu=use_cpu,
                recompute_neighbors=False,
                aggregate_atomwise=False,
                max_steps=max_steps,
                convergence_step=0,
                check_stability=False,
                bonds_data=None,
                min_steps=0,
                return_stability=False
            )
        else:
            denoi_R, history_R, *rest = sample_R(
                batch,
                model,
                noise_schedule,
                cutoff=cutoff,
                T=T,
                start=start_step,
                random=False,
                use_forces=True,
                save_progress=True,
                progress_stride=1,
                use_cpu=use_cpu,
                recompute_neighbors=False,
                use_orig=True
            )
        
        batch[properties.R] = denoi_R
        
        rmsd_list = wrapper_rmsd(batch, batch["orig_R"])

        bonds, stable_atoms, stable_molecules, \
        connected, stable_atoms_wo_h, stable_molecules_wo_h, \
        connected_wo_h = check_validity(
            batch, *generate_bonds_data().values(), progress_bar=False
        )
        
        rmsd_ratio = rmsd_list / batch["rmsd"]
        
        if predict_time and num_steps is not None:
            logger.info(f"Fraction of converged denoising: {(num_steps < max_steps).float().mean().item():.3f}")
        logger.info(f"Fraction of stable molecules: {np.mean(stable_molecules):.3f}")
        logger.info(f"New batch denoised. new RMSD: mean {rmsd_list.mean().item():.3f}, std {rmsd_list.std().item():.3f}, median {rmsd_list.median().item():.3f}"
                     f", min {rmsd_list.min().item():.3f}, max {rmsd_list.max().item():.3f}")
        logger.info(f"Ratio of old and new RMSD denoising: "
                     f"mean {rmsd_ratio.mean().item():.3f}, median: {rmsd_ratio.median().item():.3f}, "
                     f"std {rmsd_ratio.std().item():.3f}, min {rmsd_ratio.min().item():.3f}, max {rmsd_ratio.max().item():.3f}")
        
        for i, mol_idx in enumerate(batch[properties.idx_m].unique(sorted=False)):
            
            mask = batch[properties.idx_m] == mol_idx
            
            mol = Atoms(
                positions=batch[properties.R][mask].detach().cpu().numpy(),
                numbers=batch[properties.Z][mask].detach().cpu().numpy()
            )
            
            data = {
                "rmsd": rmsd_list[i].item(),
                "orig_rmsd": batch["rmsd"][i].item(),
                "history_R": history_R[:,mask,:].detach().cpu().numpy() if history_R is not None else None,
                "steps": steps[:,mask].detach().cpu().numpy() if predict_time and steps is not None else None,
                "num_steps": num_steps[i].cpu().numpy() if predict_time and num_steps is not None else None,
                "stable": int(stable_molecules[i]),
            }
            
            with ase.db.connect(denoised_db_path) as conn:                
                conn.write(
                    mol,
                    orig_id=batch[properties.idx][i].item()+1,
                    rmsd=rmsd_list[i].item(),
                    data=data
                )
                
        logger.info(f"Done writing batch in DB.")
        
    if data_workdir is not None:
        logger.info(f"Copying denoised DB back to {old_denoised_db_path}")
        shutil.copy(denoised_db_path, old_denoised_db_path)
        logger.info(f"removing {denoised_db_path}")
        os.remove(denoised_db_path)
                
    
    