# Copyright 2024 the Llamole team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os 
import yaml
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import diffusion_utils as utils
from .molecule_utils import graph_to_smiles, check_valid
from .transformer import Transformer

class GraphDiT(nn.Module):
    def __init__(
        self,
        model_config_path,
        data_info_path,
        model_dtype,
    ):
        super().__init__()

        dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)

        input_dims = data_info.input_dims
        output_dims = data_info.output_dims
        nodes_dist = data_info.nodes_dist
        active_index = data_info.active_index

        self.model_config = dm_cfg
        self.data_info = data_info
        self.T = dm_cfg.diffusion_steps
        self.guide_scale = dm_cfg.guide_scale
        self.Xdim = input_dims["X"]
        self.Edim = input_dims["E"]
        self.ydim = input_dims["y"]
        self.Xdim_output = output_dims["X"]
        self.Edim_output = output_dims["E"]
        self.ydim_output = output_dims["y"]
        self.node_dist = nodes_dist
        self.active_index = active_index
        self.max_n_nodes = data_info.max_n_nodes
        self.train_loss = TrainLossDiscrete(dm_cfg.lambda_train)
        self.atom_decoder = data_info.atom_decoder
        self.hidden_size = dm_cfg.hidden_size
        self.text_input_size = 768

        self.denoiser = Transformer(
            max_n_nodes=self.max_n_nodes,
            hidden_size=dm_cfg.hidden_size,
            depth=dm_cfg.depth,
            num_heads=dm_cfg.num_heads,
            mlp_ratio=dm_cfg.mlp_ratio,
            drop_condition=dm_cfg.drop_condition,
            Xdim=self.Xdim,
            Edim=self.Edim,
            ydim=self.ydim,
            text_dim=self.text_input_size
        )
        # self.model_dtype = next(self.denoiser.parameters()).dtype
        self.model_dtype = model_dtype
        # self.device = next(self.denoiser.parameters()).device

        # model_params = torch.load(model_params_path, map_location='cpu')
        # self.denoiser.load_state_dict(model_params)
        
        self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
            dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
        )
        x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
            data_info.node_types.to(self.model_dtype)
        )
        e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
            data_info.edge_types.to(self.model_dtype)
        )
        x_marginals = x_marginals / x_marginals.sum()
        e_marginals = e_marginals / e_marginals.sum()

        xe_conditions = data_info.transition_E.to(self.model_dtype)
        xe_conditions = xe_conditions[self.active_index][:, self.active_index]

        xe_conditions = xe_conditions.sum(dim=1)
        ex_conditions = xe_conditions.t()
        xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
        ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)

        self.transition_model = utils.MarginalTransition(
            x_marginals=x_marginals,
            e_marginals=e_marginals,
            xe_conditions=xe_conditions,
            ex_conditions=ex_conditions,
            y_classes=self.ydim_output,
            n_nodes=self.max_n_nodes,
        )
        self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)

    def init_model(self, model_dir, verbose=False):
        model_file = os.path.join(model_dir, 'model.pt')
        if os.path.exists(model_file):
            self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
        else:
            raise FileNotFoundError(f"Model file not found: {model_file}")
        
        if verbose:
            print('GraphDiT Denoiser Model initialized.')
            print('Denoiser model:\n', self.denoiser)

    def save_pretrained(self, output_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        # Save model
        model_path = os.path.join(output_dir, 'model.pt')
        torch.save(self.denoiser.state_dict(), model_path)
        
        # Save model config
        config_path = os.path.join(output_dir, 'model_config.yaml')
        with open(config_path, 'w') as f:
            yaml.dump(vars(self.model_config), f)
        
        # Save data info
        data_info_path = os.path.join(output_dir, 'data.meta.json')
        data_info_dict = {
            "active_atoms": self.data_info.active_atoms,
            "max_node": self.data_info.max_n_nodes,
            "n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(),
            "bond_type_dist": self.data_info.edge_types.tolist(),
            "transition_E": self.data_info.transition_E.tolist(),
            "atom_type_dist": self.data_info.node_types.tolist(),
            "valencies": self.data_info.valency_distribution.tolist()
        }
        with open(data_info_path, 'w') as f:
            json.dump(data_info_dict, f, indent=2)
        
        print('GraphDiT Model and configurations saved to:', output_dir)

    def disable_grads(self):
        self.denoiser.disable_grads()
    
    def forward(
        self, x, edge_index, edge_attr, graph_batch, properties, text_embedding, no_label_index
    ):
        properties = torch.where(properties == no_label_index, float("nan"), properties)
        data_x = F.one_hot(x, num_classes=118).to(self.model_dtype)[
            :, self.active_index
        ]
        data_edge_attr = F.one_hot(edge_attr, num_classes=5).to(self.model_dtype)

        dense_data, node_mask = utils.to_dense(
            data_x, edge_index, data_edge_attr, graph_batch, self.max_n_nodes
        )
        X, E = dense_data.X, dense_data.E

        dense_data = dense_data.mask(node_mask)
        noisy_data = self.apply_noise(X, E, properties, node_mask)
        pred = self._forward(noisy_data, text_embedding)
        loss = self.train_loss(
            masked_pred_X=pred.X,
            masked_pred_E=pred.E,
            true_X=X,
            true_E=E,
            node_mask=node_mask,
        )
        return loss

    def _forward(self, noisy_data, text_embedding, unconditioned=False):
        noisy_x, noisy_e, properties = (
            noisy_data["X_t"].to(self.model_dtype),
            noisy_data["E_t"].to(self.model_dtype),
            noisy_data["y_t"].to(self.model_dtype).clone(),
        )
        node_mask, timestep, text_embedding = (
            noisy_data["node_mask"],
            noisy_data["t"],
            text_embedding.to(self.model_dtype),
        )
        
        pred = self.denoiser(
            noisy_x,
            noisy_e,
            node_mask,
            properties,
            text_embedding,
            timestep,
            unconditioned=unconditioned,
        )
        return pred

    def apply_noise(self, X, E, y, node_mask):
        """Sample noise and apply it to the data."""

        # Sample a timestep t.
        # When evaluating, the loss for t=0 is computed separately
        lowest_t = 0 if self.training else 1
        t_int = torch.randint(
            lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
        ).to(
            self.model_dtype
        )  # (bs, 1)
        s_int = t_int - 1

        t_float = t_int / self.T
        s_float = s_int / self.T

        # beta_t and alpha_s_bar are used for denoising/loss computation
        beta_t = self.noise_schedule(t_normalized=t_float)  # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)  # (bs, 1)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)  # (bs, 1)

        Qtb = self.transition_model.get_Qt_bar(
            alpha_t_bar, X.device
        )  # (bs, dx_in, dx_out), (bs, de_in, de_out)

        bs, n, d = X.shape
        X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
        prob_all = X_all @ Qtb.X
        probX = prob_all[:, :, : self.Xdim_output]
        probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)

        sampled_t = utils.sample_discrete_features(
            probX=probX, probE=probE, node_mask=node_mask
        )

        X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
        E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
        assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

        y_t = y
        z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)

        noisy_data = {
            "t_int": t_int,
            "t": t_float,
            "beta_t": beta_t,
            "alpha_s_bar": alpha_s_bar,
            "alpha_t_bar": alpha_t_bar,
            "X_t": z_t.X,
            "E_t": z_t.E,
            "y_t": z_t.y,
            "node_mask": node_mask,
        }
        return noisy_data

    @torch.no_grad()
    def generate(
        self,
        properties,
        text_embedding,
        no_label_index,
    ):
        properties = torch.where(properties == no_label_index, float("nan"), properties)
        batch_size = properties.size(0)
        device = text_embedding.device
        # print('self.denoiser.dtype', self.model_dtype)
        n_nodes = self.node_dist.sample_n(batch_size, device)
        arange = (
            torch.arange(self.max_n_nodes, device=device)
            .unsqueeze(0)
            .expand(batch_size, -1)
        )
        node_mask = arange < n_nodes.unsqueeze(1)

        z_T = utils.sample_discrete_feature_noise(
            limit_dist=self.limit_dist, node_mask=node_mask
        )
        X, E = z_T.X, z_T.E

        assert (E == torch.transpose(E, 1, 2)).all()

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        y = properties
        for s_int in reversed(range(0, self.T)):
            s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
            t_array = s_array + 1
            s_norm = s_array / self.T
            t_norm = t_array / self.T

            # Sample z_s
            sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
                s_norm, t_norm, X, E, y, text_embedding, node_mask
            )
            X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        # Sample
        sampled_s = sampled_s.mask(node_mask, collapse=True)
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        molecule_list = []
        for i in range(batch_size):
            n = n_nodes[i]
            atom_types = X[i, :n].cpu()
            edge_types = E[i, :n, :n].cpu()
            molecule_list.append([atom_types, edge_types])

        smiles_list = graph_to_smiles(molecule_list, self.atom_decoder)
        return smiles_list

    def check_valid(self, smiles):
        return check_valid(smiles)
    
    def sample_p_zs_given_zt(
        self, s, t, X_t, E_t, properties, text_embedding, node_mask
    ):
        """Samples from zs ~ p(zs | zt). Only used during sampling.
        if last_step, return the graph prediction as well"""
        bs, n, _ = X_t.shape
        beta_t = self.noise_schedule(t_normalized=t)  # (bs, 1)
        alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
        alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)

        # Neural net predictions
        noisy_data = {
            "X_t": X_t,
            "E_t": E_t,
            "y_t": properties,
            "t": t,
            "node_mask": node_mask,
        }

        def get_prob(noisy_data, text_embedding, unconditioned=False):
            pred = self._forward(noisy_data, text_embedding, unconditioned=unconditioned)

            # Normalize predictions
            pred_X = F.softmax(pred.X, dim=-1)  # bs, n, d0
            pred_E = F.softmax(pred.E, dim=-1)  # bs, n, n, d0

            device = text_embedding.device
            # Retrieve transitions matrix
            Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
            Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
            Qt = self.transition_model.get_Qt(beta_t, device)

            Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
            predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)

            unnormalized_probX_all = utils.reverse_diffusion(
                predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
            )

            unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
            unnormalized_prob_E = unnormalized_probX_all[
                :, :, self.Xdim_output :
            ].reshape(bs, n * n, -1)

            unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
            unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5

            prob_X = unnormalized_prob_X / torch.sum(
                unnormalized_prob_X, dim=-1, keepdim=True
            )  # bs, n, d_t-1
            prob_E = unnormalized_prob_E / torch.sum(
                unnormalized_prob_E, dim=-1, keepdim=True
            )  # bs, n, d_t-1
            prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

            return prob_X, prob_E

        prob_X, prob_E = get_prob(noisy_data, text_embedding)

        ### Guidance
        if self.guide_scale is not None and self.guide_scale != 1:
            uncon_prob_X, uncon_prob_E = get_prob(
                noisy_data, text_embedding, unconditioned=True
            )
            prob_X = (
                uncon_prob_X
                * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale
            )
            prob_E = (
                uncon_prob_E
                * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale
            )
            prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
            prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)

        # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
        # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()

        sampled_s = utils.sample_discrete_features(
            prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
        )

        X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
        E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
        out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)

        return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
            node_mask, collapse=True
        ).type_as(properties)
    

class TrainLossDiscrete(nn.Module):
    """Train with Cross entropy"""

    def __init__(self, lambda_train):
        super().__init__()
        self.lambda_train = lambda_train

    def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, node_mask):

        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)
        masked_pred_X = torch.reshape(
            masked_pred_X, (-1, masked_pred_X.size(-1))
        )  # (bs * n, dx)
        masked_pred_E = torch.reshape(
            masked_pred_E, (-1, masked_pred_E.size(-1))
        )  # (bs * n * n, de)

        # Remove masked rows
        mask_X = (true_X != 0.0).any(dim=-1)
        mask_E = (true_E != 0.0).any(dim=-1)

        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]

        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]

        target_X = torch.argmax(flat_true_X, dim=-1)
        loss_X = F.cross_entropy(flat_pred_X, target_X, reduction="mean")

        target_E = torch.argmax(flat_true_E, dim=-1)
        loss_E = F.cross_entropy(flat_pred_E, target_E, reduction="mean")

        total_loss = self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E

        return total_loss
