from typing import Dict
import torch
from torch import nn
from schnetpack import properties

import schnetpack.properties as properties
from schnetpack.interfaces import AtomsConverter
import schnetpack.transform as trn
from ase import Atoms

from schnetpack.diffusion import NoiseSchedule
from schnetpack.diffusion.utils import batch_center_systems


class BatchSubtractCenterOfMass(trn.Transform):

    is_preprocessor: bool = False
    is_postprocessor: bool = True
    force_apply: bool = True

    def __init__(
        self,
        name: str = "eps_pred",
        R_dim: int = 3,
    ):
        super().__init__()
        self.name = name
        self.R_dim = R_dim

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        if inputs[self.name].shape[1] < self.R_dim:
            raise ValueError(
                f"Property {self.name} has less than {self.R_dim} dimensions. Cannot subtract center of mass."
            )
        if inputs[self.name].shape[-1] == self.R_dim:
            inputs[self.name] = batch_center_systems(
                inputs[self.name], inputs[properties.idx_m], inputs[properties.n_atoms]
            )
        else:
            x = inputs[self.name][
                :, : self.R_dim
            ]  # can this be done more efficient without in place op that discards autogrd?
            h = inputs[self.name][:, self.R_dim :]
            x_cent = batch_center_systems(
                x, inputs[properties.idx_m], inputs[properties.n_atoms]
            )
            inputs[self.name] = torch.cat((x_cent, h), dim=-1).to(
                device=inputs[self.name].device
            )
        return inputs


class Diffuse(trn.Transform):
    """
    Diffuse the input systems.
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(
        self,
        noise_schedule: NoiseSchedule,
        diffuse_z: bool = True,
        diffuse_all: bool = False,
        use_forces: bool = False,
        t_train: int = None,
        include_t: torch.Tensor = None,
        exclude_eps_0: bool = False,
        per_atom_step: bool = True,
    ):
        super().__init__()
        self.noise_schedule = noise_schedule
        self.T = noise_schedule.T
        self.diffuse_z = diffuse_z
        self.diffuse_all = (
            diffuse_all  # diffuse for all time steps but train on only one.
        )
        self.use_forces = use_forces
        if t_train is not None:
            self.t_train = torch.tensor(
                [t_train]
            )  # if None then train on randomly picked time step.
        else:
            self.t_train = None
        self.include_t = torch.tensor(include_t) if include_t is not None else None
        self.exclude_eps_0 = exclude_eps_0  # exclude t_0 from time sampling if training on VLB which always include L_0 term.
        self.per_atom_step = per_atom_step

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        device = inputs[properties.R].device
        x = inputs[properties.R]
        if self.diffuse_z:
            h = inputs[properties.Z].float().unsqueeze(-1)

        # Sample timesteps.
        # !!!!! in the E3D paper they sampled from (0..T) ND NOT (1..T-1) like here but for t=0 they compute L_0 and not L_t (mse loss) and L_T has also other loss than L_t right?
        # they sample till T and not T-1 because they compute alpha_t then by using s=t-1 so they need more steps (not sure)
        # in the noise schedule I implemented alpha_0 doesnt exist it strts with alpha_1 but in E3D the noise schedule precomputed has more entries.
        # meaning t = 0 will return alpha_1 not alpha_0 and t = T-1 will return alpha_T
        # in E3D in sampling they iterate from t in T to 1 and define s=t-1 meaning from T to 0 but then they compute alpha_t|s (meaning in the end they iterate over alpha_T .. alpha_1) which is in this imp defined as alpha_t and is precomputed
        # hence the code here in sampling goes from T-1 to 0 meaning from alpha_T to alpha_1. But in E3D in the end of sampling is z_0 from which x,h will be sampled but in other diff imp the last step is x_0 which is the original sample.
        if self.diffuse_all:
            self.include_t = torch.arange(self.T, dtype=torch.long)
        if self.t_train is not None:
            if self.exclude_eps_0 and (self.t_train == 0).any():
                raise ValueError(
                    "set 't_train' to something different than 0 when 'exclude_eps_0 = True' !"
                )
            t = self.t_train
        elif self.exclude_eps_0:
            t = torch.randint(1, self.T, size=(1,), dtype=torch.long)
        else:
            t = torch.randint(0, self.T, size=(1,), dtype=torch.long)

        # the first timestep is always the one to be traiend on.
        idx_t_train = 0

        if self.include_t is not None:
            t = torch.concat((t, self.include_t))

        # always compute eps_0 to be used for VLB as training loss or metric.
        t = torch.concat((t, torch.tensor([0])))

        t = t.to(device=device)

        t_norm = t.float() / self.T  # normalize timesteps to [0,1]

        # Compute noise parameters.
        noise_params = self.noise_schedule(t_norm)
        sqrt_beta_bar = noise_params["sqrt_beta_bar"].unsqueeze(-1).unsqueeze(-1)
        sqrt_alpha_bar = noise_params["sqrt_alpha_bar"].unsqueeze(-1).unsqueeze(-1)

        # Diffuse positions R
        # sample gaussian noise N(0,I).
        eps_x = torch.randn((sqrt_alpha_bar.shape[0],) + x.shape, device=device)
        # project positions to subspace defined by zero center of geometry (translation invariance).
        eps_x -= eps_x.mean(1).unsqueeze(1)
        # diffuse using reparametrization trick.
        z_x_t = sqrt_alpha_bar * x + sqrt_beta_bar * eps_x
        eps = eps_x

        # use forces trick: noise is the opposite of froces, meaning it s the derivative of energy wrt positions
        if self.use_forces:
            eps = -eps

        # Diffuse atomic numbers Z if specified
        if self.diffuse_z:
            eps_h = torch.randn((sqrt_alpha_bar.shape[0],) + h.shape, device=device)
            z_h_t = sqrt_alpha_bar * h + sqrt_beta_bar * eps_h
            z_h_t = torch.where(
                torch.round(z_h_t).long() < 0, 0, torch.round(z_h_t).long()
            )
            eps = torch.cat((eps, eps_h), dim=-1)

        # inputs['alphas_bar'] = (noise_params['alpha_bar']).unsqueeze(0)
        diff_step_mol = t[idx_t_train]
        # inputs['diff_step_mol'] = diff_step_mol
        # inputs['disc_diff_step_mol'] = inputs['diff_step_mol']
        # inputs['diff_step_mol'] = (inputs['diff_step_mol'] / self.T).float()
        # inputs['disc_diff_step'] = diff_step_mol.repeat(inputs[properties.n_atoms])
        inputs["t"] = t.unsqueeze(0)
        inputs["diff_step"] = (
            (diff_step_mol / self.T).float()
        )
        if self.per_atom_step:
            inputs["diff_step"] = inputs["diff_step"].repeat(inputs[properties.n_atoms])
        else:
            inputs["diff_step"] = inputs["diff_step"].unsqueeze(0)
                
        inputs["eps_all"] = eps.transpose(1, 0)
        inputs["eps"] = inputs["eps_all"][:, idx_t_train].squeeze(1)
        inputs["all_diff_R"] = z_x_t.transpose(1, 0)
        inputs["original_R"] = inputs[properties.R]
        inputs[properties.R] = z_x_t[idx_t_train].squeeze(0)
        if self.diffuse_z:
            inputs["all_diff_Z"] = z_h_t.squeeze(-1).T
            inputs["original_Z"] = inputs[properties.Z]
            inputs[properties.Z] = z_h_t[idx_t_train].squeeze(-1).squeeze(0)

        return inputs
