#!/usr/bin/env python3

import os
import argparse
import torch
import numpy as np

import torch
import math
import numpy as np
import sys
from tqdm import tqdm



import torch

# ----------------------------------------------------------------
# Global caches for subsets, signs, etc.
# Key in the cache is (n, device, dtype, tag).
# ----------------------------------------------------------------
_subsets_mask_cache = {}
_ryser_sign_cache = {}
_sign_vectors_cache = {}


def get_subsets_mask(n: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """
    Returns a mask of shape (n, 2^n), where mask[j, s] = 1 if
    the j-th bit of subset s is set (i.e., j in S), else 0,
    cast to the specified dtype.
    """
    key = (n, device, dtype, 'subsets_mask')
    if key not in _subsets_mask_cache:
        # We'll compute the bit mask in long, then convert at the end to the desired dtype
        s = torch.arange(2**n, dtype=torch.long, device=device)  # (2^n,)
        bits_list = []
        for j in range(n):
            bits_list.append(((s >> j) & 1).unsqueeze(0))  # shape (1, 2^n)
        mask_long = torch.cat(bits_list, dim=0)  # (n, 2^n) in {0,1}, dtype=long
        mask = mask_long.to(dtype=dtype)
        _subsets_mask_cache[key] = mask
    return _subsets_mask_cache[key]


def get_ryser_sign(n: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """
    Returns the sign vector of shape (2^n,) for Ryser's formula: (-1)^(n - |S|),
    cast to the specified dtype.
    """
    key = (n, device, dtype, 'ryser_sign')
    if key not in _ryser_sign_cache:
        s = torch.arange(2**n, dtype=torch.long, device=device)  # (2^n,)
        # popcount: number of set bits
        popcounts = s.unsqueeze(1).bitwise_and(
            1 << torch.arange(n, device=device)
        ) > 0  # (2^n, n) bool
        subset_size = popcounts.sum(dim=1)  # (2^n,) integer
        # sign = (-1)^(n - subset_size)
        # We'll use float cast
        signs = (-1)**(n - subset_size)
        signs = signs.to(dtype)
        _ryser_sign_cache[key] = signs
    return _ryser_sign_cache[key]


def get_sign_vectors(n: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """
    Generate all sign vectors {+1, -1}^n as a tensor of shape (2^n, n),
    mapping bit=0->-1, bit=1->+1, cast to the specified dtype.
    """
    key = (n, device, dtype, 'sign_vectors')
    if key not in _sign_vectors_cache:
        s = torch.arange(2**n, dtype=torch.long, device=device)  # (2^n,)
        bits_list = []
        for j in range(n):
            bits_list.append(((s >> j) & 1).unsqueeze(1))  # (2^n, 1), bool
        bits = torch.cat(bits_list, dim=1).float()  # (2^n, n) in {0,1}, float
        # map 0->-1, 1->+1
        sign_vectors = 2*bits - 1
        sign_vectors = sign_vectors.to(dtype=dtype)
        _sign_vectors_cache[key] = sign_vectors
    return _sign_vectors_cache[key]


# ----------------------------------------------------------------
# The core permanent computations (without row normalization)
# ----------------------------------------------------------------

def permanent_ryser(A: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Compute the permanent of a batch of (n_data, n, n) nonnegative matrices
    using Ryser's formula. Returns shape (n_data,).

    You can select float precision by specifying dtype=torch.float32 or torch.float64.
    """
    # Convert input to the desired dtype
    A = A.to(dtype)
    device = A.device
    n_data = A.shape[0]
    n = A.shape[1]

    # retrieve cached mask and sign
    mask = get_subsets_mask(n, device, dtype)  # (n, 2^n)
    sign = get_ryser_sign(n, device, dtype)    # (2^n,)

    # partial_sums: sum_{j in subset S} A[i, j], shape => (n_data, n, 2^n)
    partial_sums = torch.matmul(A, mask)

    # product over rows => shape (n_data, 2^n)
    row_products = partial_sums.prod(dim=1)

    # multiply by sign and sum => shape (n_data,)
    result = row_products * sign.unsqueeze(0)  # broadcast sign
    perm = result.sum(dim=1)
    return perm


def permanent_glynn(A: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Compute the permanent of a batch of (n_data, n, n) nonnegative matrices
    using Glynn's formula. Returns shape (n_data,).

    You can select float precision by specifying dtype=torch.float32 or torch.float64.
    """
    A = A.to(dtype)
    device = A.device
    n_data = A.shape[0]
    n = A.shape[1]

    # sign vectors => shape (2^n, n)
    eps = get_sign_vectors(n, device, dtype)
    eps_prod = eps.prod(dim=1)  # (2^n,)

    # broadcast for bmm => shape (n_data, 2^n, n)
    eps_expanded = eps.unsqueeze(0).expand(n_data, -1, -1)

    # sum across rows with signs => shape (n_data, 2^n, n)
    col_sums = torch.bmm(eps_expanded, A)

    # product over columns => shape (n_data, 2^n)
    col_products = col_sums.prod(dim=2)

    # multiply by eps_prod, sum => shape (n_data,)
    result = col_products * eps_prod.unsqueeze(0)
    perm = result.sum(dim=1)

    # normalization factor
    perm = perm / (2.0 ** n)
    return perm


# ----------------------------------------------------------------
# LOG-SPACE VERSIONS WITH ROW NORMALIZATION
# ----------------------------------------------------------------

def permanent_ryser_log(A: torch.Tensor,
                        dtype: torch.dtype = torch.float32,
                        eps: float = 1e-40) -> torch.Tensor:
    """
    Computes log(permanent(A)) for a batch of nonnegative matrices (A)
    using Ryser's formula with row-normalization to mitigate overflow/underflow.
    
    A: (n_data, n, n)
    dtype: torch.float32 or torch.float64
    eps: small constant to avoid log(0.0)
    Returns: (n_data,) for log(permanent(A)).
    """
    A = A.to(dtype)
    device = A.device
    n_data = A.shape[0]
    n = A.shape[1]
    
    # 1) row maxima => shape (n_data, n, 1)
    row_scales = A.amax(dim=2, keepdim=True)  # row-wise max

    # 2) check for zero rows => if row_scale=0 => that row is all zero => permanent=0
    row_scales_squeezed = row_scales.squeeze(2)            # (n_data, n)
    zero_rows_mask = (row_scales_squeezed == 0).any(dim=1) # (n_data,)

    # 3) normalize
    # We'll do a "where" to avoid dividing by zero
    normalized = torch.where(row_scales > 0, A / row_scales, torch.zeros_like(A))

    # 4) permanent of the normalized matrix
    perm_norm = permanent_ryser(normalized, dtype=dtype)

    # 5) log(perm_norm)
    log_perm_norm = torch.log(perm_norm + eps)

    # 6) add sum(log(row maxes))
    sum_log_scales = row_scales_squeezed.clamp_min(eps).log().sum(dim=1)
    log_perm = log_perm_norm + sum_log_scales

    # 7) if any row is zero => permanent=0 => log=-inf
    log_perm = torch.where(zero_rows_mask,
                           torch.full_like(log_perm, float('-inf')),
                           log_perm)

    return log_perm


def permanent_glynn_log(A: torch.Tensor,
                        dtype: torch.dtype = torch.float32,
                        eps: float = 1e-40) -> torch.Tensor:
    """
    Computes log(permanent(A)) for a batch of nonnegative matrices (A)
    using Glynn's formula with row-normalization to mitigate overflow/underflow.

    A: (n_data, n, n)
    dtype: torch.float32 or torch.float64
    eps: small constant to avoid log(0.0)
    Returns: (n_data,) for log(permanent(A)).
    """
    A = A.to(dtype)
    # print(A.shape)
    device = A.device
    n_data = A.shape[0]
    n = A.shape[1]

    # 1) row maxima
    row_scales = A.amax(dim=2, keepdim=True)  # (n_data, n, 1)

    # 2) zero-rows?
    row_scales_squeezed = row_scales.squeeze(2)            # (n_data, n)
    zero_rows_mask = (row_scales_squeezed == 0).any(dim=1) # (n_data,)

    # 3) normalize
    normalized = torch.where(row_scales > 0, A / row_scales, torch.zeros_like(A))

    # 4) compute permanent
    perm_norm = permanent_glynn(normalized, dtype=dtype)

    # 5) log(perm_norm)
    log_perm_norm = torch.log(perm_norm + eps)

    # 6) sum(log(row maxes)) => shape (n_data,)
    sum_log_scales = row_scales_squeezed.clamp_min(eps).log().sum(dim=1)

    # 7) combine
    log_perm = log_perm_norm + sum_log_scales

    # 8) zero row => -inf
    log_perm = torch.where(zero_rows_mask,
                           torch.full_like(log_perm, float('-inf')),
                           log_perm)

    return log_perm

def glynn(x_in,x_out,n,sigma):
    n_max = x_in.shape[1]


    pairwise_dist = torch.linalg.norm(
        x_in[:, None, :] - x_out[None, :, :],
        dim=-1
    )
    pairwise_gaussian_log = -(pairwise_dist**2 / (2 * sigma[None,None]**2))
    pairwise_gaussian = torch.exp(pairwise_gaussian_log)
    logZ_glynn = permanent_glynn_log(pairwise_gaussian[None,:,:])
    return logZ_glynn

score_glynn = vmap(jacrev(glynn, argnums=1))