import torch
import torch.nn.functional as F
import numpy as np
import time
import wandb
import sys

# Import VSA models
from HLB import HLB
from FHRR import FHRR
from MAP_I import MAP_I
from VTB import VTB
# Add others as needed

def generate_item_memory_batch(model, N, device):
    """
    Generates N items efficiently using batches.
    Returns:
        pool_composites: (N, D) tensor
        pool_keys: (N, D) tensor
        pool_values: (N, D) tensor
    """
    # We can try to assume model has a batch interface, or just hack it for common models.
    # But clean way: Just call generate_vector N times? No, slow.
    # Let's check if model.generate_vector accepts 'n' arg? No.
    # We will replicate generation logic basically, or accept the cost if Reuse helps enough.
    # Actually, for HLB/FHRR/HRR, generation is just randn.
    # We can implement a generic batch generator helper.
    
    # 1. Generate Keys and Values in Batch
    # Most VSAs behave like: random tensor -> normalize/quantize
    # HLB: sign(randn(N, D))
    # FHRR: flat complex phasors.
    
    # Let's try to generate N vectors loop-free if possible?
    # If not, the 'Reuse Pool' optimization is the biggest win anyway.
    # Let's just Loop generate but pre-allocate tensor?
    # Pre-allocating is faster than appending to list.
    
    # Creating (N, D) directly is best.
    # If we call model.generate_vector() N times, overhead is high.
    # Let's optimize generation for specific known classes if we detect them, or fallback.
    
    # Optimization: Just optimize 'HLB' and 'FHRR' since those are active.
    
    d = model.d
    
    if model.__class__.__name__ == "HLB":
        # HLB: {-1, 1}
        # keys = torch.randint(0, 2, (N, d), device=device) * 2 - 1
        # Actually HLB.py uses sign(randn). equivalent.
        keys = (torch.rand((N, d), device=device) > 0.5).float() * 2 - 1
        values = (torch.rand((N, d), device=device) > 0.5).float() * 2 - 1
        composites = keys * values # Binding is multiply
        return composites, keys, values

    elif model.__class__.__name__ == "FHRR":
        # FHRR: Phasors. (N, D). D must be even.
        # Generating (N, D//2) angles.
        n_complex = d // 2
        
        theta_k = (torch.rand((N, n_complex), device=device) * 2 * np.pi) - np.pi
        k_c = torch.cos(theta_k)
        k_s = torch.sin(theta_k)
        keys = torch.stack([k_c, k_s], dim=-1).reshape(N, d) # Interleave: Re, Im, Re, Im...
        # Wait, reshape might not interleave correctly if not careful.
        # FHRR.generate stacks dim=-1 then flattens. 
        # [c1, s1, c2, s2...]
        # My stack above: (N, n_complex, 2) -> reshape(N, 2*n_complex).
        # Yes, this matches.
        
        theta_v = (torch.rand((N, n_complex), device=device) * 2 * np.pi) - np.pi
        v_c = torch.cos(theta_v)
        v_s = torch.sin(theta_v)
        values = torch.stack([v_c, v_s], dim=-1).reshape(N, d)
        
        # Bind: Complex Mul.
        # (a+bi)(c+di) = (ac-bd) + (ad+bc)i
        # View as (N, -1, 2)
        k_view = keys.view(N, -1, 2)
        v_view = values.view(N, -1, 2)
        
        re = k_view[..., 0] * v_view[..., 0] - k_view[..., 1] * v_view[..., 1]
        im = k_view[..., 0] * v_view[..., 1] + k_view[..., 1] * v_view[..., 0]
        
        composites = torch.stack([re, im], dim=-1).reshape(N, d)
        return composites, keys, values

    elif model.__class__.__name__ == "MAP_I":
        # MAP_I: Bipolar {-1, 1}. Bind is elementwise mul.
        # Identical generation to HLB roughly.
        keys = (torch.rand((N, d), device=device) > 0.5).float() * 2 - 1
        values = (torch.rand((N, d), device=device) > 0.5).float() * 2 - 1
        composites = keys * values
        return composites, keys, values

    elif model.__class__.__name__ == "VTB":
        # VTB: Gaussian, Norm=1. Setup square matrix logic.
        keys = torch.randn(N, d, device=device)
        keys = F.normalize(keys, p=2, dim=1)
        
        values = torch.randn(N, d, device=device)
        values = F.normalize(values, p=2, dim=1)
        
        # Bind: Reshape (N, d_root, d_root), then bmm
        d_root = model.d_root
        k_mat = keys.view(N, d_root, d_root)
        v_mat = values.view(N, d_root, d_root)
        
        # VTB bind(u, v) -> v @ u
        # Here u=key, v=value. So Value @ Key
        comp_mat = torch.matmul(v_mat, k_mat)
        
        composites = comp_mat.view(N, d)
        return composites, keys, values

    else:
        # Fallback for generic VSA (Slow loop)
        keys_list = [model.generate_vector() for _ in range(N)]
        values_list = [model.generate_vector() for _ in range(N)]
        keys = torch.stack(keys_list)
        values = torch.stack(values_list)
        
        # Batch bind logic? Try to bind one by one or assume elementwise?
        # Safe fallback:
        comps = []
        for i in range(N):
            comps.append(model.bind(keys[i], values[i]))
        composites = torch.stack(comps)
        return composites, keys, values

