import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import math

from models.conditions import ClusterContinuousEmbedder


class LatentSchedule(nn.Module):
    def __init__(self, X_dim, E_dim, ydim, context_dim, timesteps, max_n_nodes, 
                eps=1e-4, k_init=-6, k_offset=1):
        super(LatentSchedule, self).__init__()
        """
        Args:
            timesteps: total number of diffusion steps
            num_categories: number of categorical features (for separate schedules)
            delta: small constant to bound noise schedules in (δ, 1−δ)
        """
        super().__init__()
        self.timesteps = timesteps
        self.eps = eps
        self.k_offset = k_offset

        # Learnable exponent k_j per categorical feature
        self.node_vocab = nn.Parameter(torch.randn(max_n_nodes, 1))
        self.max_n_nodes = max_n_nodes

    def k(self, node_vocab):
        return torch.nn.functional.softplus(node_vocab) + self.k_offset

    def alpha(self, bs, n, device, t_normalized=None, t_int=None, prop=None, 
                force_drop_ids=None, dgamma=False):
        """
        Computes α_t = 1 - t^k for each feature. Shape: [batch_size, num_categories]
        """
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_normalized is None:
            t_normalized = t_int / self.timesteps

        if self.training:
            permuted_indices = torch.randperm(self.max_n_nodes, device=self.node_vocab.device)
            node_vocab = self.node_vocab[permuted_indices]
            x = F.normalize(node_vocab, p=2, dim=-1).view(1, n).expand(bs, -1)
            e = F.normalize(
                (node_vocab.unsqueeze(1) + node_vocab.unsqueeze(0)) / 2, p=2, dim=-1
            ).view(1, n, n).expand(bs, -1, -1)
        else:
            node_vocab = self.node_vocab
            x = F.normalize(node_vocab, p=2, dim=-1).view(1, n).expand(bs, -1)
            e = F.normalize(
                (node_vocab.unsqueeze(1) + node_vocab.unsqueeze(0)) / 2, p=2, dim=-1
            ).view(1, n, n).expand(bs, -1, -1)

        k_node = self.k(x)
        k_edge = self.k(e)

        # Rescale time into [δ, 1−δ] as in paper Eq. (23)
        t_rescaled = (1 - self.eps) * t_normalized  # shape [B, 1]

        # α_t = 1 - t^k, where k is shape [num_categories]
        alpha_node = 1.0 - t_rescaled ** k_node # broadcasting over [B, C]
        alpha_edge = 1.0 - t_rescaled.unsqueeze(-1) ** k_edge  # broadcasting over [B, C]

        if dgamma:
            return (alpha_node, alpha_edge), (k_node, k_edge)
        else:
            return alpha_node, alpha_edge

    def dgamma_times_alpha(self, bs, n, device, t_normalized=None, t_int=None, prop=None, 
                force_drop_ids=None):
        """
        Computes α'(t) / (1 - α_t), where:
        α'(t) = -k * t^{k-1}, so:
        α'(t) / (1 - α_t) = (k * t^{k-1}) / (1 - α_t)
        """
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_normalized is None:
            t_normalized = t_int / self.timesteps

        alpha_t, k = self.alpha(bs, n, device, t_normalized=t_normalized, dgamma=True)  # [B, C]
        k_node, k_edge = k
        alpha_node, alpha_edge = alpha_t

        t_rescaled = (1 - self.eps) * t_normalized  # [B, 1]

        deriv_node = k_node * t_rescaled ** (k_node - 1)  # [B, C]
        deriv_edge = k_edge * t_rescaled.unsqueeze(-1) ** (k_edge - 1)  # [B, C]
        weights_node = deriv_node / (1 - alpha_node).clamp(min=1e-8)
        weights_edge = deriv_edge / (1 - alpha_edge).clamp(min=1e-8)
        return weights_node, weights_edge


