import torch
import math
from torch.distributions.categorical import Categorical


def sample_worst_case_dist(
    batch_size: int, 
    num_points: int, 
    var_scale: float = 2, 
    normalize: bool = False
) -> torch.Tensor:
    """Generates challenging distributions for nearest neighbor search using binary tree structure.
    
    Creates data where nearest neighbor search is difficult by organizing points in a binary tree
    hierarchy and scaling variance with depth. This structure forces algorithms to explore multiple
    tree levels for accurate search.
    
    Args:
        batch_size: Number of data batches to generate
        num_points: Number of points per batch
        var_scale: Variance scaling factor per tree level
        normalize: Whether to normalize data to [-1, 1] range
        
    Returns:
        Generated data tensor with shape (batch_size, num_points)
    """
    # Initialize tree structure tracking
    node_depth = torch.zeros(num_points, dtype=torch.int32) - 1
    node_path = torch.zeros(num_points, num_points)

    def build_tree_structure(
        current_depth: int,
        position: int,
        parent_pos: int,
        sign: int,
        low: int,
        high: int
    ) -> None:
        """Recursively constructs binary tree paths and depths for each node."""
        if position < 0 or position >= num_points:
            return
            
        node_depth[position] = current_depth
        if parent_pos != -1:
            node_path[position] = node_path[parent_pos].clone()
        node_path[position, position] = sign

        # Calculate child positions and recurse
        left_child = (low + position - 1) // 2
        right_child = (high + position + 1) // 2
        
        if left_child >= 0 and node_depth[left_child] == -1:
            build_tree_structure(current_depth+1, left_child, position, -1, low, position-1)
        if right_child < num_points and node_depth[right_child] == -1:
            build_tree_structure(current_depth+1, right_child, position, 1, position+1, high)

    # Build tree from middle point
    build_tree_structure(0, (num_points - 1) // 2, -1, 1, 0, num_points - 1)
    
    # Generate scaled variances based on node depth
    max_depth = int(math.log2(num_points)) + 1
    depth_scales = var_scale ** torch.arange(1, max_depth + 1).flip(0)
    point_scales = depth_scales[node_depth]
    
    # Create base data and apply tree structure
    data = torch.rand(batch_size, num_points) * point_scales[None, :]
    data[:, (num_points - 1) // 2] -= depth_scales[0] / 2  # Center root node
    data = data @ node_path.T  # Project data through tree paths

    # Normalize if requested
    if normalize:
        abs_max = depth_scales[0]/2 + depth_scales[1:].sum()
        data = (data + abs_max) / (2 * abs_max) * 2 - 1

    # Randomize point order within each batch
    return data.gather(1, torch.rand((batch_size, num_points)).argsort())


def sample_zipf(
    num_values: int, 
    alpha: float, 
    sample_shape: tuple
) -> torch.Tensor:
    """Samples from a Zipf distribution using categorical sampling."""
    probabilities = torch.arange(1, num_values + 1).float().reciprocal().pow(alpha)
    return Categorical(probs=probabilities).sample(sample_shape)


def sample_saved_data(
    dataset: torch.Tensor,
    num_points: int,
    batch_size: int,
    dim: int,
    n_resample: int = 1,
    device: str = "cuda"
) -> tuple:
    """Samples nearest neighbor data from a pre-existing dataset."""
    data_ids = torch.randint(0, len(dataset), (batch_size // n_resample * num_points,))
    input_data = dataset[data_ids].reshape(batch_size // n_resample, num_points, dim)
    input_data = input_data.repeat(n_resample, 1, 1)
    
    query_data = dataset[torch.randint(0, len(dataset), (batch_size,))]
    distances = (input_data - query_data.unsqueeze(1)).pow(2).mean(-1)
    _, nn_ids = distances.min(dim=1)
    
    return (
        input_data.to(device),
        query_data.to(device),
        input_data[torch.arange(batch_size), nn_ids].to(device),
        nn_ids.to(device)
    )


def generate_data(
    num_points: int,
    batch_size: int,
    dim: int,
    data_range: float,
    query_range: float,
    worst_case_var: float,
    dist_type: str = "uniform",
    nn_is_query: bool = True,
    n_resample: int = 1,
    device: str = "cuda",
    sort_inputs: bool = False,
    random_query: bool = False,
    query_corr: float = 1.0,
    flip_ratio: float = 0.5,
    n_vals: int = None,
    zipf_alpha: float = 1.25,
    saved_data: torch.Tensor = None
) -> tuple:
    """Generates synthetic datasets for nearest neighbor search experiments.
    
    Supports multiple distribution types and query generation strategies.
    
    Args:
        num_points: Number of data points per batch
        batch_size: Number of batches to generate
        dim: Data dimensionality
        data_range: Value range for uniform distributions
        query_range: Noise range for query generation
        worst_case_var: Variance scaling for worst-case distribution
        dist_type: Data distribution type (uniform, normal, binary, worst_case, fixed, zipf)
        nn_is_query: Whether query is exact nearest neighbor
        n_resample: Number of times to resample base data
        device: Output device for tensors
        sort_inputs: Whether to sort input data points
        random_query: Generate queries independently from data
        query_corr: Noise correlation for query generation
        flip_ratio: Orthogonal component ratio for normal distribution queries
        n_vals: Number of unique values for fixed/zipf distributions
        zipf_alpha: Skew parameter for Zipf distribution
        saved_data: Pre-existing dataset to sample from
    
    Returns:
        Tuple containing:
        - input_data: Batch of data points (shape: [batch_size, num_points, dim])
        - query_data: Batch of query points (shape: [batch_size, dim])
        - target_data: Nearest neighbor targets (shape: [batch_size, dim])
        - nn_ids: Indices of nearest neighbors (shape: [batch_size])
    """
    if saved_data is not None:
        return sample_saved_data(saved_data, num_points, batch_size, dim, n_resample, device)

    # Generate base input data
    base_batch = batch_size // n_resample
    if dist_type == "uniform":
        input_data = torch.FloatTensor(base_batch, num_points, dim).uniform_(-data_range, data_range)
    elif dist_type == "normal":
        input_data = torch.randn(base_batch, num_points, dim)
        input_data = input_data / input_data.norm(dim=-1, keepdim=True)
    elif dist_type == "binary":
        input_data = torch.randn(base_batch, num_points, dim).sign()
    elif dist_type == "worst_case":
        input_data = torch.zeros(base_batch, num_points, dim)
        for i in range(dim):
            input_data[..., i] = sample_worst_case_dist(base_batch, num_points, worst_case_var)
    elif dist_type in ("fixed", "zipf"):
        input_data = torch.randn(base_batch, n_vals).argsort(dim=1)
        input_data = input_data.float() / n_vals * 2 - 1  # Scale to [-1, 1)
        input_data = input_data[:, :num_points, None]  # Add dummy dim
    else:
        raise ValueError(f"Unknown distribution type: {dist_type}")

    # Resample and prepare input data
    input_data = input_data.repeat(n_resample, 1, 1)
    batch_size = input_data.size(0)  # Update in case of resampling

    # Generate query points
    if nn_is_query and dist_type != "zipf":
        nn_ids = torch.randint(0, num_points, (batch_size,))
        query_data = input_data[torch.arange(batch_size), nn_ids]
    else:
        if dist_type == "fixed":
            query_data = torch.randint(0, n_vals, (batch_size, 1)).float()
            query_data = query_data / n_vals * 2 - 1
        elif dist_type == "zipf":
            query_data = sample_zipf(n_vals, zipf_alpha, (batch_size, 1)).float()
            query_data = query_data / n_vals * 2 - 1
        elif random_query:
            query_data = torch.FloatTensor(batch_size, dim).uniform_(-query_range, query_range)
        else:
            # Generate queries by perturbing data points
            rand_idx = torch.randint(0, num_points, (batch_size,))
            base_points = input_data[torch.arange(batch_size), rand_idx]

            if dist_type == "normal":
                # Create orthogonal perturbation for normal distribution
                perturbation = torch.randn(batch_size, dim)
                projection = (perturbation * base_points).sum(-1, keepdim=True)
                perturbation -= projection * base_points
                perturbation /= perturbation.norm(dim=-1, keepdim=True)
                query_data = base_points * (1 - flip_ratio) + perturbation * flip_ratio
            else:
                # Add scaled noise for other distributions
                query_data = base_points + torch.randn_like(base_points) * query_corr

    # Find nearest neighbors if not already set
    if not (nn_is_query and dist_type != "zipf"):
        distances = (input_data - query_data.unsqueeze(1)).pow(2).sum(-1)
        _, nn_ids = distances.min(dim=1)

    # Prepare final outputs
    target_data = input_data[torch.arange(batch_size), nn_ids]
    if sort_inputs:
        input_data = input_data.sort(dim=1).values

    # Move all tensors to specified device
    return (
        input_data.to(device),
        query_data.to(device),
        target_data.to(device),
        nn_ids.to(device)
    )


def generate_mnist_data(N, B, n_resample=1, device='cuda', sort_inputs = False, images=None, feature_model=None, n_vals=50, n_image_samples=10, return_imgs=False, feats=None):
    rand_nums = torch.randn(B, n_vals).sort().indices
    input_vals = rand_nums[:B//n_resample, :N].repeat(n_resample, 1)
    if sort_inputs:
        input_vals = input_vals.sort(-1).values
    query_vals = rand_nums[:, N]
    nn_ids = (input_vals - query_vals[:, None]).square().argmin(-1)
    target_vals = input_vals[torch.arange(B), nn_ids]
    input_sample_ids = torch.randint(0, n_image_samples, (B//n_resample, N)).repeat(n_resample, 1)
    query_sample_ids = torch.randint(0, n_image_samples, (B,))


    if feats is not None: ## Use pre-processed feats
        query_feats = feats[query_vals, query_sample_ids]
        input_feats = feats[input_vals.view(-1), input_sample_ids.view(-1)].reshape(B, N, 1)
        target_feats = input_feats[torch.arange(B), nn_ids]
    else:
        if B*N > n_vals*n_image_samples and return_imgs == False: ## Less unique images, quicker to process them all:
            image_feats = feature_model(images.view(n_vals * n_image_samples, 28, 84)).view(n_vals, n_image_samples, -1)
            query_feats = image_feats[query_vals, query_sample_ids]
            input_feats = image_feats[input_vals.view(-1), input_sample_ids.view(-1)].reshape(B, N, 1)
            target_feats = input_feats[torch.arange(B), nn_ids]
        else:
            query_images = images[query_vals, query_sample_ids]
            input_images = images[input_vals.view(-1), input_sample_ids.view(-1)].view(B, N, 28, 84)
            target_images = input_images[torch.arange(B), nn_ids]
            input_feats = feature_model(input_images.view(B*N, 28, 84)).view(B, N, 1)
            query_feats = feature_model(query_images)
            target_feats = feature_model(target_images)

    if return_imgs:
        return input_feats.to(device), query_feats.to(device), target_feats.to(device), nn_ids.to(device), input_vals.to(device), query_vals.to(device), target_vals.to(device), input_images.to(device), input_sample_ids.to(device), query_sample_ids.to(device)
    else:
        return input_feats.to(device), query_feats.to(device), target_feats.to(device), nn_ids.to(device), input_vals.to(device), query_vals.to(device), target_vals.to(device)


def generate_freq_data(B, N, n_vals, dist_type, zipf_alpha, device, shuffle_ordering, n_resample=1, use_features=False, heavy_subspace=None, n_heavy=1, query_is_stream=False):
    if dist_type == 'zipf':
        stream, queries = sample_zipf(n_vals, zipf_alpha, (2, B, N)).unbind(0)
    elif dist_type == 'uniform':
        stream, queries = torch.randint(0, n_vals, (2, B, N)).unbind(0)

    stream = stream[:B // n_resample].repeat(n_resample, 1)

    stream = stream.cuda()
    queries = queries.cuda()
    if query_is_stream:
        queries = stream

    if shuffle_ordering:
        rand_perms = torch.randn(B // n_resample, n_vals).sort(-1).indices.repeat(n_resample, 1).cuda()
        stream = rand_perms.gather(1, stream)
        queries = rand_perms.gather(1, queries)
    all_counts = torch.nn.functional.one_hot(stream, n_vals).cumsum(dim=1).cuda()
    query_counts = all_counts.gather(2, queries[..., None]).squeeze(-1)

    return stream.to(device), queries.to(device), all_counts.to(device), query_counts.to(device)