def run_trial_batch(model_class, model_instance, k, N, pool_data, device):
    """
    Runs a single trial using pre-generated pool.
    """
    model = model_instance
    pool_composites, pool_keys, pool_values = pool_data
    
    # 2. Randomly select k items
    indices = torch.randperm(N, device=device)[:k] # Stay on GPU
    
    selected_composites = pool_composites[indices] # (k, D)
    
    # 3. Bundle
    memory = model.bundle(selected_composites)
    # bundle usually expects list, but if implemented with sum(dim=0), tensor works?
    # BSC/FHRR bundle: "if list: stack". If tensor passed, assumes (k, D)?
    # FHRR: sum(dim=0). Works.
    # HLB: sign(sum). Works.
    
    # 4. Retrieval (Batched)
    
    # Queries: Keys of selected items
    queries = pool_keys[indices] # (k, D)
    
    # Unbind Batch:
    # Most unbinds are elementwise or matrix.
    # If HLB: unbind(u, z) = u * z.
    # If FHRR: z * conj(u).
    # We can try to detect or loop. 
    # For speed, let's just loop the K unbinds (K is small, 10).
    # Batched unbind is better but K=10 is negligible vs N=100k cleanup.
    
    recovered_list = []
    for i in range(k):
        rec = model.unbind(queries[i], memory)
        recovered_list.append(rec)
    
    recovered_batch = torch.stack(recovered_list) # (k, D)
    
    # Cleanup Batch
    # Normalize Recovered: (k, D)
    recovered_batch = F.normalize(recovered_batch, p=2, dim=1)
    
    # Vocab Matrix is pool_values: (N, D)
    # Already on GPU. Normalize it ONCE outside? 
    # Let's normalize inside generate/check to allow reuse.
    
    # Scores: (k, D) @ (D, N) -> (k, N)
    # Note: pool_values might need normalization if not already.
    # We assume 'pool_data' has raw values.
    
    vocab = pool_values # (N, D)
    vocab = F.normalize(vocab, p=2, dim=1) # Fast on GPU
    
    scores = torch.matmul(recovered_batch, vocab.t()) # (k, N)
    
    # Argmax
    best_indices = torch.argmax(scores, dim=1) # (k,)
    
    # Success check
    # best_indices is index in GLOBAL pool (0..N-1)
    # indices is the ground truth indices (0..N-1)
    
    correct = (best_indices == indices).sum().item()
    
    return correct / k

