# src/target_functions.py
"""
Defines potential target functions used for generating the ground truth data.
Each function takes a PyTorch tensor of binary inputs and the device string,
and returns a tensor of long integers (0 or 1).
"""
import torch
import torch.nn.functional as F
import hashlib
from typing import Callable, Dict
import numpy as np

# --- Existing Functions ---

def parity_first_half(v: torch.Tensor, device: str) -> torch.Tensor:
    """Parity of the first half of the coordinates."""
    return (v[:, :v.shape[1]//2].sum(dim=1) % 2).long()

def sum_greater_than_5(v: torch.Tensor, device: str) -> torch.Tensor:
    """Output is 1 if the sum of elements is > 5, else 0."""
    # Note: 50/50 probability depends heavily on vector length and input distribution
    return (v.sum(dim=1) > 5).long()

def automata_parity(v, device):
    inds = F.pad(v.float(), (1, 1), 'constant', 0).unfold(1, 3, 1).matmul(torch.tensor([4, 2, 1], device=device, dtype=torch.float))
    rule = torch.tensor([0, 1, 1, 1, 1, 0, 0, 0], device=device, dtype=torch.float)
    return (rule[inds.long()].sum(1) % 2).long()

def sha256_parity(v: torch.Tensor, device: str) -> torch.Tensor:
    """Parity of '1' bits in the SHA-256 hash of the binary string."""
    # Note: Likely 50/50 due to hash properties, but less trivial computation
    results = []
    for row in v:
        row_str = ''.join(map(str, row.int().tolist()))
        hashed = hashlib.sha256(row_str.encode()).hexdigest()
        binary_hash = bin(int(hashed, 16))[2:]
        results.append(binary_hash.count('1') % 2)
    return torch.tensor(results, device=device, dtype=torch.long)

def is_palindrome(v: torch.Tensor, device: str) -> torch.Tensor:
    """Checks if each binary sequence in the batch is a palindrome."""
    # Note: Unlikely to be 50/50 for random inputs, probability of palindrome decreases with length
    flipped_v = torch.flip(v, dims=[1])
    comparison = (v == flipped_v)
    palindrome_check = torch.all(comparison, dim=1)
    return palindrome_check.long()

# --- New Simple Functions (Likely 50/50 on random inputs) ---

def parity_all(v: torch.Tensor, device: str) -> torch.Tensor:
    """Parity of all coordinates."""
    # Parity is typically well-balanced (50/50) for random inputs.
    return (v.sum(dim=1) % 2).long()

def parity_even_indices(v: torch.Tensor, device: str) -> torch.Tensor:
    """Parity of the coordinates at even indices (0, 2, 4...)."""
    # Parity of a random subset is typically well-balanced.
    return (v[:, ::2].sum(dim=1) % 2).long()

def parity_odd_indices(v: torch.Tensor, device: str) -> torch.Tensor:
    """Parity of the coordinates at odd indices (1, 3, 5...)."""
    # Parity of a random subset is typically well-balanced.
    if v.shape[1] < 2: # Handle vectors shorter than 2
        return torch.zeros(v.shape[0], device=device, dtype=torch.long)
    return (v[:, 1::2].sum(dim=1) % 2).long()

def first_bit(v: torch.Tensor, device: str) -> torch.Tensor:
    """Returns the value of the first bit."""
    # If input bits are 50/50 random, output will be 50/50. Extremely simple.
    return v[:, 0].long()

def last_bit(v: torch.Tensor, device: str) -> torch.Tensor:
    """Returns the value of the last bit."""
    # If input bits are 50/50 random, output will be 50/50. Extremely simple.
    return v[:, -1].long()

def middle_bit(v: torch.Tensor, device: str) -> torch.Tensor:
    """Returns the value of the middle bit (floor index for even length)."""
    # If input bits are 50/50 random, output will be 50/50. Extremely simple.
    middle_index = v.shape[1] // 2
    return v[:, middle_index].long()

def xor_first_last(v: torch.Tensor, device: str) -> torch.Tensor:
    """Returns the XOR (parity) of the first and last bits."""
    # If first/last bits are independent random 50/50, XOR is 50/50. Simple.
    # (a != b) is equivalent to (a + b) % 2 for binary inputs
    return (v[:, 0] != v[:, -1]).long()

def first_equals_last(v: torch.Tensor, device: str) -> torch.Tensor:
    """Returns 1 if the first and last bits are equal, 0 otherwise."""
    # Opposite of xor_first_last, also likely 50/50. Simple.
    return (v[:, 0] == v[:, -1]).long()



# -- New tough function --
def noisy_poly(x: torch.Tensor, device: str) -> torch.Tensor:
    """Evaluate x^2 + 2x + 1 with Gaussian noise."""
    sum_x = x.sum(dim=1).float()
    return (sum_x**2 + 2*sum_x + 1 + 10 * torch.randn_like(sum_x)).long()

def is_sorted_binary_1d(v: torch.Tensor, device: str) -> torch.Tensor:
    """
    Returns 1 if the 1D binary tensor v is sorted in non-decreasing order (0s before 1s), 0 otherwise.
    """
    is_sorted = (v[1:] >= v[:-1]).all()
    return torch.tensor(int(is_sorted), device=device).long()

def dyck2(v: torch.Tensor, device: str) -> torch.Tensor:
    """
    Returns (N,) tensor where each entry is 1 if the decoded paren sequence is valid, else 0.
    """
    pmap = {"00": "(", "01": ")", "10": "[", "11": "]"}
    match = {')': '(', ']': '['}
    to_paren = lambda row: "".join(pmap[f"{row[i]}{row[i+1]}"] for i in range(0, len(row), 2))
    
    def is_valid(s):
        stack = []
        for c in s:
            if c in match.values(): stack.append(c)
            elif not stack or stack.pop() != match[c]: return 0
        return int(not stack)
    
    return torch.tensor([is_valid(to_paren(r)) for r in v.tolist()], device=device, dtype=torch.float)

#Random parity function
def parity_rand_10(v, device):
    torch.manual_seed(42)
    idx = torch.randperm(v.shape[1])[:10]
    return (v[:, idx].sum(dim=1) % 2).long()

def parity_rand_5(v, device):
    torch.manual_seed(42)
    idx = torch.randperm(v.shape[1])[:5]
    return (v[:, idx].sum(dim=1) % 2).long()

def parity_rand_1(v, device):
    torch.manual_seed(42)
    idx = torch.randperm(v.shape[1])[:1]
    return (v[:, idx].sum(dim=1) % 2).long()

def parity_rand_3(v, device):
    torch.manual_seed(42)
    idx = torch.randperm(v.shape[1])[:3]
    return (v[:, idx].sum(dim=1) % 2).long()

def parity_rand_4(v, device):
    torch.manual_seed(42)
    idx = torch.randperm(v.shape[1])[:4]
    return (v[:, idx].sum(dim=1) % 2).long()

def patternmatch1(v: torch.Tensor, device: str) -> torch.Tensor:
    pattern = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=v.dtype, device=device)
    N, L = v.shape
    match_length = pattern.size(0)
    results = torch.zeros(N, dtype=torch.bool, device=device)

    for i in range(L - match_length + 1):
        window = v[:, i:i + match_length]  # (N, match_length)
        match = (window == pattern).all(dim=1)  # (N,)
        results = results | match  # now valid because both are bool

    return results.float()

def patternmatch2(v: torch.Tensor, device: str) -> torch.Tensor:
    pattern = torch.tensor([0, 0, 1, 1, 1, 1, 1, 1], dtype=v.dtype, device=device)
    N, L = v.shape
    match_length = pattern.size(0)
    results = torch.zeros(N, dtype=torch.bool, device=device)

    for i in range(L - match_length + 1):
        window = v[:, i:i + match_length]  # (N, match_length)
        match = (window == pattern).all(dim=1)  # (N,)
        results = results | match  # now valid because both are bool

    return results.float()

# --

# --- Updated Dictionary ---
TARGET_FUNCTIONS: Dict[str, Callable[[torch.Tensor, str], torch.Tensor]] = {
    # Original
    'parity_first_half': parity_first_half,
    'sum_greater_than_5': sum_greater_than_5, # Note: Not guaranteed 50/50
    'automata_parity': automata_parity,
    'sha256_parity': sha256_parity,           # Note: Less trivial compute
    'is_palindrome': is_palindrome,           # Note: Not likely 50/50

    # New Simple Functions (Likely 50/50 for random inputs)
    'parity_all': parity_all,
    'parity_even_indices': parity_even_indices,
    'parity_odd_indices': parity_odd_indices,
    'first_bit': first_bit,
    'last_bit': last_bit,
    'middle_bit': middle_bit,
    'xor_first_last': xor_first_last,
    'first_equals_last': first_equals_last,
    'noisy_poly': noisy_poly,
    'is_sorted_binary_1d': is_sorted_binary_1d,
    'dyck2': dyck2,
    'parity_rand_10' : parity_rand_10,
    'parity_rand_5' : parity_rand_5,
    'parity_rand_4' : parity_rand_4,
    'parity_rand_3' : parity_rand_3,
    'parity_rand_1' : parity_rand_1,
    'patternmatch1' : patternmatch1,
    'patternmatch2' : patternmatch2,
}