"""
Synthetic Experiments for Knowledge Distillation

This module implements synthetic experiments to validate the effectiveness of different
similarity measures in knowledge distillation. The experiments compare various geometric
alignment methods on synthetic vector spaces to understand their properties and behavior.

The experiments include:
- CKA (Centered Kernel Alignment) minimization
- Linear/Procrustes alignment
- Kernel Frobenius norm minimization
- Shape similarity optimization

Key Features:
- Low inner product vector generation
- Luby's algorithm for maximal independent sets
- Multiple similarity measure comparisons
- Experiment tracking with wandb

Author: Feature Distillation Research Team
License: Apache 2.0
"""

import wandb
import torch
from tqdm import tqdm
import math
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import similarity_measures as sim
import similarity

# Set random seed for reproducibility
torch.manual_seed(69)

# Device configuration
device = torch.device("cuda")


def generate_low_inner_product_vectors(n, epsilon, device='cpu', seed=None):
    """
    Generate a set of vectors with low inner products between them.
    
    This function creates vectors that are approximately orthogonal to each other,
    with inner products bounded by epsilon. The number of vectors generated is
    exponential in the dimension and epsilon.
    
    Args:
        n (int): Dimension of the vectors
        epsilon (float): Maximum allowed inner product between any two vectors
        device (str): Device to create tensors on ('cpu' or 'cuda')
        seed (int, optional): Random seed for reproducibility
        
    Returns:
        torch.Tensor: Matrix of shape (k, n) where k is the number of vectors generated
        
    Note:
        The number of vectors k is approximately exp((epsilon^2 * n) / 4)
    """
    if seed is not None:
        torch.manual_seed(seed)
        
    # Calculate number of vectors based on Johnson-Lindenstrauss lemma
    k = int(math.exp((epsilon**2 * n) / 4))
    
    # Generate random signs (±1) and normalize
    signs = torch.randint(0, 2, (k, n), device=device, dtype=torch.float32) * 2 - 1
    vectors = signs / math.sqrt(n)  # Normalize each to have norm 1

    # Compute inner product matrix
    ip_matrix = torch.matmul(vectors, vectors.T)
    mask = torch.eye(k, device=device).bool()
    ip_matrix[mask] = 0  # Zero out diagonal

    # Check if all inner products are below epsilon
    max_ip = ip_matrix.abs().max().item()
    success = (ip_matrix.abs() < epsilon).all().item()

    print(f"Generated {k} vectors in {n} dimensions.")
    print(f"Maximum off-diagonal inner product: {max_ip:.4f}")
    print("All inner products < ε:", success)

    return vectors


def luby_mis_from_dense(adj_dense, max_iters=1000, device=device):
    """
    Compute an approximate maximal independent set using Luby's algorithm.
    
    This algorithm finds a maximal independent set in a graph represented by
    an adjacency matrix. It uses a randomized approach that guarantees good
    approximation with high probability.
    
    Args:
        adj_dense (torch.Tensor): Dense n x n adjacency matrix with 1s as edges, 0s elsewhere
        max_iters (int): Maximum iterations to try
        device (str): Device to use ('cpu' or 'cuda')
    
    Returns:
        torch.Tensor: Indices of nodes in the maximal independent set
        
    References:
        Luby, M. (1986). A simple parallel algorithm for the maximal independent set problem.
        SIAM journal on computing, 15(4), 1036-1053.
    """
    adj_dense = adj_dense.to(device)
    n = adj_dense.size(0)
    adj_sparse = adj_dense.to_sparse().coalesce()

    # Initialize sets
    in_set = torch.zeros(n, dtype=torch.bool, device=device)
    remaining = torch.ones(n, dtype=torch.bool, device=device)

    for _ in range(max_iters):
        if not remaining.any():
            break

        # Assign random priorities
        priorities = torch.rand(n, device=device)
        priorities[~remaining] = -1e9  # Effectively ignore these

        # Find maximum priority among neighbors
        neighbor_max = torch.sparse.mm(adj_sparse, priorities.unsqueeze(1)).squeeze(1)

        # Select nodes with higher priority than all neighbors
        selected = (priorities > neighbor_max) & remaining
        in_set[selected] = True

        # Remove selected nodes and their neighbors
        selected_mask = selected.to(adj_dense.dtype).unsqueeze(0)  # (1 x n)
        neighbors_of_selected = (selected_mask @ adj_dense).squeeze(0).bool()

        to_remove = selected | neighbors_of_selected
        remaining[to_remove] = False

    return torch.nonzero(in_set, as_tuple=False).squeeze(1)


def get_adjacency_matrix(vec, eps):
    """
    Compute adjacency matrix based on inner product threshold.
    
    Creates a binary adjacency matrix where two vectors are connected
    if their inner product is above the threshold epsilon.
    
    Args:
        vec (torch.Tensor): Matrix of vectors (n_vectors, n_features)
        eps (float): Threshold for inner product
        
    Returns:
        torch.Tensor: Binary adjacency matrix of shape (n_vectors, n_vectors)
    """
    # Normalize vectors
    norm_vec = vec / vec.norm(dim=1, keepdim=True)
    
    # Compute inner product matrix
    ip_matrix = vec @ vec.T
    mask = torch.eye(len(vec)).bool().to(device)
    ip_matrix[mask] = 0  # Zero out diagonal

    # Create binary adjacency matrix
    return (ip_matrix.abs() < eps).float()


