import matplotlib.pyplot as plt
import numpy as np
import torch

def closest_divisor_leq(n: int, k: int) -> int:
    """
    Finds the largest divisor of n that is <= k.
    """
    k = min(k, n)
    for d in range(k, 0, -1):
        if n % d == 0:
            return d
    return 1

alpha = 0.5

def majority(x):
    """
        Majority function over n bits.
    """
    return int(x.sum() > x.shape[0] / 2)

def parity(x):
    """
        Returns 1 if the number of 1s in x is odd.
        Returns 0 otherwise
    """
    return int(x.sum() % 2 != 0)

def create_junta_parity(n, k=3, seed=None):
    """
    Creates a k-junta parity function on n bits.
    
    Args:
        n: Total number of bits
        k: Number of relevant coordinates (default 20)
        seed: Random seed for reproducibility
    
    Returns:
        A function that computes parity on k randomly selected bits
    """
    # Set seed for reproducibility if provided
    if seed is not None:
        torch.manual_seed(seed)
    
    # Randomly select k coordinates that will be relevant
    relevant_coords = torch.randperm(n)[:k].tolist()
    
    def junta(x, rel_coords):
        """Computes parity of the k relevant bits in x."""
        x = x.squeeze(0)

        # eta = 0.01 # small noise
        # flip = False
        # if torch.rand(1).item() < eta:
        #     flip = True

        device = x.device
        rel_coords = torch.tensor(rel_coords, dtype=torch.long, device=device)

        # Extract only the relevant bits
        relevant_bits = x[rel_coords]
        
        # Compute parity for {0,1} input representations
        o = int(relevant_bits.sum() % 2)
        # return 1-o if flip else o
        return o
    
    return junta, relevant_coords