import torch
import torch.nn.functional as F
import shutil
import random
import hashlib
import time
import wandb
import itertools
import os
import argparse
import math
import numpy as np
from Memory_General_Torch import Memory

# Set WandB to offline mode
os.environ["WANDB_MODE"] = "offline"

# Setup
terminal_width = shutil.get_terminal_size().columns

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- BATCHED GPU BASIS GENERATION ---
def compute_seed(role_indices):
    """Compute deterministic seed from role indices."""
    sorted_indices = sorted(list(role_indices))
    key = str(tuple(sorted_indices)).encode('utf-8')
    seed = int(hashlib.sha256(key).hexdigest()[:8], 16)
    return seed

def batched_qr_orthonormal(seeds, n_dim, num_vectors, device='cuda'):
    """
    Generate orthonormal bases for multiple seeds using batched QR decomposition.
    No batching/OOM checks - processes all seeds at once.
    """
    n_seeds = len(seeds)
    
    # Generate random matrices for ALL seeds at once
    batch_matrices = torch.zeros((n_seeds, n_dim, num_vectors), device=device)
    for i, seed in enumerate(seeds):
        generator = torch.Generator(device=device).manual_seed(seed)
        batch_matrices[i] = torch.randn(n_dim, num_vectors, generator=generator, device=device)
    
    # Batched QR decomposition
    Q, R = torch.linalg.qr(batch_matrices)
    
    # Transpose to get (n_seeds, num_vectors, n_dim)
    return Q.transpose(1, 2)

def generate_all_bases_batched(all_abstract_indices, n_dim, num_vectors, device='cuda'):
    """Pre-generate all bases needed for encoding and testing."""
    unique_index_sets = set()
    for indices in all_abstract_indices:
        unique_index_sets.add(frozenset(indices))
    
    seeds = []
    index_set_list = []
    for index_set in unique_index_sets:
        seeds.append(compute_seed(index_set))
        index_set_list.append(index_set)
    
    # print(f"Generating {len(seeds)} unique bases on {device}...")
    # start = time.time()
    
    all_bases = batched_qr_orthonormal(seeds, n_dim, num_vectors, device=device)
    
    # end = time.time()
    # print(f"Batch generation took {end - start:.3f}s")
    
    basis_cache = {}
    for i, index_set in enumerate(index_set_list):
        basis_cache[index_set] = all_bases[i]
    
    return basis_cache

# --- EXPERIMENT RUNNER ---
def run_orthhash_experiment(dim, depth_r, num_role_vecs, n_items_to_bundle,
                            p_factor, is_wedge, n_abstract_roles):
    """
    Runs a single trial of the orthogonal hash experiment.
    """
    start_time = time.time()

    # Use dim directly
    memory_dim = dim

    # Initialize Memory
    model = Memory(
        N=memory_dim,
        p=p_factor,
        isWedge=is_wedge,
        orthogonalize_roles=False,
        device=device
    )

    # Pre-generate all abstract index combinations
    encoding_abstract_indices = []
    for i in range(n_items_to_bundle):
        abstract_indices = torch.randperm(n_abstract_roles)[:depth_r].tolist()
        encoding_abstract_indices.append(abstract_indices)
    
    testing_abstract_indices = []
    for _ in range(n_items_to_bundle):
        abstract_indices = torch.randperm(n_abstract_roles)[:depth_r].tolist()
        testing_abstract_indices.append(abstract_indices)
    
    all_abstract_indices = encoding_abstract_indices + testing_abstract_indices
    
    # Batch generate all bases
    basis_cache = generate_all_bases_batched(all_abstract_indices, memory_dim, num_role_vecs, device=device)

    stored_data = []

    # ENCODING
    for i, abstract_indices in enumerate(encoding_abstract_indices):
        index_set = frozenset(abstract_indices)
        role_vecs = basis_cache[index_set]
        filler_vecs = torch.randn(p_factor, memory_dim, device=device)

        model.bind(role_vecs, filler_vecs)

        stored_data.append({
            "role": role_vecs,
            "abstract_indices": index_set,
            "filler": filler_vecs
        })

    # RETRIEVAL TEST (Tied: Uses the same stored_data as Recognition)
    vocab_list = [item['filler'] for item in stored_data]
    vocab_tensor = torch.stack(vocab_list)  # (n_items, p_factor, memory_dim)
    
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0
    
    for target_idx, item in enumerate(stored_data):
        role_vecs = item['role']
        
        # Direct batched call (No loops, no OOM catching)
        scores = model.get_dual_contraction_score_batched(role_vecs, vocab_tensor)
        
        best_idx = torch.argmax(scores).item()
        
        if best_idx == target_idx:
            retrieval_success_count += 1
        
        retrieval_total_target_sim += scores[target_idx].item()

    end_time = time.time()
    elapsed = end_time - start_time

    # CALCULATE METRICS
    retrieval_accuracy = (retrieval_success_count / n_items_to_bundle) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_items_to_bundle

    return {
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'elapsed': elapsed
    }

