"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import random
import numpy as np
import torch
from torch import Tensor, einsum, vmap
from torch.linalg import solve_triangular
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence



def fix_random(seed: int=0):
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True, warn_only=True)



def nanstd(tensor: Tensor, dim: int, unbiased: bool=False):
    """
    Compute the standard deviation of a tensor ignoring NaN values.

    Parameters
    ----------
    tensor: Tensor
        Input tensor.
    dim: int
        Dimension along which to compute the standard deviation.
    unbiased: bool
        If True, compute the unbiased standard deviation.

    Returns
    -------
    std: Tensor
        Standard deviation.
    """

    mean = torch.nanmean(tensor, dim=dim, keepdim=True)
    diff = tensor - mean
    mask = torch.isnan(diff)
    diff[mask] = 0.0
    diff_sq = diff**2
    n = (~mask).sum(dim=dim)
    if unbiased:
        n = n - 1
    std = torch.sqrt(torch.nansum(diff_sq, dim=dim)/n)

    return std



def count_parameters(model, log_path: str=None):
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Number of parameters: {n_params}')
    if log_path is not None:
        with open(log_path, 'a') as f:
            f.write(f'Number of parameters: {n_params}\n')
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f'{name}: {param.numel()}')
            if log_path is not None:
                with open(log_path, 'a') as f:
                    f.write(f'{name}: {param.numel()}\n')



def vector_to_lower_triangular_matrix(vector: Tensor, dim: int):
    """
    Convert a vector to a lower triangular matrix.

    Parameters
    ----------
    vector: Tensor, shape (..., dim*(dim+1)//2)
        Vector to be converted to a lower triangular matrix.
    dim: int
        Dimension of the square matrix.

    Returns
    -------
    lower_triangular_matrix: Tensor, shape (..., dim, dim)
        Lower triangular matrix.
    """
    idx_diag = torch.arange(dim)
    idx_offdiag = torch.tril_indices(dim - 1, dim - 1)
    idx_offdiag = (idx_offdiag[0] + 1, idx_offdiag[1])
    lower_triangular_matrix = torch.zeros(vector.shape[:-1] + (dim, dim), dtype=vector.dtype, device=vector.device)
    lower_triangular_matrix[..., idx_diag, idx_diag] = torch.exp(vector[...,:dim]*0.5)
    lower_triangular_matrix[..., idx_offdiag[0], idx_offdiag[1]] = vector[...,dim:]

    return lower_triangular_matrix



def broadcastable_cat(tensors, dim=0):
    """
    Concatenate tensors along a specified dimension with broadcasting.

    Parameters
    ----------
    tensors: list of Tensor
        List of tensors to be concatenated.
    dim: int
        Dimension along which to concatenate the tensors.
        
    Returns
    -------
    concatenated_tensor: Tensor
        Concatenated tensor.
    """ 
    if not tensors:
        return torch.tensor([])

    dim = dim if dim >= 0 else dim + len(tensors[0].shape)
    
    shapes = [tensor.shape for tensor in tensors]
    max_dims = max(len(shape) for shape in shapes)
    tensors = [tensor.view([1] * (max_dims - len(tensor.shape)) + list(tensor.shape)) for tensor in tensors]

    broadcast_shape = []
    for i in range(max_dims):
        if i == dim:
            broadcast_shape.append(-1)
        else:
            max_dim = max(tensor.shape[i] for tensor in tensors)
            broadcast_shape.append(max_dim)

    broadcasted_tensors = [tensor.expand([
        shape if shape != -1 else tensor.shape[i]
        for i, shape in enumerate(broadcast_shape)
        ]) for tensor in tensors]

    return torch.cat(broadcasted_tensors, dim=dim)



def triangular_inverse(L: Tensor):
    """
    Compute the inverse of a lower triangular matrix L.

    Parameters
    ----------
    L: Tensor, shape (..., n, n)
        Lower triangular matrix.

    Returns
    -------
    L_inv: Tensor, shape (..., n, n)
        Inverse of L.
    """

    L_inv = solve_triangular(L, torch.eye(L.size(-1), dtype=L.dtype, device=L.device).expand_as(L), upper=False)

    return L_inv



