import torch
import torch.nn.functional as F
import numpy as np
import time
import wandb
import sys
import argparse
import os

# Set WandB to offline mode
os.environ["WANDB_MODE"] = "offline"

# Memory Import
from Memory_General_Torch import Memory

def generate_item_memory_batch(N, dim, p, num_roles, device):
    """
    Generates N items (Roles and Fillers) for OrthHash.
    Returns:
        pool_roles: (N, num_roles, dim)
        pool_fillers: (N, p, dim)
    """
    # Simply random gaussian vectors for basic functionality
    # OrthHash typically benefits from orthonormal roles, but for 
    # "Basic VSA" checks with N=1 unique items, random high-dim is quasi-orthogonal.
    # If explicit orthogonality is needed per item (e.g. within num_roles), 
    # we rely on random chance or could add QR. 
    # Given the scale (N=100k), explicit QR for all is expensive.
    # We proceed with Random Normal as proxies.
    
    pool_roles = torch.randn(N, num_roles, dim, device=device)
    pool_fillers = torch.randn(N, p, dim, device=device)
    
    return pool_roles, pool_fillers

def run_trial_batch(dim, k, N, p, num_roles, pool_data, device):
    """
    Runs a single trial using pre-generated pool.
    """
    pool_roles, pool_fillers = pool_data
    
    # 1. Initialize Fresh Memory for this trial
    # (Memory storage is internal to the class, so we instantiate a new one)
    # isWedge=False (Outer product default for basic VSA behavior usually)
    model = Memory(N=dim, p=p, isWedge=False, device=device)
    
    # 2. Randomly select k items
    indices = torch.randperm(N, device=device)[:k]
    
    selected_roles = pool_roles[indices]     # (k, num_roles, D)
    selected_fillers = pool_fillers[indices] # (k, p, D)
    
    # 3. Bundle (Bind selected items into Memory)
    # Memory.bind is stateful. We assume it adds to the sum.
    for i in range(k):
        model.bind(selected_roles[i], selected_fillers[i])
        
    # 4. Retrieval (Batched Scoring)
    # For each selected item (Role), check against ALL Fillers in pool.
    
    # Pre-computation: The memory is fixed now.
    # We query with the k Roles.
    # We score against N Fillers.
    
    success_count = 0
    
    for i in range(k):
        target_idx = indices[i].item()
        query_role = selected_roles[i] # (num_roles, D)
        
        # Score this role against ALL N fillers
        # (N, p, D) input to batched score
        scores = model.get_dual_contraction_score_batched(query_role, pool_fillers)
        # scores: (N,) tensor of scores
        
        best_idx = torch.argmax(scores).item()
        
        if best_idx == target_idx:
            success_count += 1
            
    return success_count / k

import math

def check_accuracy(dim, k, N, p, device):
    """Run averaged trials with POOL REUSE"""
    n_trials = 20 # High robustness
    
    # Dynamic Roles
    num_roles = max(1, int(math.sqrt(dim)))
    
    # 1. Generate Batch Pool (ONCE per dimension check)
    try:
        pool_data = generate_item_memory_batch(N, dim, p, num_roles, device)
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"OOM at Dim {dim}")
            return 0.0
        raise e
    
    # 3. Run Trials against this pool
    acc_sum = 0
    for _ in range(n_trials):
        acc_sum += run_trial_batch(dim, k, N, p, num_roles, pool_data, device)
        
    avg = acc_sum / n_trials
    print(f"Dim {dim}: {avg*100:.1f}% (Roles: {num_roles})")
    return avg

def find_min_dimension(k, N, p, start_dim, device):
    """
    Finds minimum dimension for >= 99% accuracy.
    Strategy: Coarse search (steps of 100), then Fine search (steps of 10).
    Robust Confirmation: 5-trial Mean check.
    """
    COARSE_STEP = 100
    FINE_STEP = 10
    THRESHOLD = 0.99
    
    current_dim = start_dim