# --- SMART SEARCH ---
def run_smart_search(dimensions_list, args):
    """
    Intelligent grid search over dimensions and num_role_vecs.
    """
    
    print("\n" + "="*80)
    print(f"SMART K SEARCH STARTING")
    print(f"Dimensions: {dimensions_list}")
    print(f"WandB Logging: {'DISABLED' if args.test_mode else 'ENABLED'}")
    print("="*80 + "\n")

    overall_results = [] # Store tuple: (dim, bundles, k, acc, norm_acc)

    for dim in dimensions_list:
        # 1. Dynamic Bundles
        # Heuristic: Capacity scales with D. 
        # For D=10000, 1000 is fine. For D=50, 1000 is too much.
        # Let's try n_bundles = 15 * dim
        n_bundles = int(15 * dim) 
        
        # 2. Key Range around sqrt(dim)
        sqrt_d = int(math.sqrt(dim))
        k_min = max(2, sqrt_d - 5)
        # Extend further out to capture behavior at higher k (e.g. +15)
        k_max = sqrt_d + 15
        # Ensure we have at least a few points
        k_values = list(range(k_min, k_max + 1))
        
        print(f"--- Processing Dim: {dim} | Bundles: {n_bundles} | K Range: {k_values} ---")
        
        dim_results = [] # (k, acc)

        for k in k_values:
            # Config for this run
            depth_r = 16 # Fixed as per original override
            p_factor = 2 # Fixed/Default
            is_wedge = False
            n_abstract_roles = 100

            # --- RUN TRIALS ---
            acc_list = []
            for t in range(5): # 5 trials for robustness
                res = run_orthhash_experiment(
                    dim, depth_r, k, n_bundles,
                    p_factor, is_wedge, n_abstract_roles
                )
                acc_list.append(res['retrieval_accuracy'])
            
            avg_acc = sum(acc_list) / len(acc_list)
            dim_results.append((k, avg_acc))
            
            # WandB logging if enabled
            if not args.test_mode:
                config = {
                    'dimension': dim, 'n_bundles': n_bundles, 'num_role_vecs': k,
                    'depth': depth_r, 'p_factor': p_factor, 'model_class': 'OrthHash'
                }
                run_name = f"SmartSearch_D{dim}_K{k}"
                with wandb.init(project="Smart_K_Search", name=run_name, config=config, reinit=True, settings=wandb.Settings(init_timeout=300)) as run:
                    wandb.log({'retrieval_accuracy': avg_acc})

        # --- NORMALIZATION ---
        # Find peak accuracy for this dimension
        max_acc = max([r[1] for r in dim_results]) if dim_results else 0.0
        
        for k, acc in dim_results:
            norm_acc = (acc / max_acc * 100.0) if max_acc > 0 else 0.0
            overall_results.append((dim, n_bundles, k, acc, norm_acc))
            
        print(f"   Done Dim {dim}. Peak Acc: {max_acc:.2f}%")

    # --- PRINT FINAL SUMMARY TABLE ---
    print("\n\n" + "="*80)
    print(f"{'Dim':<8} | {'Bundles':<10} | {'K':<5} | {'Acc':<10} | {'Norm Acc':<10}")
    print("-" * 55)
    for res in overall_results:
        print(f"{res[0]:<8} | {res[1]:<10} | {res[2]:<5} | {res[3]:<10.2f} | {res[4]:<10.2f}")
    print("="*80 + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_mode", action="store_true", help="Disable WandB logging")
    args = parser.parse_args()
    
    if args.test_mode:
        os.environ["WANDB_MODE"] = "offline"

    # User defined dimensions list for coverage
    dimensions = [
        50, 60, 70, 80, 90, 
        100, 110, 120, 130, 140, 
        150, 160, 170, 180, 190, 200
    ]


    run_smart_search(dimensions, args)