def positive_definite_inverse(A: Tensor, eps: float=0.0, double_precision: bool=False):
    """
    Compute the inverse of a positive definite matrix A.

    Parameters
    ----------
    A: Tensor, shape (..., n, n)
        Positive definite matrix.
    eps: float
        Small value to avoid numerical instability.

    Returns
    -------
    A_inv: Tensor, shape (..., n, n)
        Inverse of A.
    """

    L = torch.linalg.cholesky(A + eps*torch.eye(A.size(-1), device=A.device))
    L_inv = triangular_inverse(L, double_precision)
    A_inv = L_inv.transpose(-1, -2) @ L_inv

    return A_inv



def flatten_packed_sequence(packed_sequence: PackedSequence):
    """
    Flatten a packed sequence.

    Parameters
    ----------
    packed_sequence: PackedSequence
        Packed sequence.

    Returns
    -------
    flat_sequence: Tensor, shape (n_total, ...)
        Flattened sequence.
    """

    padded_sequences, sequence_lengths = pad_packed_sequence(packed_sequence, batch_first=True)
    max_length = max(sequence_lengths)  
    mask = torch.arange(max_length).expand(len(sequence_lengths), max_length) < sequence_lengths.unsqueeze(1) 
    flat_sequence = padded_sequences[mask]

    return flat_sequence



def logdet_triangular(L: Tensor):
    """
    Compute the log-determinant of a lower triangular matrix L.

    Parameters
    ----------
    L: Tensor, shape (..., n, n)
        Lower triangular matrix.

    Returns
    -------
    logdet: Tensor, shape (...)
        Log-determinant of L.
    """

    logdet = torch.log(L.diagonal(dim1=-2, dim2=-1)).sum(dim=-1)

    return logdet



def trace_covinv_covcholesky(cov: Tensor, cov_cholesky: Tensor):
    """
    Compute the trace of the product of the inverse of a covariance matrix and the Cholesky decomposition of another covariance matrix.

    Parameters
    ----------
    cov: Tensor, shape (..., n, n)
        Covariance matrix.
    cov_cholesky: Tensor, shape (..., n, n)
        Cholesky decomposition of a covariance matrix.

    Returns
    -------
    trace: Tensor, shape (...)
        Trace of the product of the inverse of cov and cov_cholesky.
    """

    trace = (torch.linalg.solve(cov, cov_cholesky)*cov_cholesky).sum(dim=(-2, -1))

    return trace



def maharanobis_norm_squared(cov, v):
    """
    Compute the squared Maharanobis norm of a vector v with respect to a covariance matrix.

    Parameters
    ----------
    cov: Tensor, shape (..., n, n)
        Covariance matrix.
    v: Tensor, shape (..., n)
        Vector.

    Returns
    -------
    maharanobis_norm_squared: Tensor, shape (...)
        Squared Maharanobis norm.
    """

    maharanobis_norm_squared = torch.einsum('...i,...i->...', torch.linalg.solve(cov, v.unsqueeze(-1))[...,0], v)

    return maharanobis_norm_squared



def maharanobis_norm_squared_cholesky(cov_cholesky, v):
    """
    Compute the squared Maharanobis norm of a vector v with respect to a covariance matrix.

    Parameters
    ----------
    cov_cholesky: Tensor, shape (..., n, n)
        Cholesky decomposition of the covariance matrix.
    v: Tensor, shape (..., n)
        Vector.

    Returns
    -------
    maharanobis_norm_squared: Tensor, shape (...)
        Squared Maharanobis norm.
    """

    maharanobis_norm_squared = (torch.linalg.solve_triangular(cov_cholesky, v.unsqueeze(-1), upper=False)**2).sum(dim=(-2, -1))
    
    return maharanobis_norm_squared



