import torch
import torch.nn.functional as F
import time
import numpy as np
import wandb
import itertools
import os

# Set WandB to offline mode
os.environ["WANDB_MODE"] = "offline"

# Import your VSA models
from MBAT import MBAT
from FHRR import FHRR
from HRR import HRR
from VTB import VTB
from HLB import HLB
from BSC import BSC
from BSDC_S import BSDC_S
from BSDC_SEG import BSDC_SEG
from BSDC_CDT import BSDC_CDT
from MAP_B import MAP_B
from MAP_C import MAP_C
from MAP_I import MAP_I
from CGR import CGR
from GHRR import GHRR

# ==========================================
# Experiment Runners (Updated for Metrics)
# ==========================================

def run_vsa_experiment_with_cleanup(vsa_class, dimension, n_recursive, n_bundles):
    print(f"   -> Running {vsa_class.__name__}: Dim={dimension}, Depth={n_recursive}, Bundles={n_bundles}...")

    model = vsa_class(dimension)

    # Pre-allocate containers
    vectors_to_bundle = []
    vocab_list = []
    ground_truth_chains = [] # Store (keys, label_index, stored_chain)

    start_time = time.time()

    # 1. BUNDLING LOOP
    for b in range(n_bundles):
        chain_vectors = [model.generate_vector() for _ in range(n_recursive + 1)]

        chain_result = chain_vectors[0]
        for i in range(1, n_recursive + 1):
            chain_result = model.bind(chain_result, chain_vectors[i])

        vectors_to_bundle.append(chain_result)

        keys = chain_vectors[:-1]
        leaf_vec = chain_vectors[-1]

        ground_truth_chains.append((keys, b, chain_result)) # Store chain_result for recognition
        vocab_list.append(leaf_vec)

    # Bundle Memory
    memory = model.bundle(vectors_to_bundle)

    # Create Vocabulary Matrix (N_bundles, Dim)
    vocab_matrix = torch.stack(vocab_list)
    vocab_matrix = F.normalize(vocab_matrix, p=2, dim=1)

    # 2. RETRIEVAL LOOP
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0

    for i, (keys, target_idx, _) in enumerate(ground_truth_chains):
        recovered = memory
        for key in keys:
            recovered = model.unbind(key, recovered)

        # 3. FAST BATCH CLEANUP
        recovered = F.normalize(recovered, p=2, dim=0)
        scores = torch.mv(vocab_matrix, recovered)
        best_idx = torch.argmax(scores).item()

        if best_idx == target_idx:
            retrieval_success_count += 1

        # Accumulate similarity of the CORRECT target
        retrieval_total_target_sim += scores[target_idx].item()

    # 4. RECOGNITION TASK
    # Test if memory correctly recognizes stored chains vs wrong chains
    recognition_success_count = 0

    for keys, target_idx, stored_chain in ground_truth_chains:
        # Generate a wrong chain
        wrong_keys = [model.generate_vector() for _ in range(n_recursive)]
        wrong_leaf = model.generate_vector()
        
        wrong_chain = wrong_keys[0]
        for i in range(1, n_recursive):
            wrong_chain = model.bind(wrong_chain, wrong_keys[i])
        wrong_chain = model.bind(wrong_chain, wrong_leaf)
        
        # Compare similarities
        sim_correct = model.similarity(memory, stored_chain)
        sim_wrong = model.similarity(memory, wrong_chain)
        
        # If correct has higher similarity, it's recognized correctly
        if sim_correct > sim_wrong:
            recognition_success_count += 1

    end_time = time.time()
    elapsed = end_time - start_time

    # Calculate metrics
    retrieval_accuracy = (retrieval_success_count / n_bundles) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_bundles
    recognition_accuracy = (recognition_success_count / n_bundles) * 100.0

    return {
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'recognition_accuracy': recognition_accuracy,
        'elapsed': elapsed
    }


