import torch
import torch.nn.functional as F
import time
import numpy as np
import wandb
import itertools
import os
import argparse

# Set WandB to offline mode (default, can be overridden if needed, but we handle logic below)
os.environ["WANDB_MODE"] = "offline"

# Import your VSA models
from MBAT import MBAT
from FHRR import FHRR
from HRR import HRR
from VTB import VTB
from HLB import HLB
from BSC import BSC
from BSDC_S import BSDC_S
from BSDC_SEG import BSDC_SEG
from BSDC_CDT import BSDC_CDT
from MAP_B import MAP_B
from MAP_C import MAP_C
from MAP_I import MAP_I
from CGR import CGR
from GHRR import GHRR

# ==========================================
# Experiment Runners (Updated for Metrics)
# ==========================================

def run_vsa_experiment_with_cleanup(vsa_class, dimension, n_recursive, n_bundles):
    print(f"   -> Running {vsa_class.__name__}: Dim={dimension}, Depth={n_recursive}, Bundles={n_bundles}...")

    model = vsa_class(dimension)

    # Pre-allocate containers
    vectors_to_bundle = []
    vocab_list = []
    ground_truth_chains = [] # Store (keys, label_index, stored_chain)

    start_time = time.time()

    # 1. BUNDLING LOOP
    for b in range(n_bundles):
        chain_vectors = [model.generate_vector() for _ in range(n_recursive + 1)]

        chain_result = chain_vectors[0]
        for i in range(1, n_recursive + 1):
            chain_result = model.bind(chain_result, chain_vectors[i])

        vectors_to_bundle.append(chain_result)

        keys = chain_vectors[:-1]
        leaf_vec = chain_vectors[-1]

        ground_truth_chains.append((keys, b, chain_result)) # Store chain_result for recognition
        vocab_list.append(leaf_vec)

    # Bundle Memory
    memory = model.bundle(vectors_to_bundle)

    # Create Vocabulary Matrix (N_bundles, Dim)
    vocab_matrix = torch.stack(vocab_list)
    vocab_matrix = F.normalize(vocab_matrix, p=2, dim=1)

    # 2. RETRIEVAL LOOP (Modified: Recognition-based)
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0

    for i, (keys, target_idx, _) in enumerate(ground_truth_chains):
        # Construct Role Vector
        # The chain was: k1 * k2 * ... * leaf
        # We need the 'Role' which is k1 * k2 * ... without the leaf
        role_vec = keys[0]
        for k in keys[1:]:
            role_vec = model.bind(role_vec, k)
        
        # 2. Batch Bind Role to All Fillers
        # Check for optimized batching
        if isinstance(model, (HLB, MAP_I)):
             # Element-wise multiplication
             # vocab_matrix: (N_vocab, D)
             # role_vec: (D) -> broadcast to (N_vocab, D)
             candidate_chains = role_vec * vocab_matrix
             
             sims = F.cosine_similarity(memory.unsqueeze(0), candidate_chains, dim=1)
             best_idx = torch.argmax(sims).item()
             scores = sims

        elif isinstance(model, FHRR):
            # Complex Multiplication
            # (a + ib)(c + id) = (ac - bd) + i(ad + bc)
            # Shapes: role_vec (D), vocab_matrix (N, D)
            # View as (..., D/2, 2)
            
            n_complex = dimension // 2
            
            # Role: (1, n_complex, 2) for broadcasting
            u = role_vec.view(1, n_complex, 2)
            # Vocab: (N, n_complex, 2)
            v = vocab_matrix.view(-1, n_complex, 2)
            
            # u is role, v is filler. FHRR bind(u, v) -> u * v
            # re = u.re * v.re - u.im * v.im
            re = u[:,:,0] * v[:,:,0] - u[:,:,1] * v[:,:,1]
            # im = u.re * v.im + u.im * v.re
            im = u[:,:,0] * v[:,:,1] + u[:,:,1] * v[:,:,0]
            
            # Stack and flatten: (N, n_complex, 2) -> (N, D)
            candidate_chains = torch.stack([re, im], dim=-1).reshape(len(vocab_list), -1)
            
            sims = F.cosine_similarity(memory.unsqueeze(0), candidate_chains, dim=1)
            best_idx = torch.argmax(sims).item()
            scores = sims

        elif isinstance(model, VTB):
            # Matrix Multiplication
            # VTB Bind(u, v) -> v @ u where u, v are reshaped to (d_root, d_root)
            # role_vec = u, filler = v
            # We want filler @ role_vec
            
            d_root = int(np.sqrt(dimension))
            
            # Role: (d_root, d_root)
            U_mat = role_vec.view(d_root, d_root)
            
            # Vocab: (N, d_root, d_root)
            V_mats = vocab_matrix.view(-1, d_root, d_root)
            
            # Batch Matmul: (N, k, k) @ (k, k) -> (N, k, k)
            result_mats = torch.matmul(V_mats, U_mat)
            
            # Flatten: (N, D)
            candidate_chains = result_mats.view(len(vocab_list), -1)
            
            sims = F.cosine_similarity(memory.unsqueeze(0), candidate_chains, dim=1)
            best_idx = torch.argmax(sims).item()
            scores = sims

        else:
            # Fallback for models without explicit batch support implemented here
            scores = torch.zeros(len(vocab_list), device=role_vec.device)
            
            for v_idx, filler in enumerate(vocab_list):
                candidate_chain = model.bind(role_vec, filler)
                sim = model.similarity(memory, candidate_chain)
                scores[v_idx] = sim
                
            best_idx = torch.argmax(scores).item()

        if best_idx == target_idx:
            retrieval_success_count += 1

        # Accumulate similarity of the CORRECT target
        retrieval_total_target_sim += scores[target_idx].item()

    end_time = time.time()
    elapsed = end_time - start_time

    # Calculate metrics
    retrieval_accuracy = (retrieval_success_count / n_bundles) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_bundles

    return {
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'elapsed': elapsed
    }