def regularize_positive_definite(A: Tensor, name_A: str=None):
    """
    Regularize a matrix to make it positive definite.

    Parameters
    ----------
    A: Tensor, shape (..., n, n)
        Matrix.
    name_A: str, optional
        Name of the matrix.

    Returns
    -------
    A_reg: Tensor, shape (..., n, n)
        Regularized matrix.
    """

    regulator = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    regulator_factor = 0.0
    while True:
        try:
            torch.linalg.cholesky(A + regulator*regulator_factor)
            break
        except torch._C._LinAlgError:
            if name_A is not None:
                print(f'Regularization of {name_A} failed at regulator factor = {regulator_factor}.')
            regulator_factor = regulator_factor*10 + 1e-10
        if regulator_factor > 1.0:
            if name_A is not None:
                raise ValueError(f'Regularization of {name_A} failed.')
            else:
                raise ValueError('Regularization failed.')

    A_reg = A + regulator*regulator_factor

    return A_reg



def safe_positive_definite_inverse(A: Tensor, name_A: str=None):
    """
    Compute the inverse of a positive definite matrix A to preserve positive definiteness.

    Parameters
    ----------
    A: Tensor, shape (..., n, n)
        Matrix.
    name_A: str, optional
        Name of the matrix.

    Returns
    -------
    A_inv: Tensor, shape (..., n, n)
        Inverse of A.
    """
    A_cholesky = safe_cholesky(A, 'A in safe_positive_definite_inverse')
    A_cholesky_inv = triangular_inverse(A_cholesky)
    A_inv = A_cholesky_inv.mT @ A_cholesky_inv
    # regulator = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    # regulator_factor = 0.0
    # while True:
    #     A_inv = torch.inverse(A + regulator*regulator_factor)
    #     try:
    #         torch.linalg.cholesky(A_inv)
    #         break
    #     except torch._C._LinAlgError:
    #         if name_A is not None:
    #             print(f'Inversion of {name_A} failed at regulator factor = {regulator_factor}.')
    #         regulator_factor = regulator_factor*10 + 1e-10
    #     if regulator_factor > 1.0:
    #         if name_A is not None:
    #             raise ValueError(f'Inversion of {name_A} failed.')
    #         else:
    #             raise ValueError('Inversion failed.')

    return A_inv



def safe_cholesky(A: Tensor, name_A: str=None):
    """
    Compute the Cholesky decomposition of a positive definite matrix A.

    Parameters
    ----------
    A: Tensor, shape (..., n, n)
        Matrix.
    name_A: str, optional
        Name of the matrix.

    Returns
    -------
    L: Tensor, shape (..., n, n)
        Cholesky decomposition of A.
    """
    regulator = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    regulator_factor = 0.0
    while True:
        try:
            L = torch.linalg.cholesky(A + regulator*regulator_factor)
            break
        except torch._C._LinAlgError:
            if name_A is not None:
                print(f'Cholesky decomposition of {name_A} failed at regulator factor = {regulator_factor}.')
            regulator_factor = regulator_factor*10 + 1e-10
        if regulator_factor > 1.0:
            if name_A is not None:
                raise ValueError(f'Cholesky decomposition of {name_A} failed.')
            else:
                raise ValueError('Cholesky decomposition failed.')

    return L



def safe_logdet(A: Tensor, name_A: str=None):
    """
    Compute the log-determinant of a positive definite matrix A.

    Parameters
    ----------
    A: Tensor, shape (..., n, n)
        Matrix.
    name_A: str, optional
        Name of the matrix.

    Returns
    -------
    logdet: Tensor, shape (...)
        Log-determinant of A.
    """
    regulator = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    regulator_factor = 0.0
    while True:
        logdet = torch.logdet(A + regulator*regulator_factor)
        if torch.isnan(logdet).any():
            if name_A is not None:
                print(f'Log-determinant of {name_A} failed at regulator factor = {regulator_factor}.')
            regulator_factor = regulator_factor*10 + 1e-10
        else:
            break
        if regulator_factor > 1.0:
            if name_A is not None:
                raise ValueError(f'Log-determinant of {name_A} failed.')
            else:
                raise ValueError('Log-determinant failed.')
            
    return logdet