def run_bsdc_cdt_experiment(vsa_class, dimension, n_recursive, n_bundles):
    """
    BSDC-CDT Special Case (Set membership with retrieval and recognition)
    """
    print(f"   -> Running {vsa_class.__name__} (Set Check): Dim={dimension}, Items={n_recursive+1}, Bundles={n_bundles}...")

    model = vsa_class(dimension)

    ground_truth_items = []  # Store the first item from each set
    composite_vectors = []   # Store the composite vectors
    all_items = []           # Store all items for vocabulary

    start_time = time.time()

    # 1. GENERATION & BINDING LOOP
    for b in range(n_bundles):
        items = [model.generate_vector() for _ in range(n_recursive + 1)]

        # Bind items together: Z = Item1 U Item2 U ...
        composite_vector = items[0]
        for i in range(1, n_recursive + 1):
            composite_vector = model.bind(composite_vector, items[i])

        composite_vectors.append(composite_vector)
        ground_truth_items.append(items[0])  # Store first item as target
        all_items.append(items[0])  # Add to vocabulary

    # 2. BUNDLING STEP
    global_memory = model.bundle(composite_vectors)

    # Create vocabulary matrix for retrieval
    vocab_matrix = torch.stack(all_items)
    vocab_matrix = F.normalize(vocab_matrix, p=2, dim=1)

    # 3. RETRIEVAL TASK
    # For each composite, try to retrieve the correct item from vocabulary
    retrieval_success_count = 0
    retrieval_total_target_sim = 0.0

    for idx, target_item in enumerate(ground_truth_items):
        # Score all vocabulary items against memory
        # We're looking for which item has highest similarity to the composite in memory
        scores = torch.zeros(len(all_items), device=target_item.device)
        for i, vocab_item in enumerate(all_items):
            scores[i] = model.similarity(global_memory, vocab_item)
        
        best_idx = torch.argmax(scores).item()
        
        if best_idx == idx:
            retrieval_success_count += 1
        
        # Accumulate similarity of the target
        retrieval_total_target_sim += scores[idx].item()

    # 4. RECOGNITION TASK
    # Test if memory correctly recognizes stored items vs distractors
    recognition_success_count = 0

    for target_item in ground_truth_items:
        distractor = model.generate_vector()

        sim_target = model.similarity(global_memory, target_item)
        sim_distractor = model.similarity(global_memory, distractor)

        if sim_target > sim_distractor:
            recognition_success_count += 1

    end_time = time.time()
    elapsed = end_time - start_time

    # Calculate metrics
    retrieval_accuracy = (retrieval_success_count / n_bundles) * 100.0
    retrieval_avg_similarity = retrieval_total_target_sim / n_bundles
    recognition_accuracy = (recognition_success_count / n_bundles) * 100.0

    return {
        'retrieval_accuracy': retrieval_accuracy,
        'retrieval_avg_similarity': retrieval_avg_similarity,
        'recognition_accuracy': recognition_accuracy,
        'elapsed': elapsed
    }

# ==========================================
# Grid Search Engine
# ==========================================