def run_bsdc_cdt_experiment(vsa_class, dimension, n_recursive, n_bundles):
    """
    BSDC-CDT Special Case (Set membership with retrieval and recognition)
    """
    print(f"   -> Running {vsa_class.__name__} (Set Check): Dim={dimension}, Items={n_recursive+1}, Bundles={n_bundles}...")

    model = vsa_class(dimension)

    ground_truth_items = []  # Store the first item from each set
    composite_vectors = []   # Store the composite vectors
    all_items = []           # Store all items for vocabulary

    start_time = time.time()

    # 1. GENERATION & BINDING LOOP
    for b in range(n_bundles):
        items = [model.generate_vector() for _ in range(n_recursive + 1)]

        # Bind items together: Z = Item1 U Item2 U ...
        composite_vector = items[0]
        for i in range(1, n_recursive + 1):
            composite_vector = model.bind(composite_vector, items[i])

        composite_vectors.append(composite_vector)
        ground_truth_items.append(items[0])  # Store first item as target
        all_items.append(items[0])  # Add to vocabulary

    # 2. BUNDLING STEP
    global_memory = model.bundle(composite_vectors)

    # Create vocabulary matrix for retrieval
    vocab_matrix = torch.stack(all_items)
    vocab_matrix = F.normalize(vocab_matrix, p=2, dim=1)

    # 3. RETRIEVAL TASK
    # For each composite, try to retrieve the correct item from vocabulary
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0

    for idx, target_item in enumerate(ground_truth_items):
        # Score all vocabulary items against memory
        # We're looking for which item has highest similarity to the composite in memory
        scores = torch.zeros(len(all_items), device=target_item.device)
        for i, vocab_item in enumerate(all_items):
            scores[i] = model.similarity(global_memory, vocab_item)
        
        best_idx = torch.argmax(scores).item()
        
        if best_idx == idx:
            retrieval_success_count += 1
        
        # Accumulate similarity of the target
        retrieval_total_target_sim += scores[idx].item()

    end_time = time.time()
    elapsed = end_time - start_time

    # Calculate metrics
    retrieval_accuracy = (retrieval_success_count / n_bundles) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_bundles

    return {
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'elapsed': elapsed
    }

# ==========================================
# Grid Search Engine
# ==========================================

