import torch
from torch.nn import functional as F

from equivariant_diffusion import utils as diffusion_utils
from geo_ldm.encoder import EGNN_encoder
from geo_ldm.decoder import EGNN_decoder
from equivariant_diffusion.en_diffusion import gaussian_KL, gaussian_KL_for_dimension
from bond_type_prediction.losses import adjacency_matrix_loss, atom_types_and_formal_charges_loss
from bond_type_prediction.utils import compute_class_weight


class EnHierarchicalVAE(torch.nn.Module):
    """
    The E(n) Hierarchical VAE Module.
    """
    def __init__(
            self,
            encoder: EGNN_encoder,
            decoder: EGNN_decoder,
            in_node_nf: int, n_dims: int, latent_node_nf: int,
            kl_weight: float,
            norm_values=(1., 1., 1.), norm_biases=(None, 0., 0.), 
            include_atomic_numbers=False,
            num_edge_types=5,
            noise_sigma=None,
            use_focal_loss=False,
            n_formal_charges=3):
        super().__init__()

        self.include_atomic_numbers = include_atomic_numbers

        self.encoder = encoder
        self.decoder = decoder

        self.in_node_nf = in_node_nf
        self.n_dims = n_dims
        self.latent_node_nf = latent_node_nf
        self.n_formal_charges = n_formal_charges
        self.num_classes = self.in_node_nf - self.include_atomic_numbers - n_formal_charges # one hot formal charges are always included
        self.kl_weight = kl_weight

        self.norm_values = norm_values
        self.norm_biases = norm_biases
        self.num_edge_types = num_edge_types
        print(f'noise_sigma: {noise_sigma}')
        self.noise_sigma = noise_sigma
        self.use_focal_loss = use_focal_loss
        if self.use_focal_loss:
            print('Using focal loss')
        if noise_sigma is not None:
            # if training with noise, freeze encoder
            self.encoder = self.encoder.eval()
            self.encoder.train = lambda model: model
            for param in self.encoder.parameters():
                param.requires_grad = False

        self.is_encoder_frozen = False

        self.register_buffer('buffer', torch.zeros(1))

    def prepare_class_weights(self, dataset_info, device):
        print('Preparing class weights...')
        self.class_weight_dict = {}
        for target in ['atom_types', 'formal_charges', 'edge_types']:
            self.class_weight_dict[target] = torch.Tensor(dataset_info[f'class_weights_{target}']).to(device)
        print(f'Got class weights: {self.class_weight_dict}')

        #self.class_weight_dict = compute_class_weight(dataloader, dataset_info, recompute_class_weight=recompute_class_weight)
        #self.class_weight_dict = {key: value.to(device) for key, value in self.class_weight_dict.items()}

    def subspace_dimensionality(self, node_mask):
        """Compute the dimensionality on translation-invariant linear subspace where distributions on x are defined."""
        number_of_nodes = torch.sum(node_mask.squeeze(2), dim=1)
        return (number_of_nodes - 1) * self.n_dims

    def compute_reconstruction_error(self, xh_rec, xh):
        """Computes reconstruction error."""

        bs, n_nodes, dims = xh.shape

        # Error on positions.
        x_rec = xh_rec[:, :, :self.n_dims]
        x = xh[:, :, :self.n_dims]
        error_x = diffusion_utils.sum_except_batch((x_rec - x) ** 2)
        
        # Error on classes.
        h_cat_rec = xh_rec[:, :, self.n_dims:self.n_dims + self.num_classes]
        h_cat = xh[:, :, self.n_dims:self.n_dims + self.num_classes]
        h_cat_rec = h_cat_rec.reshape(bs * n_nodes, self.num_classes)
        h_cat = h_cat.reshape(bs * n_nodes, self.num_classes)
        error_h_cat = F.cross_entropy(h_cat_rec, h_cat.argmax(dim=1), reduction='none')
        error_h_cat = error_h_cat.reshape(bs, n_nodes, 1)
        error_h_cat = diffusion_utils.sum_except_batch(error_h_cat)
        # error_h_cat = sum_except_batch((h_cat_rec - h_cat) ** 2)

        # Error on charges.
        if self.include_charges:
            h_int_rec = xh_rec[:, :, -self.include_charges:]
            h_int = xh[:, :, -self.include_charges:]
            error_h_int = diffusion_utils.sum_except_batch((h_int_rec - h_int) ** 2)
        else:
            error_h_int = 0.
        
        error = error_x + error_h_cat + error_h_int

        if self.training:
            denom = (self.n_dims + self.in_node_nf) * xh.shape[1]
            error = error / denom

        return error

    # TODO: rethink stuff here: how to reduce, how to deal with dummy preds, etc.
    def compute_2d_reconstruction_error(self, adj_rec, h_rec, adj, h, node_mask, edge_mask):
        # Error on edge types
        error_edges = adjacency_matrix_loss(adj_rec, adj, edge_mask, n_classes=self.num_edge_types, reduction='mean', weight=self.class_weight_dict['edge_types'], use_focal_loss=self.use_focal_loss)

        # Error on atom types and formal charges
        h_atomic_numbers_gt = h['atomic_numbers_one_hot'].long()
        h_formal_charges_gt = h['formal_charges_one_hot'].long()
        error_atom_types, error_formal_charges = atom_types_and_formal_charges_loss(h_rec, h_atomic_numbers_gt, h_formal_charges_gt, node_mask, weight_dict=self.class_weight_dict, use_focal_loss=self.use_focal_loss)

        error = error_edges + error_atom_types + error_formal_charges
        return error
    
    def sample_normal(self, mu, sigma, node_mask, fix_noise=False):
        """Samples from a Normal distribution."""
        bs = 1 if fix_noise else mu.size(0)
        eps = self.sample_combined_position_feature_noise(bs, mu.size(1), node_mask)
        return mu + sigma * eps

    def inject_noise(self, x, h, node_mask, p_mol, p_atom, sigma_max):
        """
        Args:
            p_mol (float): probbility to corrupt a moelcule in the batch
            p_atom (float): probability to corrupt an atom in a corrupted molecule
        """
        bs, n_nodes, _ = x.size()

        # every molecule is corrupted with probability p_mol
        mols_corrupted = torch.bernoulli(p_mol * torch.ones((bs,)))
        # in the corrupted moelcules, every atom is corrupted with probability p_atom
        atoms_corrupted_x = torch.bernoulli(p_atom / 2 * torch.ones((bs, n_nodes,))) * mols_corrupted.unsqueeze(1)
        atoms_corrupted_h = torch.bernoulli(p_atom / 2 * torch.ones((bs, n_nodes,))) * mols_corrupted.unsqueeze(1)

        Sigma_x = torch.empty((bs, n_nodes,)).uniform_(0, sigma_max) * atoms_corrupted_x
        Sigma_h = torch.empty((bs, n_nodes,)).uniform_(0, sigma_max) * atoms_corrupted_h

        Sigma_x = Sigma_x.unsqueeze(2).to(x.device)
        Sigma_h = Sigma_h.unsqueeze(2).to(x.device)

        eps_x = diffusion_utils.sample_center_gravity_zero_gaussian_with_mask(
            size=(bs, n_nodes, self.n_dims), device=node_mask.device, node_mask=node_mask)
        eps_h = diffusion_utils.sample_gaussian_with_mask(
            size=(bs, n_nodes, self.latent_node_nf), device=node_mask.device, node_mask=node_mask)

        x_corrupted = torch.sqrt(1 - Sigma_x*Sigma_x) * x + Sigma_x * eps_x
        h_corrupted = torch.sqrt(1 - Sigma_h*Sigma_h) * h + Sigma_h * eps_h

        x_corrupted = diffusion_utils.remove_mean_with_mask(x_corrupted, node_mask)

        return torch.cat([x_corrupted, h_corrupted], dim=2)
    
    def compute_loss(self, x, h, adj_gt, node_mask, edge_mask, context):
        """Computes an estimator for the variational lower bound."""

        # Encoder output.
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.encode(x, h, node_mask, edge_mask, context)
        
        if self.kl_weight > 0.:
            # KL distance.
            # KL for invariant features.
            zeros, ones = torch.zeros_like(z_h_mu), torch.ones_like(z_h_sigma)
            loss_kl_h = gaussian_KL(z_h_mu, ones, zeros, ones, node_mask)
            # KL for equivariant features.
            assert z_x_sigma.mean(dim=(1,2), keepdim=True).expand_as(z_x_sigma).allclose(z_x_sigma, atol=1e-7)
            zeros, ones = torch.zeros_like(z_x_mu), torch.ones_like(z_x_sigma.mean(dim=(1,2)))
            subspace_d = self.subspace_dimensionality(node_mask)
            loss_kl_x = gaussian_KL_for_dimension(z_x_mu, ones, zeros, ones, subspace_d)
            loss_kl = loss_kl_h + loss_kl_x

        if self.noise_sigma is None:
            # Infer latent z.
            z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
            diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
            z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
            z_xh = self.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
            # z_xh = z_xh_mean
        else:
            diffusion_utils.assert_correctly_masked(z_x_mu, node_mask)
            diffusion_utils.assert_correctly_masked(z_h_mu, node_mask)
            z_xh = self.inject_noise(z_x_mu, z_h_mu, node_mask, p_mol=0.75, p_atom=0.2, sigma_max=self.noise_sigma)

            z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
            z_xh = self.sample_normal(z_xh, z_xh_sigma, node_mask)

            z_xh = z_xh.detach() # keep encoder fixed during robust decoder training

        diffusion_utils.assert_correctly_masked(z_xh, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(z_xh[:, :, :self.n_dims], node_mask)

        # Decoder output (reconstruction).
        # Get logits
        adj_recon, h_recon = self.decoder._forward(z_xh, node_mask, edge_mask, context)
        loss_recon = self.compute_2d_reconstruction_error(adj_recon, h_recon, adj_gt, h, node_mask, edge_mask)

        if self.kl_weight > 0.:
            # Combining the terms
            assert loss_recon.size() == loss_kl.size()
            loss = loss_recon + self.kl_weight * loss_kl
        else:
            loss = loss_recon

        # Should fail for now
        #assert len(loss.shape) == 1, f'{loss.shape} has more than only batch dim.'

        return loss, {'loss_t': loss.squeeze(), 'rec_error': loss_recon.squeeze()}

    def forward(self, x, h, adj_gt, node_mask=None, edge_mask=None, context=None):
        """
        Computes the ELBO if training. And if eval then always computes NLL.
        """

        loss, loss_dict = self.compute_loss(x, h, adj_gt, node_mask, edge_mask, context)

        neg_log_pxh = loss

        return neg_log_pxh

    def sample_combined_position_feature_noise(self, n_samples, n_nodes, node_mask):
        """
        Samples mean-centered normal noise for z_x, and standard normal noise for z_h.
        """
        z_x = diffusion_utils.sample_center_gravity_zero_gaussian_with_mask(
            size=(n_samples, n_nodes, self.n_dims), device=node_mask.device,
            node_mask=node_mask)
        z_h = diffusion_utils.sample_gaussian_with_mask(
            size=(n_samples, n_nodes, self.latent_node_nf), device=node_mask.device,
            node_mask=node_mask)
        z = torch.cat([z_x, z_h], dim=2)
        return z
    
    def encode(self, x, h, node_mask=None, edge_mask=None, context=None):
        """Computes q(z|x)."""

        if self.encoder.n_extra_atomic_features > 0:
            xh = torch.cat([x, h['extra_atom_features']], dim=2)
        else:
            # Concatenate x, h[integer] and h[categorical].
            xh = torch.cat([x, h['atomic_numbers_one_hot'], h['formal_charges_one_hot']], dim=2)

        diffusion_utils.assert_mean_zero_with_mask(xh[:, :, :self.n_dims], node_mask)

        # Encoder output.
        # TODO: remove stuff related to z_sigma since we're not using it.
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.encoder._forward(xh, node_mask, edge_mask, context)

        bs, _, _ = z_x_mu.size()
        sigma_0_x = torch.ones(bs, 1, 1).to(z_x_mu) * 0.0032
        sigma_0_h = torch.ones(bs, 1, self.latent_node_nf).to(z_h_mu) * 0.0032

        return z_x_mu, sigma_0_x, z_h_mu, sigma_0_h
    
    def decode(self, z_xh, node_mask=None, edge_mask=None, context=None, valency_check=False):
        """
        Computes p(x|z).
        computes:
            1. edge types as class labels 0, 1, 2, 3, 4
            2. atom types as class labels 0, 1, 2, ... used as indices of the dataset_decoder
            3. formal charges as actual charges -1, 0, 1
        """

        # Decoder output (reconstruction).
        # Get logits
        adj_recon, h_recon = self.decoder._forward(z_xh, node_mask, edge_mask, context)
        diffusion_utils.assert_correctly_masked(h_recon, node_mask)

        if valency_check:
            # TODO: WIP, finish this
            atom_types_recon = h_recon[:, :, :-3]
            formal_charges_recon = h_recon[:, :, -3:]
            return adj_recon, atom_types_recon, formal_charges_recon

        else:
            adj_recon = torch.argmax(adj_recon, -1)

            atom_types_recon = h_recon[:, :, :-self.n_formal_charges]
            atom_types_recon = torch.argmax(atom_types_recon, -1, keepdim=True) * node_mask
            atom_types_recon = atom_types_recon.squeeze().int()

            formal_charges_recon = h_recon[:, :, -self.n_formal_charges:]
            # TODO: remove hard-coded -1 to decode formal charges
            formal_charges_recon = (torch.argmax(formal_charges_recon, -1, keepdim=True) - 1) * node_mask
            formal_charges_recon = formal_charges_recon.squeeze().int()

            return adj_recon, atom_types_recon, formal_charges_recon

    @torch.no_grad()
    def reconstruct(self, x, h, node_mask=None, edge_mask=None, context=None, inject_noise=False):
        # Encoder output.
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.encode(x, h, node_mask, edge_mask, context)

        if not inject_noise:
            # Infer latent z.
            z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
            diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
            z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
            z_xh = self.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
            # z_xh = z_xh_mean
            diffusion_utils.assert_correctly_masked(z_xh, node_mask)
            diffusion_utils.assert_mean_zero_with_mask(z_xh[:, :, :self.n_dims], node_mask)
        else:
            diffusion_utils.assert_correctly_masked(z_x_mu, node_mask)
            diffusion_utils.assert_correctly_masked(z_h_mu, node_mask)
            z_xh = self.inject_noise(z_x_mu, z_h_mu, node_mask, p_mol=1.0, p_atom=0.2, sigma_max=self.noise_sigma)

        # Decoder output (reconstruction).
        adj_recon, atom_types_recon, formal_charges_recon = self.decode(z_xh, node_mask, edge_mask, context)
        return adj_recon, atom_types_recon, formal_charges_recon


    def log_info(self):
        """
        Some info logging of the model.
        """
        info = None
        print(info)

        return info

    def freeze_encoder(self):
        if not self.is_encoder_frozen:
            self.encoder = self.encoder.eval()
            self.encoder.train = lambda x: x
            for param in self.encoder.parameters():
                param.requires_grad = False

            self.is_encoder_frozen = True
