#!/usr/bin/env python
"""
Runtime and memory profiling for QADRIFT components.
"""

import numpy as np
import time
import psutil
import os
from qisk_implementation import QISK
from real_world_datasets import get_real_world_datasets
from physically_correct_quantum_kernel import PhysicallyCorrectQuantumKernel, StreamingNystromApproximation
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import json

class RuntimeProfiler:
    """Profile runtime and memory usage of QADRIFT components."""
    
    def __init__(self):
        self.process = psutil.Process(os.getpid())
    
    def get_memory_usage(self):
        """Get current memory usage in MB."""
        return self.process.memory_info().rss / 1024 / 1024
    
    def profile_kernel_computation(self, n_samples_list=[50, 100, 200, 400], n_anchors_list=[8, 16, 32]):
        """Profile kernel computation time vs problem size."""
        print("Profiling Kernel Computation")
        print("=" * 40)
        
        results = []
        
        for n_samples in n_samples_list:
            for n_anchors in n_anchors_list:
                print(f"Testing n_samples={n_samples}, n_anchors={n_anchors}")
                
                # Generate test data
                np.random.seed(42)
                X = np.random.randn(n_samples, 4)
                
                # Create kernel and Nyström approximation
                kernel = PhysicallyCorrectQuantumKernel(n_qubits=4)
                nystrom = StreamingNystromApproximation(kernel, n_anchors=n_anchors)
                
                # Time full kernel matrix computation
                start_time = time.time()
                K_full = kernel.compute_kernel_matrix(X, X)
                full_kernel_time = time.time() - start_time
                
                # Time Nyström approximation
                start_time = time.time()
                start_memory = self.get_memory_usage()
                
                nystrom.fit(X)
                fit_time = time.time() - start_time
                
                start_time = time.time()
                Phi = nystrom.transform(X)
                transform_time = time.time() - start_time
                
                start_time = time.time()
                K_approx = Phi @ Phi.T
                reconstruct_time = time.time() - start_time
                
                end_memory = self.get_memory_usage()
                memory_usage = end_memory - start_memory
                
                # Compute approximation error
                approx_error = np.linalg.norm(K_full - K_approx, 'fro') / np.linalg.norm(K_full, 'fro')
                
                result = {
                    'n_samples': n_samples,
                    'n_anchors': n_anchors,
                    'full_kernel_time': full_kernel_time,
                    'nystrom_fit_time': fit_time,
                    'nystrom_transform_time': transform_time,
                    'nystrom_reconstruct_time': reconstruct_time,
                    'total_nystrom_time': fit_time + transform_time + reconstruct_time,
                    'memory_usage_mb': memory_usage,
                    'approximation_error': approx_error,
                    'speedup_factor': full_kernel_time / (fit_time + transform_time + reconstruct_time)
                }
                
                results.append(result)
                
                print(f"  Full kernel: {full_kernel_time:.3f}s")
                print(f"  Nyström:     {fit_time + transform_time + reconstruct_time:.3f}s")
                print(f"  Speedup:     {result['speedup_factor']:.1f}x")
                print(f"  Error:       {approx_error:.4f}")
                print()
        
        return results
    
    def profile_per_window_components(self, window_size=200):
        """Profile per-window runtime breakdown."""
        print("Profiling Per-Window Component Times")
        print("=" * 40)
        
        # Generate test data
        X, y, _ = DataGenerator.generate_sea(n_samples=400, seed=42)
        X_train, y_train = X[:window_size], y[:window_size]
        X_test, y_test = X[window_size:window_size+50], y[window_size:window_size+50]
        
        # Initialize components
        drqka = DRQKALite(n_qubits=4)
        
        print(f"Testing with window_size={window_size}, test_size=50")
        
        # Time components
        timings = {}
        
        # 1. Preprocessing
        start_time = time.time()
        X_train_processed = drqka.preprocess_features(X_train, fit=True)
        X_test_processed = drqka.preprocess_features(X_test, fit=False)
        timings['preprocessing'] = time.time() - start_time
        
        # 2. DRO weighting (simulate with previous data)
        start_time = time.time()
        X_prev = np.random.randn(window_size, 4)
        X_prev_processed = drqka.preprocess_features(X_prev, fit=False)
        weights = drqka.dro_weighting.compute_weights(X_train_processed, X_prev_processed)
        timings['dro_weighting'] = time.time() - start_time
        
        # 3. Nyström fitting
        start_time = time.time()
        drqka.nystrom.fit(X_train_processed)
        timings['nystrom_fitting'] = time.time() - start_time
        
        # 4. SPSA optimization (simplified)
        start_time = time.time()
        for i in range(5):  # Reduced iterations for profiling
            # Simulate parameter update
            current_params = drqka.kernel.get_parameters()
            perturbation = np.random.normal(0, 0.1, 4)
            new_params = np.maximum(0.1, current_params + perturbation)
            drqka.kernel.update_parameters(new_params)
            
            # Recompute kernel and KTA
            K_train = drqka.nystrom.get_kernel_matrix(X_train_processed)
            from drqka_lite_implementation import center_kernel_weighted, weighted_kernel_target_alignment
            K_centered = center_kernel_weighted(K_train, weights)
            kta = weighted_kernel_target_alignment(K_centered, y_train, weights)
        timings['spsa_optimization'] = time.time() - start_time
        
        # 5. SVM training
        start_time = time.time()
        drqka.nystrom.fit(X_train_processed)  # Refit with final params
        Phi_train = drqka.nystrom.transform(X_train_processed)
        K_final = Phi_train @ Phi_train.T
        classifier = SVC(kernel='precomputed', random_state=42)
        classifier.fit(K_final, y_train, sample_weight=weights)
        timings['svm_training'] = time.time() - start_time
        
        # 6. Prediction
        start_time = time.time()
        Phi_test = drqka.nystrom.transform(X_test_processed)
        K_test = Phi_test @ Phi_train.T
        predictions = classifier.predict(K_test)
        timings['prediction'] = time.time() - start_time
        
        # Total time
        total_time = sum(timings.values())
        
        print(f"Component breakdown:")
        for component, timing in timings.items():
            percentage = (timing / total_time) * 100
            print(f"  {component:20}: {timing:.3f}s ({percentage:.1f}%)")
        
        print(f"  {'Total':20}: {total_time:.3f}s")
        
        return timings
    
    def profile_memory_scaling(self, n_anchors_list=[8, 16, 32, 64]):
        """Profile memory usage vs number of anchors."""
        print("Profiling Memory Usage vs Anchors")
        print("=" * 40)
        
        results = []
        window_size = 200
        
        # Generate test data
        X, y, _ = DataGenerator.generate_sea(n_samples=window_size, seed=42)
        
        for n_anchors in n_anchors_list:
            print(f"Testing n_anchors={n_anchors}")
            
            # Measure initial memory
            initial_memory = self.get_memory_usage()
            
            # Create components
            kernel = PhysicallyCorrectQuantumKernel(n_qubits=4)
            nystrom = StreamingNystromApproximation(kernel, n_anchors=n_anchors)
            scaler = StandardScaler()
            
            # Process data
            X_scaled = scaler.fit_transform(X)
            nystrom.fit(X_scaled)
            Phi = nystrom.transform(X_scaled)
            K = Phi @ Phi.T
            
            # Measure peak memory
            peak_memory = self.get_memory_usage()
            memory_delta = peak_memory - initial_memory
            
            result = {
                'n_anchors': n_anchors,
                'memory_usage_mb': memory_delta,
                'memory_per_sample_kb': (memory_delta * 1024) / window_size,
                'feature_map_size': Phi.shape,
                'kernel_matrix_size': K.shape
            }
            
            results.append(result)
            
            print(f"  Memory usage: {memory_delta:.1f} MB")
            print(f"  Per sample:   {result['memory_per_sample_kb']:.1f} KB")
            print()
        
        return results

