from typing import Dict

from torch import nn
import torch

from src.schnetpack.diffusion.utils import *
from src.schnetpack.diffusion.noise_schedule import NoiseSchedule
import schnetpack.properties as properties
from schnetpack.interfaces import AtomsConverter
import schnetpack.transform as trn
from ase import Atoms


class Diffuse(nn.Module):
    """
    Diffuse the input systems.
    """

    def __init__(self, noise_schedule: NoiseSchedule, diffuse_z: bool = True):
        super().__init__()
        self.noise_schedule = noise_schedule
        self.T = noise_schedule.T
        self.diffuse_z = diffuse_z

    @torch.no_grad()
    def sample(
        self, model, R, n_samples: int = 1, n_atoms: int = 19
    ) -> Dict[str, torch.Tensor]:
        # !!!!!! still need to specify the device for the newly created tensors here ?
        systems = {}
        self.model = model

        systems[properties.idx_m] = torch.arange(n_samples).repeat_interleave(n_atoms)
        if isinstance(n_atoms, int):
            systems[properties.n_atoms] = torch.tensor([n_atoms] * n_samples)
        else:
            assert len(n_atoms) == n_samples
            systems[properties.n_atoms] = torch.tensor(n_atoms)

        # !!!!! redundant code with the forward function
        eps_h = torch.randn(n_samples * n_atoms, 1)
        zt = eps_h
        systems[properties.R] = R
        systems[properties.Z] = torch.where(
            torch.round(eps_h.squeeze(-1)).long() < 0,
            0,
            torch.round(eps_h.squeeze(-1)).long(),
        )
        print("systems infos before converting")
        for key in systems.keys():
            print(key, systems[key][:5])
        converter = AtomsConverter(
            neighbor_list=trn.ASENeighborList(cutoff=5.0), dtype=torch.float32
        )
        print("converter input", systems[properties.Z], systems[properties.R])
        atoms = Atoms(numbers=systems[properties.Z], positions=systems[properties.R])
        systems = converter(atoms)
        print("systems added info after converting")
        for key in systems.keys():
            print(key, systems[key][:5])

        molecules = torch.zeros((self.T,) + zt.size())

        # here the sampling loop from T-1 to 0 meaning from alpha_T to alpha_1 as noise schedule is defined here
        # this is the same as implemented in the paper improved diff. in E3D they also iterate for s from T-1 to 0
        # and define t=s+1 and then define alpha_t|s = alpha t/ alpha s and in the end they do an extra sampling step from x,h | z_0
        # while in other diff imps the x_0 (meaning z_0) is returned as original sample
        for s in range(self.T - 1, -1, -1):
            print("********************************** iter ", s)
            if torch.isnan(zt).any():
                print("nan in zt")
                print("zt nan", zt[torch.isnan(zt)], zt[torch.isnan(zt)].shape)
                exit
            else:
                print("zt", zt[:5])
            s_array = torch.full((n_samples, 1), fill_value=s)
            t_array = s_array + 1  # step s = t-1
            s_array = s_array / self.T
            t_array = t_array / self.T
            print("time steps s, t", s_array[:5].T, t_array[:5].T)
            alpha_t_bar = self.noise_schedule(t_array)[systems[properties.idx_m]]
            alpha_s_bar = self.noise_schedule(s_array)[systems[properties.idx_m]]
            print("alpha s t bar", alpha_s_bar[:5].T, alpha_t_bar[:5].T)
            alpha_t = alpha_t_bar / alpha_s_bar
            beta_t = (1 - alpha_t) ** 0.5
            alpha_t = alpha_t**0.5
            print("alpha beta", alpha_t[:5].T, beta_t[:5].T)

            print("rounded zt", torch.round(zt.squeeze(-1)))
            systems[properties.Z] = torch.where(
                torch.round(zt.squeeze(-1)).long() < 0,
                0,
                torch.round(zt.squeeze(-1)).long(),
            )
            print("input NN R Z", systems[properties.R], systems[properties.Z])
            eps_t = self.model(systems)["eps_pred"]
            if torch.isnan(eps_t).any():
                print("nan in eps_t")
                print(
                    "eps_t nan",
                    eps_t[torch.isnan(eps_t)][:5],
                    eps_t[torch.isnan(eps_t)].shape,
                    eps_t.shape,
                )
                exit
            else:
                print("eps_t", eps_t[:5])

            # !!!!!! should this be zt or systems[properties.Z] because normally we pass zt to the network but we passed systems[properties.Z]
            mu = (systems[properties.Z].unsqueeze(-1).float() / alpha_t) - (
                beta_t**2 / alpha_t / (1 - alpha_t_bar) ** 0.5
            ) * eps_t
            print(
                "first part mu", systems[properties.Z].unsqueeze(-1).float() / alpha_t
            )
            print("second part mu", (beta_t**2 / alpha_t / (1 - alpha_t_bar) ** 0.5))
            if torch.isnan(mu).any():
                print("nan in mu")
                print("mu nan", mu[torch.isnan(mu)][:5], mu[torch.isnan(mu)].shape)
                exit
            else:
                print("mu", mu)

            # Compute sigma for p(zs | zt).
            # as in diff papers for images no noise when t=0 meaning for alpha_1 because x_0 will be the original end output meaning the mean and no noise
            # however this doesnt exist in E3D imp because s=0 is z_0 meaning still not the original x,h which will be then sampled after the loop from z_0
            # normally for this imp here for s=0 the final x,h are sampled
            if s == 0:
                sigma = torch.zeros_like(mu)
            else:
                sigma = beta_t * (1 - alpha_s_bar) ** 0.5 / (1 - alpha_t_bar) ** 0.5
            print(
                "alpha part sigma",
                ((1 - alpha_s_bar) ** 0.5 / (1 - alpha_t_bar) ** 0.5)[:5],
            )
            if torch.isnan(sigma).any():
                print("nan in sigma")
                print(
                    "sigma nan",
                    sigma[torch.isnan(sigma)][:5],
                    sigma[torch.isnan(sigma)].shape,
                )
                exit
            else:
                print("sigma", sigma[:5])

            # Sample zs given the paramters derived from zt.
            # !!!!! redundant code with the forward function
            eps_h = torch.randn(n_samples * n_atoms, 1)
            print("eps_h", eps_h)

            zs = mu + sigma * eps_h
            if torch.isnan(zs).any():
                print("nan in zs")
                print("zs nan", zs[torch.isnan(zs)][:5], zs[torch.isnan(zs)].shape)
                exit
            else:
                print("zs", zs)

            if torch.isnan(zs).any():
                print("nan in zs after centering")
                print("zs nan", zs[torch.isnan(zs)][:5], zs[torch.isnan(zs)].shape)
                exit
            else:
                print("zs after centering", zs[:5])

            molecules[s] = zs

            zt = zs  # next step

        return zt, systems

        # normally for this imp not like E3D the final x,h are sampled for for s=0
        # Finally sample p(x, h | z_0).
        # x, h = self.sample_p_xh_given_z0(z, node_mask, edge_mask, context)

        # xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
        # chain[0] = xh  # Overwrite last frame with the resulting x and h.

        # chain_flat = chain.view(n_samples * keep_frames, *z.size()[1:])
