import torch
import numpy as np
from torch.nn import functional as F
import io
import gzip


def reparam_trick(mu, log_sigma):
    """
    Generate samples from a normal distribution for reparametrization trick.

    input args
        mu: mean of the Gaussian distribution for q(s|z,x) = N(mu, sigma^2*I).
        log_sigma: log of variance of the Gaussian distribution for q(s|z,x) = N(mu, sigma^2*I).

    return
        a sample from Gaussian distribution N(mu, sigma^2*I).
    """
    
    device = mu.device
    std = log_sigma.exp().sqrt()
    ep = torch.FloatTensor(std.size()).normal_().to(device)
    return ep.mul(std).add(mu)



def sample_gumbel(shape, eps=1e-4):
    """
    Generates samples from Gumbel distribution.

    input args
        size: number of cells in a batch (int).
        eps: a small value to prevent numerical instability.

    return
        -(log(-log(U))) (tensor)
    """
    
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)



def gumbel_softmax_sample(phi, temperature, eps=1e-4):
    """
    Generates samples via Gumbel-softmax distribution.

    input args
        phi: probabilities of categories.
        temperature: a hyperparameter that define the shape of the distribution across categtories.
        eps: a small value to prevent numerical instability.

    return
        Samples from a categorical distribution.
    """
    
    logits = (phi + eps).log() + sample_gumbel(phi.size(), eps).to(phi.device)
    return F.softmax(logits / temperature, dim=-1)



def gumbel_softmax(phi, latent_dim, categorical_dim, temperature, hard=False, gumble_noise=True, eps=1e-4):
    """
    Implements Straight-Through (ST) Gumbel-softmax and regular Gumbel-softmax.

    input args
        phi: probabilities of categories.
        latent_dim: latent variable dimension.
        categorical_dim: number of categories of the latent variables.
        temperature: a hyperparameter that define the shape of the distribution across categories.
        hard: a boolean variable, True uses one-hot method that is used in ST Gumbel-softmax, and False uses the Gumbel-softmax function.
        gumble_noise: a boolean variable, True uses the Gumbel noise, and False uses the original probabilities.
        eps: a small value to prevent numerical instability.

    return
        Samples from a categorical distribution, a tensor with latent_dim x categorical_dim.
    """
    
    if gumble_noise:
        y = gumbel_softmax_sample(phi, temperature, eps)
    else:
        y = phi

    if not hard:
        return y.view(-1, latent_dim * categorical_dim)
    else:
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y_hard = (y_hard - y).detach() + y
        return y_hard.view(-1, latent_dim * categorical_dim)


def zinb_error(rec_x, x_p, x_r, X, eps=1e-6):
    """
    loss function using zero inflated negative binomial distribution for
    log(x|s,z) for genes expression data.

   input args
        rec_x: log of mean value of the negative binomial distribution.
        x_p: log of the probability of dropout events.
        x_r: log of the probability of zero inflation.
        X: input data.
        eps: a small constant value to fix computation overflow.

    return
        l_zinb: log of loss value
    """

    X_dim = X.size(-1)
    k = X.exp() - 1. #logp(count) -->  (count)

    # extracting r,p, and z from the concatenated vactor.
    # eps added for stability.
    r = rec_x + eps # zinb_params[:, :X_dim] + eps
    p = (1 - eps)*(x_p + eps) # (1 - eps)*(zinb_params[:, X_dim:2*X_dim] + eps)
    z = (1 - eps)*(x_r + eps) # (1 - eps)*(zinb_params[:, 2*X_dim:] + eps)

    mask_nonzeros = ([X > 0])[0].to(torch.float32)
    loss_zero_counts = (mask_nonzeros-1) * (z + (1-z) * (1-p).pow(r)).log()
    # log of zinb for non-negative terms, excluding x! term
    loss_nonzero_counts = mask_nonzeros * (-(k + r).lgamma() + r.lgamma() - k*p.log() - r*(1-p).log() - (1-z).log())

    l_zinb = (loss_zero_counts + loss_nonzero_counts).mean()

    return l_zinb


