import torch
import torch.nn.functional as F
import shutil
import random
import hashlib
import time
import wandb
import itertools
import os
from Memory_General_Torch import Memory

# Set WandB to offline mode
os.environ["WANDB_MODE"] = "offline"

# Setup
terminal_width = shutil.get_terminal_size().columns

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- BATCHED GPU BASIS GENERATION ---
def compute_seed(role_indices):
    """Compute deterministic seed from role indices."""
    sorted_indices = sorted(list(role_indices))
    key = str(tuple(sorted_indices)).encode('utf-8')
    seed = int(hashlib.sha256(key).hexdigest()[:8], 16)
    return seed

def batched_qr_orthonormal(seeds, n_dim, num_vectors, device='cuda'):
    """
    Generate orthonormal bases for multiple seeds using batched QR decomposition.
    No batching/OOM checks - processes all seeds at once.
    """
    n_seeds = len(seeds)
    
    # Generate random matrices for ALL seeds at once
    batch_matrices = torch.zeros((n_seeds, n_dim, num_vectors), device=device)
    for i, seed in enumerate(seeds):
        generator = torch.Generator(device=device).manual_seed(seed)
        batch_matrices[i] = torch.randn(n_dim, num_vectors, generator=generator, device=device)
    
    # Batched QR decomposition
    Q, R = torch.linalg.qr(batch_matrices)
    
    # Transpose to get (n_seeds, num_vectors, n_dim)
    return Q.transpose(1, 2)

def generate_all_bases_batched(all_abstract_indices, n_dim, num_vectors, device='cuda'):
    """Pre-generate all bases needed for encoding and testing."""
    unique_index_sets = set()
    for indices in all_abstract_indices:
        unique_index_sets.add(frozenset(indices))
    
    seeds = []
    index_set_list = []
    for index_set in unique_index_sets:
        seeds.append(compute_seed(index_set))
        index_set_list.append(index_set)
    
    print(f"Generating {len(seeds)} unique bases on {device}...")
    start = time.time()
    
    all_bases = batched_qr_orthonormal(seeds, n_dim, num_vectors, device=device)
    
    end = time.time()
    print(f"Batch generation took {end - start:.3f}s")
    
    basis_cache = {}
    for i, index_set in enumerate(index_set_list):
        basis_cache[index_set] = all_bases[i]
    
    return basis_cache

# --- EXPERIMENT RUNNER ---
def run_orthhash_experiment(dim, depth_r, num_role_vecs, n_items_to_bundle,
                            p_factor, is_wedge, n_abstract_roles):
    """
    Runs a single trial of the orthogonal hash experiment.
    """
    start_time = time.time()

    # Use dim directly
    memory_dim = dim

    # Initialize Memory
    model = Memory(
        N=memory_dim,
        p=p_factor,
        isWedge=is_wedge,
        orthogonalize_roles=False,
        device=device
    )

    # Pre-generate all abstract index combinations
    encoding_abstract_indices = []
    for i in range(n_items_to_bundle):
        abstract_indices = torch.randperm(n_abstract_roles)[:depth_r].tolist()
        encoding_abstract_indices.append(abstract_indices)
    
    testing_abstract_indices = []
    for _ in range(n_items_to_bundle):
        abstract_indices = torch.randperm(n_abstract_roles)[:depth_r].tolist()
        testing_abstract_indices.append(abstract_indices)
    
    all_abstract_indices = encoding_abstract_indices + testing_abstract_indices
    
    # Batch generate all bases
    basis_cache = generate_all_bases_batched(all_abstract_indices, memory_dim, num_role_vecs, device=device)

    stored_data = []

    # ENCODING
    for i, abstract_indices in enumerate(encoding_abstract_indices):
        index_set = frozenset(abstract_indices)
        role_vecs = basis_cache[index_set]
        filler_vecs = torch.randn(p_factor, memory_dim, device=device)

        model.bind(role_vecs, filler_vecs)

        stored_data.append({
            "role": role_vecs,
            "abstract_indices": index_set,
            "filler": filler_vecs
        })

    # RECOGNITION TEST
    recognition_success_count = 0

    for i in range(n_items_to_bundle):
        item = stored_data[i]
        
        wrong_index_set = frozenset(testing_abstract_indices[i])
        
        if wrong_index_set == item['abstract_indices']:
            while True:
                wrong_abstract_indices = torch.randperm(n_abstract_roles)[:depth_r].tolist()
                wrong_index_set = frozenset(wrong_abstract_indices)
                if wrong_index_set != item['abstract_indices']:
                    break
            
            if wrong_index_set not in basis_cache:
                seed = compute_seed(wrong_index_set)
                basis = batched_qr_orthonormal([seed], memory_dim, num_role_vecs, device=device)
                basis_cache[wrong_index_set] = basis[0]
        
        wrong_role_vecs = basis_cache[wrong_index_set]
        wrong_filler_vecs = torch.randn(p_factor, memory_dim, device=device)
        
        score_correct = model.get_dual_contraction_score(item['role'], item['filler'])
        score_wrong = model.get_dual_contraction_score(wrong_role_vecs, wrong_filler_vecs)
        
        if score_correct > score_wrong:
            recognition_success_count += 1

    # RETRIEVAL TEST (Tied: Uses the same stored_data as Recognition)
    vocab_list = [item['filler'] for item in stored_data]
    vocab_tensor = torch.stack(vocab_list)  # (n_items, p_factor, memory_dim)
    
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0
    
    for target_idx, item in enumerate(stored_data):
        role_vecs = item['role']
        
        # Direct batched call (No loops, no OOM catching)
        scores = model.get_dual_contraction_score_batched(role_vecs, vocab_tensor)
        
        best_idx = torch.argmax(scores).item()
        
        if best_idx == target_idx:
            retrieval_success_count += 1
        
        retrieval_total_target_sim += scores[target_idx].item()

    end_time = time.time()
    elapsed = end_time - start_time

    # CALCULATE METRICS
    recognition_accuracy = (recognition_success_count / n_items_to_bundle) * 100.0
    retrieval_accuracy = (retrieval_success_count / n_items_to_bundle) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_items_to_bundle

    return {
        'recognition_accuracy': recognition_accuracy,
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'elapsed': elapsed
    }

