import torch
import os
import pickle
import logging

from tqdm import tqdm
import numpy as np
from ase import Atoms
from pytorch_lightning.callbacks import Callback
from schnetpack.diffusion.utils import sample_R, sample_R_time, sample_Z, rmsd
from schnetpack.diffusion.sampling_analysis import check_validity, generate_bonds_data
from schnetpack.diffusion.noise_schedule import NoiseSchedule
from schnetpack import properties


class SamplerCallback(Callback):
    def __init__(
        self,
        noise_schedule: NoiseSchedule = None,
        sampled_property: str = "R",
        cutoff: float = 20.0,
        store_path: str = "samples",
        every_n_batchs: int = 1,
        every_n_epochs: int = 1,
        save_progress: bool = True,
        progress_stride: int = 1,
        use_forces: bool = True,
        random=True,
        start=None,
        T: int = 1000,
        id: str = None,
        use_cpu: bool = True,
        check_conn: bool = True,
        log_rmsd: bool = False,
        relax_coef: float = 1.17,
        name: str = "sampling",
        size_intern_batch: int = 2048,
        start_epoch: int = 1,
        recompute_neighbors: bool = False,
        bonds_data_path: str = None,
        save_predictions: list = None,
        # below parameters are specific for models predicting time
        predict_time: bool = False,
        aggregate_atomwise: bool = False,
        max_steps: int = 8000,
        convergence_step: int = 0,
        check_stability: bool = True,
        min_steps: int = 0,
    ):
        self.store_path = store_path

        if sampled_property == "R" and noise_schedule is None:
            raise ValueError("Noise schedule must be provided for sampling R")
        self.noise_schedule = noise_schedule

        self.sampled_property = sampled_property
        self.cutoff = cutoff
        self.every_n_batchs = every_n_batchs
        self.every_n_epochs = every_n_epochs
        self.save_progress = save_progress
        self.progress_stride = progress_stride
        self.use_forces = use_forces
        self.random = random
        self.start = start
        self.T = T
        self.id = id
        self.samples = {}
        self.use_cpu = use_cpu
        self.check_conn = check_conn
        self.log_rmsd = log_rmsd
        self.relax_coef = relax_coef or 1.1
        self.name = name
        self.size_intern_batch = size_intern_batch
        self.intern_batch = None
        self.start_epoch = start_epoch
        self.recompute_neighbors = recompute_neighbors
        self.bonds_data = generate_bonds_data(bonds_data_path)
        self.predict_time = predict_time
        self.aggregate_atomwise = aggregate_atomwise
        self.max_steps = max_steps
        self.convergence_step = convergence_step
        self.check_stability = check_stability
        self.min_steps = min_steps
        self.save_predictions = save_predictions

        if not os.path.exists(self.store_path):
            os.makedirs(self.store_path)

    
    def wrapper_rmsd(self, batch, reference):
        res = []
        for m in batch[properties.idx_m].unique():
            mask = batch[properties.idx_m] == m
            R = batch[properties.R][mask].detach().cpu().numpy()
            R_0 = reference[mask].detach().cpu().numpy()
            Z = batch[properties.Z][mask].detach().cpu().numpy()
            ref_mol = Atoms(positions=R_0, numbers=Z)
            mol = Atoms(positions=R, numbers=Z)
            try:
                diff = rmsd(ref_mol, mol)
            except:
                logging.warning("RMSD calculation failed and rmsd set to NaN. Maybe very different structures !")
                diff = torch.nan
            
            res.append(diff)
        return torch.tensor(res)
                        
    
    def sample(self, pl_module, batch):

        if self.use_cpu:
            batch = {p: batch[p].cpu() for p in batch}
            model = pl_module.model.cpu()
        else:
            model = pl_module.model

        results = {}

        if self.sampled_property == "R":

            if self.predict_time:
                samples, progress, time_steps, num_steps, _ = sample_R_time(
                    batch,
                    model,
                    self.noise_schedule,
                    cutoff=self.cutoff,
                    T=self.T,
                    start=self.start,
                    random=self.random,
                    use_forces=self.use_forces,
                    save_progress=self.save_progress,
                    progress_stride=self.progress_stride,
                    use_cpu=self.use_cpu,
                    recompute_neighbors=self.recompute_neighbors,
                    aggregate_atomwise=self.aggregate_atomwise,
                    max_steps=self.max_steps,
                    convergence_step=self.convergence_step,
                    check_stability=self.check_stability,
                    bonds_data=self.bonds_data,
                    min_steps=self.min_steps,
                    return_stability=False,
                )

                results.update(
                    {
                        "predict_time": True,
                        "time_steps": time_steps.float().cpu()
                        if time_steps is not None
                        else None,
                        "num_steps": num_steps.float().cpu(),
                    }
                )

            else:
                samples, progress = sample_R(
                    batch,
                    model,
                    self.noise_schedule,
                    cutoff=self.cutoff,
                    T=self.T,
                    start=self.start,
                    random=self.random,
                    use_forces=self.use_forces,
                    save_progress=self.save_progress,
                    progress_stride=self.progress_stride,
                    use_cpu=self.use_cpu,
                    recompute_neighbors=self.recompute_neighbors,
                )

                results.update({"predict_time": False})

            results.update(
                {
                    properties.R: samples.cpu(),
                    "progress_R": progress.cpu() if progress is not None else None,
                    properties.Z: batch[properties.Z].cpu(),
                    properties.idx_m: batch[properties.idx_m].cpu(),
                    properties.idx: batch[properties.idx].cpu(),
                    properties.n_atoms: batch[properties.n_atoms].cpu(),
                }
            )

        elif self.sampled_property == "Z":
            samples, progress, preds = sample_Z(
                batch,
                model,
                cutoff=self.cutoff,
                T=self.T,
                start=self.start,
                random=self.random,
                save_progress=self.save_progress,
                save_predictions=self.save_predictions,
                progress_stride=self.progress_stride,
                use_cpu=self.use_cpu,
                recompute_neighbors=self.recompute_neighbors,
            )

            results.update({"predict_time": False})

            results.update(
                {
                    properties.R: batch[properties.R].cpu(),
                    properties.Z: samples.cpu(),
                    "progress_Z": progress.cpu() if progress is not None else None,
                    properties.idx_m: batch[properties.idx_m].cpu(),
                    properties.idx: batch[properties.idx].cpu(),
                    properties.n_atoms: batch[properties.n_atoms].cpu(),
                    "predictions": preds,
                }
            )

        return results

    def save_samples(self, trainer, test=False):
        if len(self.samples) > 0:
            tmp = {
                "samples": self.samples,
                "model_id": self.id,
                "start": self.start,
                "T": self.T,
                "random": self.random,
                "progress_stride": self.progress_stride,
            }
            self.file_idx = 0
            phase = "test" if test else "val"
            while os.path.exists(
                os.path.join(
                    self.store_path,
                    f"samples_{phase}_{trainer.current_epoch}_{self.file_idx}.pt",
                )
            ):
                self.file_idx += 1
            with open(
                os.path.join(
                    self.store_path,
                    f"samples_{phase}_{trainer.current_epoch}_{self.file_idx}.pt",
                ),
                "wb",
            ) as f:
                torch.save(tmp, f)
            self.samples = {}

    def _concat(self, batch):
        if self.intern_batch is None:
            self.intern_batch = {p: batch[p].clone() for p in batch}
        else:
            tmp = self.intern_batch[properties.idx_m]
            n_atoms = len(self.intern_batch[properties.n_atoms])
            drop_p = [
                properties.lidx_i,
                properties.lidx_j,
                properties.idx_i,
                properties.idx_j,
                properties.offsets,
                "one_hot_Z",
            ]
            self.intern_batch = {
                p: torch.cat((self.intern_batch[p], batch[p]), dim=0)
                for p in batch.keys()
                if p not in drop_p
            }
            self.intern_batch[properties.idx_m] = torch.cat(
                (tmp, (batch[properties.idx_m] + n_atoms)), dim=0
            )

    def _step(self, pl_module, batch_idx, test=False):
        
        results = self.sample(pl_module, self.intern_batch)
        metrics = {}
        
        if self.check_conn:
            (
                bonds,
                stable_ats,
                stable_mols,
                connected,
                stable_ats_wo_h,
                stable_mols_wo_h,
                connected_wo_h,
            ) = check_validity(results, *self.bonds_data.values())

            stable_ats = np.concatenate(stable_ats)
            stable_mols = np.array(stable_mols)
            stable_ats_wo_h = np.concatenate(stable_ats_wo_h)
            stable_mols_wo_h = np.array(stable_mols_wo_h)
            connected = np.array(connected)
            connected_wo_h = np.array(connected_wo_h)
            results["bonds"] = bonds
            results["connectivity"] = torch.from_numpy(connected)
            results["stable_atoms"] = torch.from_numpy(stable_ats)
            results["stable_molecules"] = torch.from_numpy(stable_mols)
            results["stable_atoms_wo_h"] = torch.from_numpy(stable_ats_wo_h)
            results["stable_molecules_wo_h"] = torch.from_numpy(stable_mols_wo_h)
            results["connectivity_wo_h"] = torch.from_numpy(connected_wo_h)

            metrics = {
                "frac_stable_atoms": stable_ats.mean(),
                "frac_stable_molecules": stable_mols.mean(),
                "frac_stable_atoms_wo_h": stable_ats_wo_h.mean(),
                "frac_stable_molecules_wo_h": stable_mols_wo_h.mean(),
                "frac_connected_molecules": connected.mean(),
                "frac_connected_molecules_wo_h": connected_wo_h.mean(),
            }
            if "num_steps" in results:
                metrics["avg_num_sampling_steps"] = results["num_steps"].mean()
                metrics["med_num_sampling_steps"] = results["num_steps"].median()
                metrics["std_num_sampling_steps"] = results["num_steps"].std()
                metrics["frac_converged_sampling"] = (
                    (results["num_steps"] < self.max_steps).float().mean()
                )

                
        if self.log_rmsd:
            res_rmsd = self.wrapper_rmsd(results, self.intern_batch['original_R'])
            results["rmsd"] = res_rmsd
            metrics["rmsd"] = res_rmsd.mean()
            
        if metrics:
            for key, val in metrics.items():
                pl_module.log(
                    f"{'test' if test else 'val'}_{self.name}_{key}",
                    val,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                )
            
        self.samples.update({batch_idx: results})
        self.intern_batch = None

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = 0
    ):
        if (
            trainer.current_epoch >= self.start_epoch
            and trainer.current_epoch % self.every_n_epochs == 0
            and batch_idx % self.every_n_batchs == 0
        ):
            self._concat(batch)
            if len(self.intern_batch[properties.n_atoms]) >= self.size_intern_batch:
                self._step(pl_module, batch_idx)
            elif (
                batch_idx == len(trainer.datamodule.val_dataloader()) - 1
                and self.intern_batch is not None
            ):
                self._step(pl_module, batch_idx)

    def on_test_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = 0
    ):
        self._concat(batch)
        if len(self.intern_batch[properties.n_atoms]) >= self.size_intern_batch:
            self._step(pl_module, batch_idx, test=True)
        elif (
            batch_idx == len(trainer.datamodule.test_dataloader()) - 1
            and self.intern_batch is not None
        ):
            self._step(pl_module, batch_idx, test=True)

    def on_validation_end(self, trainer, pl_module):
        if self.intern_batch is not None:
            logging.warn("sampling on validation is over with non empty intern batch !")
        self.intern_batch = None
        self.save_samples(trainer)

    def on_test_end(self, trainer, pl_module):
        if self.intern_batch is not None:
            logging.warn("sampling on test is over with non empty intern batch !")
        self.intern_batch = None
        self.save_samples(trainer, test=True)