def copy_and_clone(stud_vec):
    """
    Create a copy of student vectors and prepare for training.
    
    Args:
        stud_vec (torch.Tensor): Original student vectors
        
    Returns:
        tuple: (cloned_vectors, dataset) where cloned_vectors requires gradients
    """
    stud_vec_2 = stud_vec.clone().detach()
    stud_vec_2.requires_grad_(True)
    dataset = TensorDataset(vecs, stud_vec_2)
    return stud_vec_2, dataset


def get_kernel_frobenius(x_vec, y_vec):
    """
    Compute the Frobenius norm of the difference between kernel matrices.
    
    This measures the difference between the Gram matrices (kernel matrices)
    of two sets of vectors.
    
    Args:
        x_vec (torch.Tensor): First set of vectors
        y_vec (torch.Tensor): Second set of vectors
        
    Returns:
        torch.Tensor: Frobenius norm of kernel matrix difference
    """
    k_x = x_vec @ x_vec.T
    k_y = y_vec @ y_vec.T
    return torch.norm(k_x - k_y, p="fro")


def run_cka_minimization_experiment(stud_vecs, vecs, epochs=10, batch_size=256):
    """
    Run CKA minimization experiment.
    
    This experiment minimizes the CKA (Centered Kernel Alignment) between
    teacher and student representations while tracking various metrics.
    
    Args:
        stud_vecs (torch.Tensor): Student vectors
        vecs (torch.Tensor): Teacher vectors
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
        
    Returns:
        dict: Final metrics from the experiment
    """
    stud_vec_2_cka, dataset_cka = copy_and_clone(stud_vecs)
    dl = DataLoader(dataset_cka, batch_size=batch_size, shuffle=False)
    optim = torch.optim.AdamW([stud_vec_2_cka])
    
    wandb.init(project="", entity="", name="cka_minim_projected_stud")
    
    for epoch in range(epochs):
        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            a, b = batch
            
            # Compute various similarity measures
            loss = 1 - (cka(b.unsqueeze(1), a.unsqueeze(1)))
            shape = lin(b.unsqueeze(1), a.unsqueeze(1))
            k_frob = get_kernel_frobenius(a, b)
            
            # Compute epsilon-orthogonal vectors
            adj = get_adjacency_matrix(stud_vec_2_cka, 0.2)
            s = luby_mis_from_dense(1 - adj)
            
            # Log metrics
            wandb.log({
                "shape": shape.item(),
                "cka": loss.item(),
                "eps-orth-vectors": len(s),
                "kernel_frob": k_frob
            })
            
            # Optimization step
            loss.backward()
            optim.step()
            optim.zero_grad()
    
    wandb.finish()
    return {"final_cka": loss.item(), "final_shape": shape.item()}


def run_shape_minimization_experiment(stud_vecs, vecs, epochs=10, batch_size=256):
    """
    Run shape similarity minimization experiment.
    
    This experiment minimizes the linear/Procrustes alignment distance between
    teacher and student representations.
    
    Args:
        stud_vecs (torch.Tensor): Student vectors
        vecs (torch.Tensor): Teacher vectors
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
        
    Returns:
        dict: Final metrics from the experiment
    """
    stud_vec_2_shape, dataset_shape = copy_and_clone(stud_vecs)
    dl = DataLoader(dataset_shape, batch_size=batch_size, shuffle=False)
    optim = torch.optim.AdamW([stud_vec_2_shape])
    
    wandb.init(project="", entity="", name="shape_minim_projected_stud")
    
    for epoch in range(epochs):
        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            a, b = batch
            
            # Compute various similarity measures
            ck = 1 - (cka(b.unsqueeze(1), a.unsqueeze(1)))
            loss = lin(b.unsqueeze(1), a.unsqueeze(1))
            k_frob = get_kernel_frobenius(a, b)
            
            # Compute epsilon-orthogonal vectors
            adj = get_adjacency_matrix(stud_vec_2_shape, 0.2)
            s = luby_mis_from_dense(1 - adj)
            
            # Log metrics
            wandb.log({
                "shape": loss.item(),
                "cka": ck.item(),
                "eps-orth-vectors": len(s),
                "kernel_frob": k_frob
            })
            
            # Optimization step
            loss.backward()
            optim.step()
            optim.zero_grad()
    
    wandb.finish()
    return {"final_shape": loss.item(), "final_cka": ck.item()}