def safe_item(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.data.item() 
    elif isinstance(tensor, np.ndarray) or isinstance(tensor, list):
        return sum(tensor).item() if hasattr(sum(tensor), "item") else sum(tensor)
    else:
        return tensor.item() if hasattr(tensor, "item") else 0.
 

def custom_collate_fn(batch):
    collated_batch = []
    for i in range(len(batch[0])):
        element = [to_dense(sample[i]) for sample in batch]
        collated_batch.append(torch.utils.data._utils.collate.default_collate(element))
    
    return tuple(collated_batch)


def to_dense(tensor):
        return tensor.to_dense() if hasattr(tensor, "to_dense") else tensor


def jaccard_distance(target: torch.Tensor, prediction: torch.Tensor, scaled: bool = True) -> torch.Tensor:
    """
    Compute the Jaccard distance between two binary tensors.
    
    Args:
        target (torch.Tensor): Target tensor.
        prediction (torch.Tensor): Prediction tensor.
        
    Returns:
        torch.Tensor: Jaccard distance.
    """
    # Ensure binary tensors
    y_true = target.bool() if hasattr(target, "bool") else target
    y_pred = prediction.bool() if hasattr(prediction, "bool") else prediction
    
    # Compute intersection and union
    intersection = torch.sum(y_true & y_pred)
    union = torch.sum(y_true | y_pred)
    
    # Jaccard Index and Distance
    jaccard_index = intersection / union if union > 0 else torch.tensor(0.0)
    if scaled:
        return (1 - jaccard_index) * union
    else:
        return 1 - jaccard_index
    
    

# def save_model_with_zstd(model, filepath, compression_level=3):  # Default compression level
#     """Saves a PyTorch model with zstd compression.

#     Args:
#         model: The PyTorch model to save.
#         filepath: The path to save the compressed model file (e.g., 'model.pth.zst').
#         compression_level: The zstd compression level (1-19, higher is better compression but slower).
#                            Defaults to 3.  Consider higher values (e.g., 9) if size is critical.
#     """
#     try:
#         # 1. Save the model's state_dict to a memory buffer (in-memory serialization)
#         buffer = io.BytesIO() # Use an in-memory buffer
#         torch.save(model.state_dict().detach().cpu(), buffer)
#         buffer.seek(0)  # Go back to the beginning of the buffer

#         # 2. Compress the buffer's contents using zstd
#         cctx = zstd.ZstdCompressor(level=compression_level)
#         compressed_data = cctx.compress(buffer.getvalue())

#         # 3. Write the compressed data to the file
#         with open(filepath, 'wb') as f:
#             f.write(compressed_data)

#         print(f"Model saved and compressed to {filepath} (zstd level {compression_level})")

#     except Exception as e:
#         print(f"Error saving model with zstd: {e}")
        

# def load_model_with_zstd(model, filepath):
#     """Loads a PyTorch model from a zstd-compressed file.

#     Args:
#         model: The PyTorch model instance (must have the same architecture).
#         filepath: The path to the compressed model file (e.g., 'model.pth.zst').

#     Returns:
#         The loaded state_dict or None if an error occurred.
#     """
#     try:
#         # 1. Read the compressed data from the file
#         with open(filepath, 'rb') as f:
#             compressed_data = f.read()

#         # 2. Decompress the data using zstd
#         dctx = zstd.ZstdDecompressor()
#         decompressed_data = dctx.decompress(compressed_data)

#         # 3. Load the state_dict from the decompressed data (using the in-memory buffer)
#         buffer = io.BytesIO(decompressed_data)
#         state_dict = torch.load(buffer)

#         model.load_state_dict(state_dict) # Load to model
#         print(f"Model loaded from {filepath} (zstd)")
#         return model.state_dict() # Return the state dict
#     except Exception as e:
#         print(f"Error loading model with zstd: {e}")
#         return None


def save_model_with_gzip(state_dict, filepath, compression_level=3): 
        
    buffer = io.BytesIO()
    torch.save(state_dict, buffer) 
    buffer.seek(0)

    with gzip.open(filepath, "wb", compresslevel=compression_level) as f:
        f.write(buffer.getvalue())

    print(f"Model saved to {filepath}")


def load_model_with_gzip(model, filepath):
    try:
        with gzip.open(filepath, "rb") as f:
            buffer = io.BytesIO(f.read())

        state_dict = torch.load(buffer, map_location="cpu")
        model.load_state_dict(state_dict)
        print(f"Model loaded successfully from {filepath}")
        return model
    
    except Exception as e:
        print(f"Error loading model with gzip: {e}")
        return None