def adjust_dim_for_model(model_class, d):
    """Snaps dimension to valid value for specific models (e.g. VTB needs squares)."""
    if model_class.__name__ == "VTB":
        root = int(np.ceil(np.sqrt(d)))
        new_d = root * root
        if new_d != d:
            # print(f"  [VTB] Adjusting dim {d} -> {new_d}")
            pass
        return new_d
    return d

def check_accuracy(model_class, d, k, N, device):
    """Run averaged trials with POOL REUSE"""
    n_trials = 20
    
    # Adjust dimension if needed
    d = adjust_dim_for_model(model_class, d)
    
    # 1. Initialize Model
    # We remove the broad try/except to see errors, or make it specific
    try:
        model = model_class(d, device=device)
    except Exception as e:
        print(f"Model Init Error (Dim {d}): {e}")
        return 0.0
        
    # 2. Generate Batch Pool (ONCE)
    # This is the heavy lifter
    pool_data = generate_item_memory_batch(model, N, device)
    
    # 3. Run Trials against this pool
    acc_sum = 0
    for _ in range(n_trials):
        acc_sum += run_trial_batch(model_class, model, k, N, pool_data, device)
        
    avg = acc_sum / n_trials
    # print(f"Dim {d}: {avg*100:.1f}%")
    return avg

