import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as nn_init
import torch.nn.functional as F
from torch import Tensor

import typing as ty
import math
import os

def gmmpara_init(n_centroid, latent_dim):
    # Initialize parameters
    theta_init = torch.ones(n_centroid) / n_centroid
    u_init = torch.zeros(latent_dim, n_centroid)
    lambda_init = torch.ones(latent_dim, n_centroid)
    
    # Convert to PyTorch tensors with requires_grad=True if they need to be optimized
    theta_p = torch.nn.Parameter(torch.tensor(theta_init, dtype=torch.float))
    u_p = torch.nn.Parameter(torch.tensor(u_init, dtype=torch.float))
    lambda_p = torch.nn.Parameter(torch.tensor(lambda_init, dtype=torch.float))
    
    return theta_p, u_p, lambda_p


def log_gmm_probs(z, gmm_means, gmm_log_variances, gmm_pi):
    """
    Compute the log probabilities of Gaussian Mixture Model.
    z: latent variables, shape [batch_size, latent_dim]
    gmm_means: means of the GMM, shape [n_clusters, latent_dim]
    gmm_log_variances: log variances of the GMM, shape [n_clusters, latent_dim]
    gmm_pi: mixture coefficients, shape [n_clusters]
    """
    D = z.shape[1]  # Latent dimension
    log_pi = torch.log(gmm_pi)
    
    # Calculating the log probability under each Gaussian component
    term1 = -0.5 * D * torch.log(torch.tensor(2 * np.pi))
    term2 = -0.5 * torch.sum(gmm_log_variances, dim=1)
    term3 = -0.5 * torch.sum(((z.unsqueeze(1) - gmm_means) ** 2) / torch.exp(gmm_log_variances), dim=2)
    log_probs = term1 + term2 + term3
    
    # Add log mixture coefficients
    log_probs = log_probs + log_pi

    return log_probs