class LatentPropertySchedule(nn.Module):
    def __init__(self, X_dim, E_dim, ydim, context_dim, timesteps, 
                guidance_target, drop_condition=0.01, use_provided_drop_ids=False, 
                max_n_nodes=None, eps_max=1e-3, eps=1e-4, k_offset=1):
        super(LatentPropertySchedule, self).__init__()
        """
        Args:
            timesteps: total number of diffusion steps
            num_categories: number of categorical features (for separate schedules)
            delta: small constant to bound noise schedules in (δ, 1−δ)
        """
        super().__init__()
        self.timesteps = timesteps
        self.eps = eps
        self.k_offset = k_offset

        self.node_vocab = nn.Parameter(torch.randn(max_n_nodes, 1))
        self.max_n_nodes = max_n_nodes

        self.guidance_target = guidance_target
        self.use_provided_drop_ids = use_provided_drop_ids
        self.ydim = ydim
        self.prop_embedding_list = torch.nn.ModuleList()
        if self.guidance_target in ["QM9", "ZINC"]:
            for i in range(ydim):
                self.prop_embedding_list.append(ClusterContinuousEmbedder(1, context_dim, drop_condition))
        else:
            self.prop_embedding_list.append(ClusterContinuousEmbedder(2, context_dim, drop_condition))
            for i in range(ydim - 2):
                self.prop_embedding_list.append(ClusterContinuousEmbedder(1, context_dim, drop_condition))
        
        self.xc_lin = nn.Linear(1 + context_dim, 1)
        self.ec_lin = nn.Linear(1 + context_dim, 1)

    def k(self, node_vocab):
        return torch.nn.functional.softplus(node_vocab) + self.k_offset

    def get_hashkey(self, bs, n, device):
        if self.training:
            permuted_indices = torch.randperm(self.max_n_nodes, device=self.node_vocab.device)
            node_vocab = self.node_vocab[permuted_indices]
            x = F.normalize(node_vocab, p=2, dim=-1).unsqueeze(0).expand(bs, n, 1)
            e = F.normalize((node_vocab.unsqueeze(1) + node_vocab.unsqueeze(0)) / 2, p=2, dim=-1).unsqueeze(0).expand(bs, n, n, 1)
        else:
            node_vocab = self.node_vocab
            x = F.normalize(node_vocab, p=2, dim=-1).unsqueeze(0).expand(bs, n, 1)
            e = F.normalize(
                (node_vocab.unsqueeze(1) + node_vocab.unsqueeze(0)) / 2, p=2, dim=-1
            ).unsqueeze(0).expand(bs, n, n, 1)
        
        return x, e
    
    def get_embedding(self, bs, n, device, prop=None, force_drop_ids=None): 
        hashkey_node, hashkey_edge = self.get_hashkey(bs, n, device)

        if self.guidance_target in ["QM9", "ZINC"]:
            for i in range(self.ydim):
                if i == 0:
                    c = self.prop_embedding_list[i](prop[:, i:i+1], self.training, force_drop_ids, 
                                                  use_provided_drop_ids=self.use_provided_drop_ids)
                else:
                    _c = self.prop_embedding_list[i](prop[:, i:i+1], self.training, force_drop_ids, 
                                                  use_provided_drop_ids=self.use_provided_drop_ids)
                    c = c + _c
        else:
            for i in range(1, self.ydim):
                if i == 1:
                    c = self.prop_embedding_list[i-1](prop[:, :2], self.training, force_drop_ids, 
                                                    use_provided_drop_ids=self.use_provided_drop_ids)
                else:
                    _c = self.prop_embedding_list[i-1](prop[:, i:i+1], self.training, force_drop_ids, 
                                                    use_provided_drop_ids=self.use_provided_drop_ids)
                    c = c + _c

        x = self.xc_lin(torch.cat([hashkey_node, c.unsqueeze(1).expand(bs, n, -1)], dim=-1))
        e = self.ec_lin(torch.cat([hashkey_edge, c.unsqueeze(1).unsqueeze(1).expand(bs, n, n, -1)], dim=-1))
        return x.squeeze(-1), e.squeeze(-1)

    def alpha(self, bs, n, device, t_normalized=None, t_int=None, prop=None, 
                force_drop_ids=None, dgamma=False):
        """
        Computes α_t = 1 - t^k for each feature. Shape: [batch_size, num_categories]
        """
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_normalized is None:
            t_normalized = t_int / self.timesteps

        x, e = self.get_embedding(bs, n, device, prop=prop, force_drop_ids=force_drop_ids)

        if t_normalized is not None:
            k_node = self.k(x).expand(bs, -1)
            k_edge = self.k(e).expand(bs, -1, -1)
        elif t_int is not None:
            k_node = self.k(x).expand(bs, -1)
            k_edge = self.k(e).expand(bs, -1, -1)

        # Rescale time into [δ, 1−δ] as in paper Eq. (23)
        t_rescaled = (1 - self.eps) * t_normalized  # shape [B, 1]

        # α_t = 1 - t^k, where k is shape [num_categories]
        alpha_node = 1.0 - t_rescaled ** k_node  # broadcasting over [B, C]
        alpha_edge = 1.0 - t_rescaled.unsqueeze(-1) ** k_edge  # broadcasting over [B, C]

        if dgamma:
            return (alpha_node, alpha_edge), (k_node, k_edge)
        else:
            return alpha_node, alpha_edge

    def dgamma_times_alpha(self, bs, n, device, t_normalized=None, t_int=None, prop=None, 
                force_drop_ids=None):
        """
        Computes α'(t) / (1 - α_t), where:
        α'(t) = -k * t^{k-1}, so:
        α'(t) / (1 - α_t) = (k * t^{k-1}) / (1 - α_t)
        """
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_normalized is None:
            t_normalized = t_int / self.timesteps

        alpha_t, k = self.alpha(bs, n, device, t_normalized=t_normalized, prop=prop, 
                force_drop_ids=force_drop_ids, dgamma=True)  # [B, C]
        k_node, k_edge = k
        alpha_node, alpha_edge = alpha_t

        t_rescaled = (1 - self.eps) * t_normalized  # [B, 1]

        deriv_node = k_node * t_rescaled ** (k_node - 1)  # [B, C]
        deriv_edge = k_edge * t_rescaled.unsqueeze(-1) ** (k_edge - 1)  # [B, C]

        weights_node = deriv_node / (1 - alpha_node).clamp(min=1e-8)
        weights_edge = deriv_edge / (1 - alpha_edge).clamp(min=1e-8)
        return weights_node, weights_edge