def run_grid_search(grid, n_trials, project_name):
    """
    Iterates over the grid of parameters.
    For each combination, runs 'n_trials'.
    Logs mean/std to WandB.
    """

    # Generate all combinations of parameters
    keys = grid.keys()
    values = grid.values()
    combinations = list(itertools.product(*values))

    print(f"\n=== Starting Grid Search ===")
    print(f"Total Configurations: {len(combinations)}")
    print(f"Trials per Config: {n_trials}")
    print(f"Total Runs: {len(combinations) * n_trials}\n")

    for i, config_values in enumerate(combinations):
        # Reconstruct the config dictionary for this run
        config = dict(zip(keys, config_values))

        model_cls = config['model_class']
        dim = config['dimension']
        depth = config['depth']
        bundles = config['n_bundles']

        model_name = model_cls.__name__

        print(f"[{i+1}/{len(combinations)}] Testing {model_name} | Dim:{dim} | Depth:{depth} | Bundles:{bundles}")

        # Initialize WandB Run for this specific configuration
        run_name = f"{model_name}_D{dim}_L{depth}_B{bundles}"

        # We start one run for the set of trials to aggregate them cleanly
        with wandb.init(
            project=project_name, 
            name=run_name, 
            config=config, 
            resume="allow",
            settings=wandb.Settings(init_timeout=300)
        ) as run:

            trial_metrics = {
                'retrieval_accuracy': [],
                'retrieval_avg_similarity': [],
                'recognition_accuracy': [],
                'elapsed': []
            }

            # --- TRIAL LOOP ---
            for t in range(n_trials):
                # Both experiment types now return dictionaries
                results = run_vsa_experiment_with_cleanup(model_cls, dim, depth, bundles) if model_name != "BSDC_CDT" else run_bsdc_cdt_experiment(model_cls, dim, depth, bundles)

                # Log Trial Data
                log_dict = {"trial": t + 1}
                for key, value in results.items():
                    log_dict[f"trial_{key}"] = value
                    if key in trial_metrics:
                        trial_metrics[key].append(value)
                
                wandb.log(log_dict)

            # --- AGGREGATE STATS ---
            summary_log = {}
            
            for metric_name, values in trial_metrics.items():
                if len(values) > 0:  # Only compute stats for metrics that exist
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    summary_log[f"mean_{metric_name}"] = mean_val
                    summary_log[f"std_{metric_name}"] = std_val

            # Log Summary to WandB
            wandb.log(summary_log)

            # Print summary
            ret_acc = trial_metrics['retrieval_accuracy']
            ret_sim = trial_metrics['retrieval_avg_similarity']
            rec_acc = trial_metrics['recognition_accuracy']
            time_vals = trial_metrics['elapsed']
            
            print(f"   >>> Retrieval Acc={np.mean(ret_acc):.2f}% (±{np.std(ret_acc):.2f}) | Sim={np.mean(ret_sim):.4f} (±{np.std(ret_sim):.4f})")
            
            if len(rec_acc) > 0:
                print(f"   >>> Recognition Acc={np.mean(rec_acc):.2f}% (±{np.std(rec_acc):.2f})")
            
            print(f"   >>> Time={np.mean(time_vals):.2f}s (±{np.std(time_vals):.2f})\n")
        
        # Small delay to avoid WandB rate limiting
        if i < len(combinations) - 1:  # Don't sleep after last run
            time.sleep(1)

# ==========================================
# Main Configuration
# ==========================================

if __name__ == "__main__":

    # 1. Warm up CUDA
    if torch.cuda.is_available():
        print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
        torch.zeros(1).cuda()
    else:
        print("WARNING: CUDA not available. Running on CPU (will be slow).")

    GRID = {
        'model_class': [
            HLB, MBAT, FHRR, GHRR, HRR, VTB, BSC, BSDC_S, BSDC_SEG, BSDC_CDT, MAP_B, MAP_C, MAP_I, CGR
        ],
        'dimension': [4096, 8100],
        'depth':     [1, 48],
        'n_bundles': [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    }
    run_grid_search(GRID, n_trials=5, project_name="VSA_Comparison")

    GRID = {
        'model_class': [
            HLB
        ],
        'dimension': [25**3, 30**3, 35**3, 40**3, 45**3, 50**3, 55**3, 60**3, 65**3, 70**3, 75**3, 80**3],
        'depth':     [1],
        'n_bundles': [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
    }
    run_grid_search(GRID, n_trials=5, project_name="Memory_Comp3")

    GRID = {
        'model_class': [
            HLB
        ],
        'dimension': [25**2, 50**2, 75**2, 100**2, 125**2, 150**2, 175**2, 200**2],
        'depth':     [1],
        'n_bundles': [1, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 400, 500, 600, 700, 800, 900, 1000]
    }
    run_grid_search(GRID, n_trials=5, project_name="Memory_Comp2")