def run_grid_search(grid, n_trials, project_name, use_wandb=True):
    """
    Iterates over the grid of parameters.
    For each combination, runs 'n_trials'.
    Logs mean/std to WandB (if enabled).
    """

    # Generate all combinations of parameters
    keys = grid.keys()
    values = grid.values()
    combinations = list(itertools.product(*values))

    print(f"\n=== Starting Grid Search ===")
    print(f"Total Configurations: {len(combinations)}")
    print(f"Trials per Config: {n_trials}")
    print(f"Total Runs: {len(combinations) * n_trials}")
    print(f"WandB Logging: {'Enabled' if use_wandb else 'Disabled'}\n")

    for i, config_values in enumerate(combinations):
        # Reconstruct the config dictionary for this run
        config = dict(zip(keys, config_values))

        model_cls = config['model_class']
        dim = config['dimension']
        depth = config['depth']
        bundles = config['n_bundles']

        model_name = model_cls.__name__

        print(f"[{i+1}/{len(combinations)}] Testing {model_name} | Dim:{dim} | Depth:{depth} | Bundles:{bundles}")

        # Initialize WandB Run for this specific configuration
        run_name = f"{model_name}_D{dim}_L{depth}_B{bundles}"

        # Context manager for wandb run or dummy
        if use_wandb:
            run_obj = wandb.init(
                project=project_name, 
                name=run_name, 
                config=config, 
                resume="allow",
                settings=wandb.Settings(init_timeout=300)
            )
        else:
            run_obj = None # No-op placeholder

        try:
            trial_metrics = {
                'retrieval_accuracy': [],
                'retrieval_avg_similarity': [],
                'elapsed': []
            }

            # --- TRIAL LOOP ---
            for t in range(n_trials):
                # Both experiment types now return dictionaries
                results = run_vsa_experiment_with_cleanup(model_cls, dim, depth, bundles) if model_name != "BSDC_CDT" else run_bsdc_cdt_experiment(model_cls, dim, depth, bundles)

                # Log Trial Data
                log_dict = {"trial": t + 1}
                for key, value in results.items():
                    log_dict[f"trial_{key}"] = value
                    if key in trial_metrics:
                        trial_metrics[key].append(value)
                
                if use_wandb:
                    wandb.log(log_dict)

            # --- AGGREGATE STATS ---
            summary_log = {}
            
            for metric_name, values in trial_metrics.items():
                if len(values) > 0:  # Only compute stats for metrics that exist
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    summary_log[f"mean_{metric_name}"] = mean_val
                    summary_log[f"std_{metric_name}"] = std_val
            
            # Print summary to stdout
            ret_acc = trial_metrics['retrieval_accuracy']
            ret_sim = trial_metrics['retrieval_avg_similarity']
            time_vals = trial_metrics['elapsed']
            
            print(f"   >>> Retrieval Acc={np.mean(ret_acc):.2f}% (±{np.std(ret_acc):.2f}) | Sim={np.mean(ret_sim):.4f} (±{np.std(ret_sim):.4f})")
            print(f"   >>> Time={np.mean(time_vals):.2f}s (±{np.std(time_vals):.2f})\n")

            if use_wandb:
                wandb.log(summary_log)

        finally:
            if use_wandb and run_obj:
                run_obj.finish()
        
        # Small delay to avoid WandB rate limiting
        if i < len(combinations) - 1:  # Don't sleep after last run
            time.sleep(1)

# ==========================================
# Main Configuration
# ==========================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run VSA Experiments with Discriminative Retrieval")
    parser.add_argument("--test_mode", action="store_true", help="Disable WandB logging for testing")
    args = parser.parse_args()

    # 1. Warm up CUDA
    if torch.cuda.is_available():
        print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
        torch.zeros(1).cuda()
    else:
        print("WARNING: CUDA not available. Running on CPU (will be slow).")

    GRID = {
        'model_class': [
            HLB, MAP_I, FHRR, VTB
        ],
        'dimension': [4096, 8100],
        'depth':     [1],
        'n_bundles': [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    }
    
    # 3. Run Experiments
    # Change 'n_trials' to set how many times each config runs for stats
    run_grid_search(GRID, n_trials=5, project_name="DiscrimRetrieval_VSATest", use_wandb=not args.test_mode)
    
    if not args.test_mode:
        print("\n" + "="*80)
        print("RUNS COMPLETED IN OFFLINE MODE")
        print("To sync to WandB, run:")
        print("  wandb sync /tmp/wandb/")
        print("Or wherever your wandb logs are stored")
        print("="*80)
    else:
        print("\n" + "="*80)
        print("TEST RUN COMPLETED (No WandB Logging)")
        print("="*80)