# --- GRID SEARCH ---
def run_grid_search(grid, n_trials, project_name):
    keys = grid.keys()
    values = grid.values()
    combinations = list(itertools.product(*values))

    print(f"\n=== Starting OSC Grid Search ===")
    print(f"Total Configurations: {len(combinations)}")
    print(f"Trials per Config: {n_trials}")
    print(f"Total Runs: {len(combinations) * n_trials}\n")

    for i, config_values in enumerate(combinations):
        config = dict(zip(keys, config_values))
        config['model_class'] = 'OSC'

        # Extract parameters
        dim = config['dimension']
        depth_r = config['depth']
        n_bundles = config['n_bundles']
        p_factor = config['p_factor']
        is_wedge = config['is_wedge']
        n_abstract_roles = config['n_abstract_roles']

        # --- DYNAMIC PARAMETER LOGIC ---
        if p_factor == 3:
            if dim <= 80:
                num_role_vecs = 10
            else:
                num_role_vecs = 15
        elif p_factor == 2:
            if dim <= 80:
                num_role_vecs = 10
            elif dim <= 150:
                num_role_vecs = 15
            else:
                num_role_vecs = 20
        else:
            num_role_vecs = -1



        config['num_role_vecs'] = num_role_vecs
        # -------------------------------

        print(f"[{i+1}/{len(combinations)}] Testing OSC | Dim:{dim} | Depth:{depth_r} | Bundles:{n_bundles} | P:{p_factor} | RoleVecs:{num_role_vecs} | Wedge:{is_wedge}")

        wedge_str = "W" if is_wedge else "O"
        run_name = f"OSC{dim}_L{depth_r}_B{n_bundles}_P{p_factor}_R{num_role_vecs}_{wedge_str}"

        # Hardcoded depth override
        depth_r = 16

        with wandb.init(project=project_name, name=run_name, config=config, reinit=True) as run:

            trial_metrics = {
                'recognition_accuracy': [],
                'retrieval_accuracy': [],
                'retrieval_avg_similarity': [],
                'elapsed': []
            }

            for t in range(n_trials):
                results = run_orthhash_experiment(
                    dim, depth_r, num_role_vecs, n_bundles,
                    p_factor, is_wedge, n_abstract_roles
                )

                log_dict = {"trial": t + 1}
                for key, value in results.items():
                    log_dict[f"trial_{key}"] = value
                    if key in trial_metrics:
                        trial_metrics[key].append(value)
                
                wandb.log(log_dict)

            summary_log = {}
            
            for metric_name, values in trial_metrics.items():
                if len(values) > 0:
                    mean_val = torch.tensor(values).mean().item() if isinstance(values[0], torch.Tensor) else sum(values) / len(values)
                    std_val = torch.tensor(values).std().item() if isinstance(values[0], torch.Tensor) else (sum((x - mean_val)**2 for x in values) / len(values)) ** 0.5
                    summary_log[f"mean_{metric_name}"] = mean_val
                    summary_log[f"std_{metric_name}"] = std_val

            wandb.log(summary_log)

            # Print summary
            rec_acc = trial_metrics['recognition_accuracy']
            ret_acc = trial_metrics['retrieval_accuracy']
            time_vals = trial_metrics['elapsed']
            
            mean_ret_acc = sum(ret_acc) / len(ret_acc)
            mean_rec_acc = sum(rec_acc) / len(rec_acc)
            mean_time = sum(time_vals) / len(time_vals)
            
            print(f"   >>> Retrieval Acc={mean_ret_acc:.2f}% | Recognition Acc={mean_rec_acc:.2f}%")
            print(f"   >>> Time={mean_time:.2f}s\n")

if __name__ == "__main__":
    GRID = {
        'model_class': ["OSC"],
        'dimension': [25, 50, 75, 100, 125, 150, 175, 200],
        'depth': [-1],
        'n_bundles': [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000],
        'p_factor': [2],
        'is_wedge': [False],
        'n_abstract_roles': [100]
    }
    run_grid_search(GRID, n_trials=5, project_name="VSA_Comparison")

    GRID = {
        'model_class': ["OSC"],
        'dimension': [25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80],
        'depth': [-1],
        'n_bundles': [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000],
        'p_factor': [3],
        'is_wedge': [False],
        'n_abstract_roles': [100]
    }
    run_grid_search(GRID, n_trials=5, project_name="Memory_Comp3")

    GRID = {
        'model_class': ["OSC"],
        'dimension': [25, 50, 75, 100, 125, 150, 175, 200],
        'depth': [-1],
        'n_bundles': [1, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 400, 500, 600, 700, 800, 900, 1000],
        'p_factor': [2],
        'is_wedge': [False],
        'n_abstract_roles': [100]
    }
    run_grid_search(GRID, n_trials=5, project_name="Memory_Comp2")