def verify_candidate(dim, k, N, p, device):
    """Performs the Heavy Check (5 trials)"""
    print(f"--- Verifying Candidate {dim} (Heavy Check) ---")
    acc_list = []
    for _ in range(5):
        acc_list.append(check_accuracy(dim, k, N, p, device))
    mean_acc = np.mean(acc_list)
    print(f">> Heavy Check {dim}: {mean_acc*100:.2f}%")
    return mean_acc >= 0.99

def find_min_dimension(k, N, p, start_dim, device):
    """
    Finds minimum dimension for >= 99% accuracy.
    Multi-stage: Step 10 -> Step 3 -> Step 1.
    """
    COARSE_STEP = 25
    FINE_STEP = 10
    THRESHOLD = 0.99
    
    # --- STAGE 1: Step 10 Boundary Search ---
    current_dim = start_dim
    print(f"\n--- Searching Min Dim for N={N} (Stage 1: Step 10) ---")

    # 1. Check Start Point
    acc = check_accuracy(current_dim, k, N, p, device)
    
    lower_bound = 0
    upper_bound = 0
    
    if acc >= THRESHOLD:
        print("-> Starting accuracy >= 99%, searching DOWN (Coarse)...")
        upper_bound = current_dim
        while True:
            next_d = max(FINE_STEP, current_dim - COARSE_STEP)
            if next_d == current_dim:
                # Bottomed out at 10. Check if 10 is actually good.
                # If so, return 10 (can't refine further than 1)
                # But refinement logic handles it if we treat d_10=10.
                break 
            
            acc = check_accuracy(next_d, k, N, p, device)
            if acc < THRESHOLD:
                lower_bound = next_d
                upper_bound = current_dim
                break 
            else:
                current_dim = next_d 
                upper_bound = current_dim
                
    else:
        print("-> Starting accuracy < 99%, searching UP (Coarse)...")
        lower_bound = current_dim
        while True:
            next_d = current_dim + COARSE_STEP
            acc = check_accuracy(next_d, k, N, p, device)
            if acc >= THRESHOLD:
                upper_bound = next_d
                break
            else:
                current_dim = next_d
                lower_bound = current_dim

    # Fine Search (Step 10)
    current_dim = lower_bound + FINE_STEP
    d_10 = upper_bound # Default
    
    phase = "SEARCH_UP"
    while True:
        if phase == "SEARCH_UP":
            acc = check_accuracy(current_dim, k, N, p, device)
            if acc >= THRESHOLD:
                print(f"-> Candidate found (Step 10): {current_dim}")
                phase = "CONFIRM"
            else:
                current_dim += FINE_STEP
                if current_dim > 100000:
                    return upper_bound
                    
        elif phase == "SEARCH_DOWN":
            acc = check_accuracy(current_dim, k, N, p, device)
            if acc < THRESHOLD:
                current_dim += FINE_STEP
                print(f"-> Candidate found (Step 10 drop): {current_dim}")
                phase = "CONFIRM"
            else:
                current_dim -= FINE_STEP
                if current_dim < 10:
                    d_10 = 10
                    break
                    
        elif phase == "CONFIRM":
            # Heavy Check (5 Runs of candidate)
            # We also check below to ensure boundary tightness
            passed = verify_candidate(current_dim, k, N, p, device)
            
            if passed:
                # Check below to ensure we aren't too high
                below_dim = max(10, current_dim - FINE_STEP)
                passed_below = verify_candidate(below_dim, k, N, p, device)
                
                if not passed_below:
                    print(f"-> Confirmed Step 10 Boundary: {current_dim}")
                    d_10 = current_dim
                    break
                else:
                    print(f"-> Too High (Below {below_dim} passes). Searching Down...")
                    current_dim = below_dim
                    phase = "SEARCH_DOWN"
            else:
                print(f"-> Too Low (Candidate {current_dim} failed). Searching Up...")
                current_dim += FINE_STEP
                phase = "SEARCH_UP"

    # --- STAGE 2: Refine Step 3 ---
    # Range: (d_10 - 10, d_10]
    print(f"\n--- Refining (Step 3) in range [{max(1, d_10-10)}, {d_10}] ---")
    d_3 = d_10
    start_3 = max(1, d_10 - 10 + 3)
    # We step 3 up to d_10. If we hit d_10, we keep d_10.
    # Using a while loop to handle boundaries safely
    curr = start_3
    found_d3 = False
    
    while curr < d_10:
        acc = check_accuracy(curr, k, N, p, device)
        if acc >= THRESHOLD:
            if verify_candidate(curr, k, N, p, device):
                print(f"-> Refined Step 3 Candidate: {curr}")
                d_3 = curr
                found_d3 = True
                break
        curr += 3
    
    if not found_d3:
        print(f"-> Kept previous boundary: {d_10}")

    # --- STAGE 3: Refine Step 1 ---
    # Range: (d_3 - 3, d_3]
    print(f"\n--- Refining (Step 1) in range [{max(1, d_3-3)}, {d_3}] ---")
    d_final = d_3
    start_1 = max(1, d_3 - 3 + 1)
    
    curr = start_1
    while curr < d_3:
        acc = check_accuracy(curr, k, N, p, device)
        if acc >= THRESHOLD:
            if verify_candidate(curr, k, N, p, device):
                 print(f"-> Refined Step 1 Candidate: {curr}")
                 d_final = curr
                 break
        curr += 1
        
    return d_final


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb", action="store_true", help="Enable WandB logging")
    parser.add_argument("--project", type=str, default="VSA_Item_Memory", help="WandB Project Name")
    parser.add_argument("--p", type=int, default=2, help="OrthHash p factor (vectors per filler)")
    parser.add_argument("--device", type=str, default="cuda:1", help="Device to run on")
    args = parser.parse_args()

    # Settings
    K_BUNDLES = 10
    
    # Robust Device Selection
    requested_device = args.device
    if "cuda" in requested_device and torch.cuda.is_available():
        count = torch.cuda.device_count()
        # Parse index
        if ":" in requested_device:
            idx = int(requested_device.split(":")[-1])
            if idx >= count:
                print(f"WARNING: Requested {requested_device} but only {count} GPUs found.")
                print(f"Falling back to cuda:0")
                DEVICE = "cuda:0"
            else:
                DEVICE = requested_device
        else:
             DEVICE = requested_device
    else:
        DEVICE = "cpu"
    
    # Fixed p from args
    P_FACTOR = args.p

    # Items sizes to test
    ITEM_SIZES = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000] 
    
    start_dim = 10
    
    print(f"Model: OrthHash")
    print(f"Bundles (k): {K_BUNDLES}")
    print(f"P values: {P_FACTOR}")
    print(f"Role Vecs: Dynamic sqrt(d)")
    print(f"Device: {DEVICE}")
    print(f"WandB: {'Enabled' if args.wandb else 'Disabled'}")
    
    if args.wandb:
        wandb.init(project=args.project, name=f"OrthHash_P{P_FACTOR}_DynamicRoles_Search", config={
            "model": "OrthHash",
            "k": K_BUNDLES,
            "p": P_FACTOR,
            "roles": "dynamic_sqrt_d"
        })
    
    results = {}
    
    for N in ITEM_SIZES:
        min_dim = find_min_dimension(K_BUNDLES, N, P_FACTOR, start_dim, DEVICE)
        print(f"==> Min Dim for N={N}: {min_dim}")
        results[N] = min_dim
        
        if args.wandb:
            wandb.log({"N": N, "Min_Dimension": min_dim})
            
        # Start lower logic
        start_dim = max(10, min_dim - 10)
        
    print("\nFINAL RESULTS:")
    for N, d in results.items():
        print(f"N={N} -> Dim={d}")
        
    if args.wandb:
        print("Syncing WandB...")
        wandb.finish()