def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp."""
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def compute_log_posteriors(log_probs):
    log_p_z = logsumexp(log_probs, dim=1, keepdim=True)
    log_posteriors = log_probs - log_p_z
    return log_posteriors

def gmm_prob(z, gmm_means, gmm_log_variances, gmm_pi):
    log_probs = log_gmm_probs(z, gmm_means, gmm_log_variances, gmm_pi)
    log_posteriors = compute_log_posteriors(log_probs)
    
    return torch.exp(log_posteriors)

class Tokenizer(nn.Module):

    def __init__(self, d_numerical, categories, d_token, bias):
        super().__init__()
        if categories is None:
            d_bias = d_numerical
            self.category_offsets = None
            self.category_embeddings = None
        else:
            d_bias = d_numerical + len(categories)
            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
            self.register_buffer('category_offsets', category_offsets)
            self.category_embeddings = nn.Embedding(sum(categories), d_token)
            nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
            print(f'{self.category_embeddings.weight.shape=}')

        # take [CLS] token into account
        self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    @property
    def n_tokens(self):
        return len(self.weight) + (
            0 if self.category_offsets is None else len(self.category_offsets)
        )

    def forward(self, x_num, x_cat):
        x_some = x_num if x_cat is None else x_cat
        assert x_some is not None
        x_num = torch.cat(
            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
            + ([] if x_num is None else [x_num]),
            dim=1,
        )
    
        x = self.weight[None] * x_num[:, :, None]

        if x_cat is not None:
            x = torch.cat(
                [x, self.category_embeddings(x_cat + self.category_offsets[None])],
                dim=1,
            )
        if self.bias is not None:
            bias = torch.cat(
                [
                    torch.zeros(1, self.bias.shape[1], device=x.device),
                    self.bias,
                ]
            )
            x = x + bias[None]

        return x

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.5):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout = dropout

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class MultiheadAttention(nn.Module):
    def __init__(self, d, n_heads, dropout, initialization = 'kaiming'):

        if n_heads > 1:
            assert d % n_heads == 0
        assert initialization in ['xavier', 'kaiming']

        super().__init__()
        self.W_q = nn.Linear(d, d)
        self.W_k = nn.Linear(d, d)
        self.W_v = nn.Linear(d, d)
        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None

        for m in [self.W_q, self.W_k, self.W_v]:
            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
                # gain is needed since W_qkv is represented with 3 separate layers
                nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
            nn_init.zeros_(m.bias)
        if self.W_out is not None:
            nn_init.zeros_(self.W_out.bias)

    def _reshape(self, x):
        batch_size, n_tokens, d = x.shape
        d_head = d // self.n_heads
        return (
            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
            .transpose(1, 2)
            .reshape(batch_size * self.n_heads, n_tokens, d_head)
        )

    def forward(self, x_q, x_kv, key_compression = None, value_compression = None):
  
        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
        for tensor in [q, k, v]:
            assert tensor.shape[-1] % self.n_heads == 0
        if key_compression is not None:
            assert value_compression is not None
            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
        else:
            assert value_compression is None

        batch_size = len(q)
        d_head_key = k.shape[-1] // self.n_heads
        d_head_value = v.shape[-1] // self.n_heads
        n_q_tokens = q.shape[1]

        q = self._reshape(q)
        k = self._reshape(k)

        a = q @ k.transpose(1, 2)
        b = math.sqrt(d_head_key)
        attention = F.softmax(a/b , dim=-1)

        
        if self.dropout is not None:
            attention = self.dropout(attention)
        x = attention @ self._reshape(v)
        x = (
            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
            .transpose(1, 2)
            .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
        )
        if self.W_out is not None:
            x = self.W_out(x)

        return x
        
class Transformer(nn.Module):

    def __init__(
        self,
        n_layers: int,
        d_token: int,
        n_heads: int,
        d_out: int,
        d_ffn_factor: int,
        attention_dropout = 0.0,
        ffn_dropout = 0.0,
        residual_dropout = 0.0,
        activation = 'relu',
        prenormalization = True,
        initialization = 'kaiming',      
    ):
        super().__init__()

        def make_normalization():
            return nn.LayerNorm(d_token)

        d_hidden = int(d_token * d_ffn_factor)
        self.layers = nn.ModuleList([])
        for layer_idx in range(n_layers):
            layer = nn.ModuleDict(
                {
                    'attention': MultiheadAttention(
                        d_token, n_heads, attention_dropout, initialization
                    ),
                    'linear0': nn.Linear(
                        d_token, d_hidden
                    ),
                    'linear1': nn.Linear(d_hidden, d_token),
                    'norm1': make_normalization(),
                }
            )
            if not prenormalization or layer_idx:
                layer['norm0'] = make_normalization()
   
            self.layers.append(layer)

        self.activation = nn.ReLU()
        self.last_activation = nn.ReLU()
        # self.activation = lib.get_activation_fn(activation)
        # self.last_activation = lib.get_nonglu_activation_fn(activation)
        self.prenormalization = prenormalization
        self.last_normalization = make_normalization() if prenormalization else None
        self.ffn_dropout = ffn_dropout
        self.residual_dropout = residual_dropout
        self.head = nn.Linear(d_token, d_out)


    def _start_residual(self, x, layer, norm_idx):
        x_residual = x
        if self.prenormalization:
            norm_key = f'norm{norm_idx}'
            if norm_key in layer:
                x_residual = layer[norm_key](x_residual)
        return x_residual

    def _end_residual(self, x, x_residual, layer, norm_idx):
        if self.residual_dropout:
            x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
        x = x + x_residual
        if not self.prenormalization:
            x = layer[f'norm{norm_idx}'](x)
        return x

    def forward(self, x):
        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx + 1 == len(self.layers)

            x_residual = self._start_residual(x, layer, 0)
            x_residual = layer['attention'](
                # for the last attention, it is enough to process only [CLS]
                x_residual,
                x_residual,
            )

            x = self._end_residual(x, x_residual, layer, 0)

            x_residual = self._start_residual(x, layer, 1)
            x_residual = layer['linear0'](x_residual)
            x_residual = self.activation(x_residual)
            if self.ffn_dropout:
                x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
            x_residual = layer['linear1'](x_residual)
            x = self._end_residual(x, x_residual, layer, 1)
        return x


class AE(nn.Module):
    def __init__(self, hid_dim, n_head):
        super(AE, self).__init__()
 
        self.hid_dim = hid_dim
        self.n_head = n_head


        self.encoder = MultiheadAttention(hid_dim, n_head)
        self.decoder = MultiheadAttention(hid_dim, n_head)

    def get_embedding(self, x):
        return self.encoder(x, x).detach() 

    def forward(self, x):

        z = self.encoder(x, x)
        h = self.decoder(z, z)
        
        return h

class VAE(nn.Module):
    def __init__(self, d_numerical, categories, num_layers, hid_dim, n_head = 1, factor = 4, bias = True):
        super(VAE, self).__init__()
 
        self.d_numerical = d_numerical
        self.categories = categories
        self.hid_dim = hid_dim
        d_token = hid_dim
        self.n_head = n_head

        self.Tokenizer = Tokenizer(d_numerical, categories, d_token, bias = bias)

        self.encoder_mu = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)
        self.encoder_logvar = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)

        self.decoder = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)

    def get_embedding(self, x):
        return self.encoder_mu(x, x).detach() 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x_num, x_cat, return_z=False):
        x = self.Tokenizer(x_num, x_cat)

        mu_z = self.encoder_mu(x)
        std_z = self.encoder_logvar(x)

        z = self.reparameterize(mu_z, std_z)

        h = self.decoder(z[:,1:])

        if return_z:
            return h, mu_z, std_z, z
        else:
            return h, mu_z, std_z

class Reconstructor(nn.Module):
    def __init__(self, d_numerical, categories, d_token):
        super(Reconstructor, self).__init__()

        self.d_numerical = d_numerical
        self.categories = categories
        self.d_token = d_token
        
        self.weight = nn.Parameter(Tensor(d_numerical, d_token))  
        nn.init.xavier_uniform_(self.weight, gain=1 / math.sqrt(2))
        self.cat_recons = nn.ModuleList()

        for d in categories:
            recon = nn.Linear(d_token, d)
            nn.init.xavier_uniform_(recon.weight, gain=1 / math.sqrt(2))
            self.cat_recons.append(recon)

    def forward(self, h):
        h_num  = h[:, :self.d_numerical]
        h_cat  = h[:, self.d_numerical:]

        recon_x_num = torch.mul(h_num, self.weight.unsqueeze(0)).sum(-1)
        recon_x_cat = []

        for i, recon in enumerate(self.cat_recons):
      
            recon_x_cat.append(recon(h_cat[:, i]))

        return recon_x_num, recon_x_cat


class Model_VAE(nn.Module):
    def __init__(
            self, 
            num_layers, 
            d_numerical, 
            categories, 
            d_token, 
            n_head = 1, 
            factor = 4,  
            bias = True, 
            compress_dim=None,
            num_clusters=-1
        ):
        super(Model_VAE, self).__init__()

        self.VAE = VAE(d_numerical, categories, num_layers, d_token, n_head = n_head, factor = factor, bias = bias)
        self.Reconstructor = Reconstructor(d_numerical, categories, d_token)
        self.num_clusters = num_clusters
        if num_clusters > 0:
            path = '/u3/w3pang/guided_tab_ddpm/tabsyn_workspace/california_20_quick_test/individual/household_individual/vae'
            weights_path = os.path.join(path, 'weights_.npy')
            means_path = os.path.join(path, 'means_.npy')
            covariances_path = os.path.join(path, 'covariances_.npy')

            self.latent_dim = d_token * (d_numerical + len(categories))
            self._pi = nn.Parameter(torch.zeros(num_clusters))
            self.mu = nn.Parameter(torch.randn(num_clusters, self.latent_dim))
            self.logvar = nn.Parameter(torch.randn(num_clusters, self.latent_dim))

            
            # self._pi.data = torch.log(torch.from_numpy(np.load(weights_path))).float()
            # self.mu.data = torch.from_numpy(np.load(means_path)).float()
            # self.logvar.data = torch.log(torch.from_numpy(np.load(covariances_path))).float()
                

        if compress_dim is not None:
            self.compress_layer = nn.Linear(d_token * (d_numerical + len(categories)), compress_dim)
            self.uncompress_layer = nn.Linear(compress_dim, d_token * (d_numerical + len(categories)))
            # initialize weights
            nn.init.xavier_uniform_(self.compress_layer.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.uncompress_layer.weight, gain=1 / math.sqrt(2))

    @property
    def weights(self):
        return torch.softmax(self._pi, dim=0)

    def get_gamma(self, z, u_p, lambda_p, theta_p):
        theta_p = F.softmax(theta_p)
        batch_size = z.shape[0]
        temp_Z = z.unsqueeze(2).expand(-1, -1, self.num_clusters)
        temp_u_tensor3 = u_p.expand(batch_size, -1, -1).permute(0, 2, 1)
        temp_lambda_tensor3 = lambda_p.expand(batch_size, -1, -1).permute(0, 2, 1)
        temp_theta_tensor3 = theta_p.unsqueeze(0).unsqueeze(1).expand(batch_size, self.latent_dim, -1)
        
        temp_p_c_z = torch.exp(torch.sum((torch.log(temp_theta_tensor3) - 0.5 * torch.log(2 * math.pi * temp_lambda_tensor3) - 
                                          (temp_Z - temp_u_tensor3) ** 2 / (2 * temp_lambda_tensor3)), dim=1)) + 1e-10
        
        gamma = temp_p_c_z / torch.sum(temp_p_c_z, dim=-1, keepdim=True)
        return gamma

    
    def get_embedding(self, x_num, x_cat):
        x = self.Tokenizer(x_num, x_cat)
        return self.VAE.get_embedding(x)


    def vade_forward(self, x_num, x_cat):
        h, z_mean, z_log_var, z = self.VAE(x_num, x_cat, return_z=True)
        recon_x_num, recon_x_cat = self.Reconstructor(h)

        z = z[:, 1:]
        z_mean = z_mean[:, 1:]
        z_log_var = z_log_var[:, 1:]

        z = z.view(z.size(0), -1)
        z_mean = z_mean.view(z_mean.size(0), -1)
        z_log_var = z_log_var.view(z_log_var.size(0), -1)

        return recon_x_num, recon_x_cat, z_mean, z_log_var, z

    def classify(self, x_num, x_cat):
        with torch.no_grad():
            _, _, _, z = self.VAE(x_num, x_cat, return_z=True)
            z = z[:, 1:]
            z = z.view(z.size(0), -1)
            z_expanded = z.unsqueeze(1)

            mu_diff = z_expanded - self.mu
            squared_diff = mu_diff.pow(2)
            neg_log_likelihood = -0.5 * (squared_diff / self.logvar.exp()).sum(dim=2)
            log_norm = torch.logsumexp(neg_log_likelihood, dim=1, keepdim=True)
            log_p_z_given_c = neg_log_likelihood - log_norm - math.log(2 * math.pi)
            log_p_z_c = log_p_z_given_c + torch.log(self.weights)
            log_gamma = log_p_z_c - torch.logsumexp(log_p_z_c, dim=1, keepdim=True)

            y = log_gamma.exp()
            pred = torch.argmax(y, dim=1)
        return pred

    def forward(self, x_num, x_cat, compress=False):

        h, mu_z, std_z, z = self.VAE(x_num, x_cat, return_z=True)

        if compress:
            h_compressed = self.compress_layer(h.view(h.size(0), -1))
            # h_compressed = F.relu(h_compressed)
            h = self.uncompress_layer(h_compressed)
            # h = F.relu(h)
            h = h.view(h.size(0), -1, self.VAE.hid_dim)

        # recon_x_num, recon_x_cat = self.Reconstructor(h[:, 1:])
        recon_x_num, recon_x_cat = self.Reconstructor(h)

        if compress:
            return recon_x_num, recon_x_cat, mu_z, std_z, h_compressed
        else:
            return recon_x_num, recon_x_cat, mu_z, std_z


class Encoder_model(nn.Module):
    def __init__(self, num_layers, d_numerical, categories, d_token, n_head, factor, bias = True):
        super(Encoder_model, self).__init__()
        self.Tokenizer = Tokenizer(d_numerical, categories, d_token, bias)
        self.VAE_Encoder = Transformer(num_layers, d_token, n_head, d_token, factor)

    def load_weights(self, Pretrained_VAE):
        self.Tokenizer.load_state_dict(Pretrained_VAE.VAE.Tokenizer.state_dict())
        self.VAE_Encoder.load_state_dict(Pretrained_VAE.VAE.encoder_mu.state_dict())

    def forward(self, x_num, x_cat):
        x = self.Tokenizer(x_num, x_cat)
        z = self.VAE_Encoder(x)

        return z

class Decoder_model(nn.Module):
    def __init__(self, num_layers, d_numerical, categories, d_token, n_head, factor, bias = True):
        super(Decoder_model, self).__init__()
        self.VAE_Decoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        self.Detokenizer = Reconstructor(d_numerical, categories, d_token)
        
    def load_weights(self, Pretrained_VAE):
        self.VAE_Decoder.load_state_dict(Pretrained_VAE.VAE.decoder.state_dict())
        self.Detokenizer.load_state_dict(Pretrained_VAE.Reconstructor.state_dict())

    def forward(self, z):

        h = self.VAE_Decoder(z)
        x_hat_num, x_hat_cat = self.Detokenizer(h)

        return x_hat_num, x_hat_cat