def main():
    """Run all profiling experiments."""
    profiler = RuntimeProfiler()
    
    print("QADRIFT Runtime and Memory Profiling")
    print("=" * 50)
    
    all_results = {}
    
    # 1. Kernel computation scaling
    print("\n1. Kernel Computation Scaling")
    kernel_results = profiler.profile_kernel_computation()
    all_results['kernel_scaling'] = kernel_results
    
    # 2. Per-window component breakdown
    print("\n2. Per-Window Component Breakdown")
    component_timings = profiler.profile_per_window_components()
    all_results['component_breakdown'] = component_timings
    
    # 3. Memory usage scaling
    print("\n3. Memory Usage Scaling")
    memory_results = profiler.profile_memory_scaling()
    all_results['memory_scaling'] = memory_results
    
    # Save results
    with open('experimental_results/profiling_results.json', 'w') as f:
        json.dump(all_results, f, indent=2)
    
    # Print summary
    print("\n" + "=" * 50)
    print("PROFILING SUMMARY")
    print("=" * 50)
    
    print("\nKernel Scaling (200 samples, 16 anchors):")
    example_result = next((r for r in kernel_results if r['n_samples'] == 200 and r['n_anchors'] == 16), kernel_results[0])
    print(f"  Nyström speedup: {example_result['speedup_factor']:.1f}x")
    print(f"  Approximation error: {example_result['approximation_error']:.4f}")
    
    print(f"\nPer-Window Timing (200 samples):")
    total_time = sum(component_timings.values())
    for component, timing in component_timings.items():
        print(f"  {component:20}: {timing:.3f}s ({timing/total_time*100:.1f}%)")
    
    print(f"\nMemory Usage (16 anchors, 200 samples):")
    example_memory = next((r for r in memory_results if r['n_anchors'] == 16), memory_results[0])
    print(f"  Total memory: {example_memory['memory_usage_mb']:.1f} MB")
    print(f"  Per sample: {example_memory['memory_per_sample_kb']:.1f} KB")
    
    print(f"\nResults saved to: experimental_results/profiling_results.json")

if __name__ == "__main__":
    main()