def find_min_dimension(model_class, k, N, start_dim, device="cuda:0"):
    """
    Finds minimum dimension for >= 99% accuracy.
    Strategy: Coarse search (steps of 100), then Fine search (steps of 10).
    """
    COARSE_STEP = 100
    FINE_STEP = 10
    THRESHOLD = 0.99
    
    current_dim = start_dim
    print(f"\n--- Searching Min Dim for N={N} (Start: {current_dim}) ---")

    # 1. Check Start Point
    acc = check_accuracy(model_class, current_dim, k, N, device)
    
    lower_bound = 0
    upper_bound = 0
    
    if acc >= THRESHOLD:
        # We started too high. Search Down Coarsely.
        # print("-> Starting accuracy >= 99%, searching DOWN (Coarse)...")
        upper_bound = current_dim
        while True:
            next_d = max(FINE_STEP, current_dim - COARSE_STEP)
            if next_d == current_dim: # Hit bottom
                return current_dim
            
            acc = check_accuracy(model_class, next_d, k, N, device)
            if acc < THRESHOLD:
                lower_bound = next_d
                upper_bound = current_dim
                break # Found the bracket [lower, upper]
            else:
                current_dim = next_d # Keep going down
                upper_bound = current_dim
                
    else:
        # Search Up Coarsely
        # print("-> Starting accuracy < 99%, searching UP (Coarse)...")
        lower_bound = current_dim
        while True:
            next_d = current_dim + COARSE_STEP
            acc = check_accuracy(model_class, next_d, k, N, device)
            if acc >= THRESHOLD:
                upper_bound = next_d
                # lower_bound remains the previous current_dim
                break
            else:
                current_dim = next_d
                lower_bound = current_dim

    # 2. Fine Search
    # We know Min Dim is in (lower_bound, upper_bound]
    # We step up from lower_bound in FINE_STEP
    
    # Start fine search
    current_dim = lower_bound + FINE_STEP
    
    # We use a loop that can go back and forth between Search (Light) and Confirm (Heavy)
    phase = "SEARCH_UP" # SEARCH_UP, SEARCH_DOWN, CONFIRM
    
    # Helper for display
    def p_dim(d):
        return adjust_dim_for_model(model_class, d)
        
    while True:
        if phase == "SEARCH_UP":
            acc = check_accuracy(model_class, current_dim, k, N, device)
            if acc >= THRESHOLD:
                # print(f"-> Candidate found (Light check passed): {p_dim(current_dim)}")
                phase = "CONFIRM"
            else:
                current_dim += FINE_STEP
                if current_dim > 100000: # Safety break (Updated to 100k)
                    # print("Error: Exceeded max dims")
                    return adjust_dim_for_model(model_class, upper_bound)
                    
        elif phase == "SEARCH_DOWN":
            acc = check_accuracy(model_class, current_dim, k, N, device)
            if acc < THRESHOLD:
                current_dim += FINE_STEP
                # print(f"-> Candidate found (Light check drop): {p_dim(current_dim)}")
                phase = "CONFIRM"
            else:
                current_dim -= FINE_STEP
                if current_dim < 10:
                    return adjust_dim_for_model(model_class, 10)
                    
        elif phase == "CONFIRM":
            # print(f"--- Verifying Candidate {p_dim(current_dim)} (Heavy Check) ---")
            
            # 1. Check Candidate (5 Runs)
            acc_cand_list = []
            for _ in range(5):
                acc_cand_list.append(check_accuracy(model_class, current_dim, k, N, device))
            mean_cand = np.mean(acc_cand_list)
            
            # 2. Check Candidate Below (5 Runs)
            below_dim = max(10, current_dim - FINE_STEP)
            acc_below_list = []
            for _ in range(5):
                acc_below_list.append(check_accuracy(model_class, below_dim, k, N, device))
            mean_below = np.mean(acc_below_list)
            
            # print(f">> Heavy Check {p_dim(current_dim)}: {mean_cand*100:.2f}% | Below {p_dim(below_dim)}: {mean_below*100:.2f}%")
            
            if mean_cand >= THRESHOLD:
                if mean_below < THRESHOLD:
                    # Perfect!
                    # print(f"-> CONFIRMED: {p_dim(current_dim)} is the boundary.")
                    return adjust_dim_for_model(model_class, current_dim)
                else:
                    # print(f"-> Too High (Below {p_dim(below_dim)} passes Heavy). Searching Down...")
                    current_dim = below_dim
                    phase = "SEARCH_DOWN" 
            else:
                # print(f"-> Too Low (Candidate {p_dim(current_dim)} failed Heavy). Searching Up...")
                current_dim += FINE_STEP
                phase = "SEARCH_UP"


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb", action="store_true", help="Enable WandB logging")
    parser.add_argument("--project", type=str, default="VSA_Item_Memory", help="WandB Project Name")
    args = parser.parse_args()

    # Settings
    # Define list of models to test
    # MODEL_CLASSES = [HLB, FHRR, MAP_I, VTB]
    MODEL_CLASSES = [HLB, FHRR, VTB, MAP_I]
    
    K_BUNDLES = 10
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    # Items sizes to test
    ITEM_SIZES = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000] 
    # ITEM_SIZES = [100000] 
    
    for MODEL_CLASS in MODEL_CLASSES:
        start_dim = 10
        
        print(f"\n========================================")
        print(f"Testing Model: {MODEL_CLASS.__name__}")
        print(f"Bundles (k): {K_BUNDLES}")
        print(f"Device: {DEVICE}")
        
        run = None
        if args.wandb:
            # Re-init wandb for each model as a separate run
            run = wandb.init(project=args.project, name=f"{MODEL_CLASS.__name__}_ItemMemSearch", config={
                "model": MODEL_CLASS.__name__,
                "k": K_BUNDLES
            }, reinit=True)
        
        results = {}
        
        for N in ITEM_SIZES:
            min_dim = find_min_dimension(MODEL_CLASS, K_BUNDLES, N, start_dim, DEVICE)
            print(f"==> Min Dim for N={N}: {min_dim}")
            results[N] = min_dim
            
            if args.wandb:
                wandb.log({"N": N, "Min_Dimension": min_dim})
                
            # Start the next search slightly lower to re-verify robustness
            start_dim = max(10, min_dim - 20)
            
        print(f"\nFINAL RESULTS ({MODEL_CLASS.__name__}):")
        for N, d in results.items():
            print(f"N={N} -> Dim={d}")
            
        if args.wandb and run:
            print("Syncing WandB...")
            run.finish()