def run_kernel_frobenius_experiment(stud_vecs, vecs, epochs=10, batch_size=256):
    """
    Run kernel Frobenius norm minimization experiment.
    
    This experiment minimizes the Frobenius norm of the difference between
    kernel matrices of teacher and student representations.
    
    Args:
        stud_vecs (torch.Tensor): Student vectors
        vecs (torch.Tensor): Teacher vectors
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
        
    Returns:
        dict: Final metrics from the experiment
    """
    stud_vec_2_kfrob, dataset_kfrob = copy_and_clone(stud_vecs)
    dl = DataLoader(dataset_kfrob, batch_size=batch_size, shuffle=False)
    optim = torch.optim.AdamW([stud_vec_2_kfrob])
    
    wandb.init(project="", entity="", name="kernel_frobenius_diff_projected_stud")
    
    for epoch in range(epochs):
        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            a, b = batch
            
            # Compute various similarity measures
            ck = 1 - (cka(b.unsqueeze(1), a.unsqueeze(1)))
            shape = lin(b.unsqueeze(1), a.unsqueeze(1))
            loss = get_kernel_frobenius(a, b)
            
            # Compute epsilon-orthogonal vectors
            adj = get_adjacency_matrix(stud_vec_2_kfrob, 0.2)
            s = luby_mis_from_dense(1 - adj)
            
            # Log metrics
            wandb.log({
                "shape": shape.item(),
                "cka": ck.item(),
                "eps-orth-vectors": len(s),
                "kernel_frob": loss.item()
            })
            
            # Optimization step
            loss.backward()
            optim.step()
            optim.zero_grad()
    
    wandb.finish()
    return {"final_kernel_frob": loss.item(), "final_shape": shape.item()}


def run_linear_projection_experiment(stud_vecs, vecs, epochs=10, batch_size=256):
    """
    Run linear projection experiment.
    
    This experiment learns a linear transformation to align teacher and student
    representations while minimizing the Euclidean distance.
    
    Args:
        stud_vecs (torch.Tensor): Student vectors
        vecs (torch.Tensor): Teacher vectors
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
        
    Returns:
        dict: Final metrics from the experiment
    """
    stud_vec_2_linear, dataset_linear = copy_and_clone(stud_vecs)
    dl = DataLoader(dataset_linear, batch_size=batch_size, shuffle=False)
    
    # Initialize learnable linear transformation
    linear_transform = torch.rand(1000, 500).requires_grad_(True)
    optim = torch.optim.AdamW([stud_vec_2_linear, linear_transform])
    
    wandb.init(project="", entity="", name="linear_projection_diff_projected_stud")
    
    for epoch in range(epochs):
        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            a, b = batch
            
            # Compute various similarity measures
            ck = 1 - (cka(b.unsqueeze(1), a.unsqueeze(1)))
            shape = lin(b.unsqueeze(1), a.unsqueeze(1))
            kernel_frob = get_kernel_frobenius(a, b)
            
            # Compute linear projection loss
            loss = torch.norm(a @ linear_transform - b)
            
            # Compute epsilon-orthogonal vectors
            adj = get_adjacency_matrix(stud_vec_2_linear, 0.2)
            s = luby_mis_from_dense(1 - adj)
            
            # Log metrics
            wandb.log({
                "shape": shape.item(),
                "cka": ck.item(),
                "eps-orth-vectors": len(s),
                "kernel_frob": kernel_frob,
                "loss": loss
            })
            
            # Optimization step
            loss.backward()
            optim.step()
            optim.zero_grad()
    
    wandb.finish()
    return {"final_loss": loss.item(), "final_shape": shape.item()}


def main():
    """
    Main function to run all synthetic experiments.
    
    This function sets up the experimental environment and runs all
    similarity measure experiments with comprehensive logging.
    """
    # Login to wandb for experiment tracking
    wandb.login()
    
    # Generate synthetic data
    print("Generating low inner product vectors...")
    vecs = generate_low_inner_product_vectors(1000, 0.2)
    
    # Create random projection
    proj = torch.randn(1000, 500)
    
    # Initialize student vectors
    stud_vec_2 = torch.randn(22026, 500)
    stud_vec_2 = stud_vec_2 / stud_vec_2.norm(dim=1, keepdim=True)
    stud_vec_2_init = stud_vec_2.data
    stud_vec_2 = stud_vec_2.requires_grad_(True)
    
    # Create projected student vectors
    stud_vecs = vecs @ proj
    stud_vecs = stud_vecs / stud_vecs.norm(dim=1, keepdim=True)
    student_vecs = stud_vecs.requires_grad_(True)
    
    # Initialize similarity measures
    lin = sim.LinearMeasure(approx=True)
    cka = sim.CKA(biased=False)
    measure = similarity.make("measure/netrep/procrustes-distance=euclidean")
    
    # Run experiments
    print("Running CKA minimization experiment...")
    cka_results = run_cka_minimization_experiment(stud_vecs, vecs)
    
    print("Running shape minimization experiment...")
    shape_results = run_shape_minimization_experiment(stud_vecs, vecs)
    
    print("Running kernel Frobenius experiment...")
    kfrob_results = run_kernel_frobenius_experiment(stud_vecs, vecs)
    
    print("Running linear projection experiment...")
    linear_results = run_linear_projection_experiment(stud_vecs, vecs)
    
    # Print summary
    print("\nExperiment Summary:")
    print(f"CKA minimization: {cka_results}")
    print(f"Shape minimization: {shape_results}")
    print(f"Kernel Frobenius: {kfrob_results}")
    print(f"Linear projection: {linear_results}")


if __name__ == "__main__":
    main() 