import numpy as np
import os
import time
import torch
from sklearn.model_selection import train_test_split
import pdb

# ----------------------------------------------------------------------------------
# GPU Acceleration Setup
# ----------------------------------------------------------------------------------
try:
    import cupy as cp
    has_cupy = True
    has_gpu = True
    print("GPU acceleration available: CuPy detected")
except ImportError:
    cp = None
    has_cupy = False
    has_gpu = False
    print("CuPy not available. Using NumPy for array operations.")

# Check for PyTorch GPU
try:
    torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        has_gpu = True
    else:
        print("PyTorch couldn't find a GPU - falling back to CPU")
        has_gpu = has_cupy  # Only True if CuPy is available
except Exception as e:
    print(f"Error initializing PyTorch GPU: {e}. Falling back to CPU.")
    has_gpu = has_cupy  # Only True if CuPy is available

# ----------------------------------------------------------------------------------
# Main Processor Class
# ----------------------------------------------------------------------------------
class AptosProcessor:
    def __init__(self, 
                 theta_path="/drive2/Kuntal/Pysindy-experiment/aptos_theta_data/aptos_all_thetas.npy", 
                 ids_path="/drive2/Kuntal/Pysindy-experiment/aptos_theta_data/aptos_theta_ids.npy", 
                 train_size=0.6, 
                 random_state=42,
                 use_gpu=True):
        """
        Initialize the Aptos data processor.
        
        Parameters:
        -----------
        theta_path : str
            Path to the numpy file containing theta values
        ids_path : str
            Path to the numpy file containing theta IDs
        train_size : float
            Size of the training set (default: 0.6 for 60/40 split)
        random_state : int
            Random seed for reproducibility
        use_gpu : bool
            Whether to use GPU acceleration if available
        """
        self.theta_path = theta_path
        self.ids_path = ids_path
        self.train_size = train_size
        self.random_state = random_state
        self.theta_T = None  # Training parameters
        self.theta_V = None  # Validation parameters
        self.theta_shape = None  # Will store the shape of theta vectors
        
        # Update use_gpu based on actual availability
        self.use_gpu = use_gpu and has_gpu and has_cupy
        
        # Set the appropriate array module based on GPU availability
        self.xp = cp if self.use_gpu and has_cupy else np
    
    # ----------------------------------------------------------------------------------
    # Data Loading and Preparation
    # ----------------------------------------------------------------------------------
    def load_data(self):
        """
        Load the Aptos dataset from theta files.
        If files don't exist, generate dummy data.
        """
        # Try to load from the specified theta and ids paths
        try:
            if os.path.exists(self.theta_path) and os.path.exists(self.ids_path):
                print(f"Loading theta data from {self.theta_path}")
                # Load with NumPy first
                thetas_np = np.load(self.theta_path)
                print(f"Loading theta IDs from {self.ids_path}")
                ids_np = np.load(self.ids_path)
                
                print(f"Loaded {len(thetas_np)} theta vectors with shape {thetas_np.shape}")
                self.theta_shape = thetas_np.shape[1:]  # Store shape for later use
                
                # Convert to GPU arrays if using GPU
                if self.use_gpu and has_cupy:
                    self.data = cp.array(thetas_np)
                    self.labels = cp.array(ids_np)
                    print("Data transferred to GPU")
                else:
                    self.data = thetas_np
                    self.labels = ids_np
                    if self.use_gpu:
                        print("GPU requested but CuPy not available. Using NumPy instead.")
                
                return self.data, self.labels
        except Exception as e:
            print(f"Error loading .npy files: {e}")
            print("Falling back to dummy data")
            
        # Generate dummy data if loading fails
        print("Using generated dummy data.")
        if self.use_gpu and has_cupy:
            self.data = cp.random.rand(3000, 20)
            self.labels = cp.random.randint(0, 5, size=3000)
        else:
            self.data = np.random.rand(3000, 20)
            self.labels = np.random.randint(0, 5, size=3000)
            
        self.theta_shape = (20,)  # Single dimension
        return self.data, self.labels
    
    def split_data(self):
        """
        Split the data into training and validation sets.
        """
        # Move to CPU for sklearn, then back to GPU if needed
        if self.use_gpu and has_cupy:
            data_np = cp.asnumpy(self.data)
            labels_np = cp.asnumpy(self.labels)
        else:
            data_np = self.data
            labels_np = self.labels
            
        X_train_np, X_valid_np, _, _ = train_test_split(
            data_np, labels_np, 
            train_size=self.train_size, 
            random_state=self.random_state
        )
        
        # Move back to GPU if using GPU
        if self.use_gpu and has_cupy:
            self.theta_T = cp.array(X_train_np)
            self.theta_V = cp.array(X_valid_np)
        else:
            self.theta_T = X_train_np
            self.theta_V = X_valid_np
        
        print(f"Training set: {len(X_train_np)} samples")
        print(f"Validation set: {len(X_valid_np)} samples")
        
        return self.theta_T, self.theta_V
    
    # ----------------------------------------------------------------------------------
    # Core Statistical Calculations
    # ----------------------------------------------------------------------------------
    # def calculate_rho_k_gpu(self, Q_k):
    #     """
    #     Calculate ρ_k values for the training set using leave-one-out approach.
    #     GPU-accelerated version using matrix operations.
        
    #     For each θ_j in training set:
    #     - Remove θ_j from the training set (leave-one-out)
    #     - Calculate ρ_j^k = sum(Q_k * θ_i) for all i ≠ j
    #     - Divide by (n_samples - 1) to get the average
        
    #     Returns:
    #     --------
    #     rho_j_k : array
    #         Array of ρ_j^k values for each training sample j
    #     rho_avg : float
    #         Average of all ρ_j^k values
    #     """
    #     xp = self.xp
    #     n_samples = len(self.theta_T)
        
    #     # Reshape Q_k to match theta dimensions
    #     Q_k_reshaped = Q_k.reshape(self.theta_shape)
        
    #     # Use PyTorch for more efficient leave-one-out computation if available
    #     if self.use_gpu and has_cupy and torch.cuda.is_available():
    #         # Convert to PyTorch tensors
    #         if has_cupy:
    #             Q_k_torch = torch.tensor(cp.asnumpy(Q_k_reshaped)).to(torch_device)
    #             theta_T_torch = torch.tensor(cp.asnumpy(self.theta_T)).to(torch_device)
    #         else:
    #             Q_k_torch = torch.tensor(Q_k_reshaped).to(torch_device)
    #             theta_T_torch = torch.tensor(self.theta_T).to(torch_device)
            
    #         # Calculate all pairwise products at once
    #         # Shape: [n_samples, n_samples]
    #         products = torch.zeros(n_samples, n_samples, device=torch_device)
            
    #         # For each sample, calculate the sum of products with all training samples
    #         for i in range(n_samples):
    #             products[i] = torch.sum(Q_k_torch * theta_T_torch, dim=tuple(range(1, len(self.theta_shape) + 1)))
            
    #         # For each j, exclude its own product
    #         mask = torch.ones(n_samples, n_samples, device=torch_device) - torch.eye(n_samples, device=torch_device)
    #         masked_products = products * mask
            
    #         # Calculate rho_j_k as the sum of products divided by (n_samples - 1)
    #         row_sums = torch.sum(masked_products, dim=1) / (n_samples - 1)
            
    #         # Move back to CPU/NumPy for consistency with the rest of the code
    #         if has_cupy:
    #             rho_j_k = cp.array(row_sums.cpu().numpy()) if self.use_gpu else np.array(row_sums.cpu().numpy())
    #         else:
    #             rho_j_k = np.array(row_sums.cpu().numpy())
    #     else:
    #         # Fall back to NumPy/CuPy implementation
    #         # Pre-calculate all element-wise products
    #         # Shape: [n_samples, product_dimensions...]
    #          # theta_k_j formula code
    #         all_products = xp.sum(Q_k_reshaped * self.theta_T, axis=tuple(range(1, len(self.theta_shape) + 1)))  
            
    #         # For each j, calculate the sum excluding itself
    #         rho_j_k = xp.zeros(n_samples)
    #         for j in range(n_samples):
    #             # Sum all products except the j-th one
    #             sum_except_j = xp.sum(all_products) - all_products[j]
    #             rho_j_k[j] = sum_except_j / (n_samples - 1)
        
    #     # Calculate ρ_avg as the average of all ρ_j^k values
    #     rho_avg = xp.mean(rho_j_k)
        
    #     return rho_j_k, rho_avg

    #   new theta_k_j formula code
    def calculate_rho_k_gpu(self, Q_k):
        xp = self.xp
        n_samples = len(self.theta_T)
        Q_k_reshaped = Q_k.reshape(self.theta_shape)
        Q_k_reshaped = Q_k_reshaped.reshape(1, 60)
        self.theta_T = self.theta_T.reshape(self.theta_T.shape[0], -1)
        
        # Dot products
        dot_products = xp.sum(Q_k_reshaped * self.theta_T, axis=1)
        
        # Norms

        norm_Q_k = xp.linalg.norm(Q_k_reshaped)
        norm_theta_T = xp.linalg.norm(self.theta_T, axis=1)
        
        # Cosine similarity
        all_products = dot_products / (norm_Q_k * norm_theta_T + 1e-8)  # Prevent division by zero

        rho_j_k = xp.zeros(n_samples)
        for j in range(n_samples):
            sum_except_j = xp.sum(all_products) - all_products[j]
            rho_j_k[j] = sum_except_j / (n_samples - 1)

        rho_avg = xp.mean(rho_j_k)
        #pdb.set_trace()
        return rho_j_k, rho_avg


    
    # def calculate_rho_m_gpu(self, Q_m):
    #     """
    #     Calculate ρ_m values for the training set.
        
    #     For each θ_j in training set:
    #     - Calculate ρ_j^m = inner product of Q_m with θ_j
    #     - This gives us an array of ρ_j^m values (one per training sample)
        
    #     Then calculate ρ_m as the average of all ρ_j^m values.
        
    #     Returns:
    #     --------
    #     rho_m : float
    #         Average of all ρ_j^m values
    #     rho_j_m : array
    #         Array of ρ_j^m values for each training sample j
    #     """
    #     xp = self.xp
        
    #     # Reshape Q_m to match theta dimensions
    #     Q_m_reshaped = Q_m.reshape(self.theta_shape)
        
    #     # Use PyTorch for more efficient computation if available
    #     if self.use_gpu and torch.cuda.is_available():
    #         # Convert to PyTorch tensors
    #         if has_cupy:
    #             Q_m_torch = torch.tensor(cp.asnumpy(Q_m_reshaped)).to(torch_device)
    #             theta_T_torch = torch.tensor(cp.asnumpy(self.theta_T)).to(torch_device)
    #         else:
    #             Q_m_torch = torch.tensor(Q_m_reshaped).to(torch_device)
    #             theta_T_torch = torch.tensor(self.theta_T).to(torch_device)
            
    #         # Calculate rho_j_m for each training sample
    #         # This is element-wise product of Q_m with each training sample
    #         products = torch.sum(Q_m_torch * theta_T_torch, dim=tuple(range(1, len(self.theta_shape) + 1)))
            
    #         # Calculate ρ_j^m for each training sample
    #         rho_j_m = products
            
    #         # Move back to CuPy for consistency
    #         if has_cupy:
    #             rho_j_m_cp = cp.array(rho_j_m.cpu().numpy()) if self.use_gpu else np.array(rho_j_m.cpu().numpy())
    #         else:
    #             rho_j_m_np = np.array(rho_j_m.cpu().numpy())
            
    #         # Calculate ρ_m as the average of all ρ_j^m values
    #         if has_cupy and self.use_gpu:
    #             rho_m = cp.mean(rho_j_m_cp)
    #             return rho_m, rho_j_m_cp
    #         else:
    #             rho_m = np.mean(rho_j_m_np)
    #             return rho_m, rho_j_m_np
    #     else:
    #         # Fall back to NumPy/CuPy implementation
    #         # Calculate rho_j_m for each training sample j
    #         # This is element-wise product of Q_m with each training sample
    #         rho_j_m = xp.sum(Q_m_reshaped * self.theta_T, axis=tuple(range(1, len(self.theta_shape) + 1)))
            
    #         # Calculate ρ_m as the average of all ρ_j^m values
    #         rho_m = xp.mean(rho_j_m)
            
    #         return rho_m, rho_j_m

    def calculate_rho_m_gpu(self, Q_m):
        xp = self.xp
        Q_m_reshaped = Q_m.reshape(self.theta_shape)
        Q_m_reshaped = Q_m_reshaped.reshape(1, 60)
        n_samples = len(self.theta_T)

        # Dot products
        dot_products = xp.sum(Q_m_reshaped * self.theta_T, axis=1)
        # Norms
        norm_Q_m = xp.linalg.norm(Q_m_reshaped)
        norm_theta_T = xp.linalg.norm(self.theta_T, axis=1)
        # Cosine similarity for all training samples
        rho_j_m = dot_products / (norm_Q_m * norm_theta_T + 1e-8)
        rho_m = xp.mean(rho_j_m)
        return rho_m, rho_j_m


    def calculate_sigma_m(self, rho_m, rho_avg):
        """
        Calculate σ_m = ρ_m - ρ_avg
        Using the ρ_m value and the average ρ_k value
        """
        return rho_m - rho_avg
    
    def empirical_confidence_interval(self, subtracted_sigma_j_m_values, alpha=0.05):
        """
        Calculate empirical confidence interval for σ_m
        based on the formula from the whiteboard.
        
        Parameters:
        -----------
        subtracted_sigma_j_m_values : array
            Array of σ_j^m values for each validation sample j
            (already calculated as ρ_j^m - ρ_avg for each validation sample)
        alpha : float
            Significance level (default: 0.05 for 95% confidence)
            
        Returns:
        --------
        confidence_interval : list
            Lower and upper bounds of the confidence interval [-selected_sigma_m, selected_sigma_m]
        selected_sigma_m : float
            The selected σ_m value based on the formula
        """
        xp = self.xp
        
        # Convert to CPU if needed for easier processing
        if self.use_gpu and has_cupy:
            sigma_m_np = cp.asnumpy(subtracted_sigma_j_m_values)
        else:
            sigma_m_np = subtracted_sigma_j_m_values
            
        # Sort sigma_m values
        sorted_sigma_m = np.sort(sigma_m_np)
        
        # Get validation set size for formula calculation
        n_valid = len(sorted_sigma_m)
        
        # Calculate the index using the formula
        formula_index = int((n_valid / 2 + 1) * (1 - alpha))
        
        # Ensure index is within bounds
        formula_index = min(max(0, formula_index), len(sorted_sigma_m) - 1)
        
        # Get the sigma_m value at this index
        selected_sigma_m = sorted_sigma_m[formula_index]
        
        # Calculate confidence interval bounds using the selected sigma_m value directly
        # Ensure negative on left, positive on right
        lower_bound = -abs(selected_sigma_m)
        upper_bound = abs(selected_sigma_m)
        
        print(f"Selected index using formula: {formula_index} (of {len(sorted_sigma_m)} values)")
        print(f"Selected σ_m value: {selected_sigma_m}")
        print(f"Confidence interval: [{lower_bound}, {upper_bound}]")
        
        return [lower_bound, upper_bound], selected_sigma_m
    
    # ----------------------------------------------------------------------------------
    # Main Experiment Runner
    # ----------------------------------------------------------------------------------
    def run_experiment(self, normalize_theta=True):
        """
        Run the experiment using each theta in the training set as Q.
        
        Parameters:
        -----------
        normalize_theta : bool
            Whether to normalize the theta vectors to unit length
        """
        if self.theta_T is None or self.theta_V is None:
            self.load_data()
            self.split_data()
        
        xp = self.xp
        
        print(f"Running experiment with {'GPU' if self.use_gpu else 'CPU'}...")
        start_time = time.time()
        
        # Arrays to store results for each theta used as Q
        n_training = len(self.theta_T)
        n_validation = len(self.theta_V)
        print(f"Will process {n_training} training samples as Q vectors")
        print(f"Each Q will be evaluated against {n_validation} validation samples")
        
        rho_avgs = xp.zeros(n_training)
        rho_ms = xp.zeros(n_training)
        sigma_ms = xp.zeros(n_training)
        
        # Collect all sigma_j_m values across all thetas
        all_subtracted_sigma_j_m_values = []
        
        # For each theta in training set, use it as Q
        for i, theta_as_q in enumerate(self.theta_T):
            if i % 100 == 0:
                print(f"Processing theta {i+1}/{n_training}...")
            
            # Flatten theta to use as Q
            theta_flat = theta_as_q.flatten()
            
            # Normalize if requested
            if normalize_theta:
                theta_flat = theta_flat / xp.linalg.norm(theta_flat)
            
            # Calculate ρ_k for the training set using leave-one-out
            rho_j_k, rho_avg = self.calculate_rho_k_gpu(theta_flat)
            rho_avgs[i] = rho_avg
            
            # Calculate ρ_m for the validation set
            rho_m, rho_j_m = self.calculate_rho_m_gpu(theta_flat)
            rho_ms[i] = rho_m
            
            # Calculate σ_m
            sigma_ms[i] = self.calculate_sigma_m(rho_m, rho_avg)
            
            # Calculate σ_j^m for each validation sample (subtraction happens here)
            subtracted_sigma_j_m = rho_j_m - rho_avg
            
            # Collect all subtracted sigma_j_m values
            if self.use_gpu and has_cupy:
                all_subtracted_sigma_j_m_values.append(cp.asnumpy(subtracted_sigma_j_m))
            else:
                all_subtracted_sigma_j_m_values.append(subtracted_sigma_j_m)
            
            if i % 100 == 0:
                print(f"  - Processed {n_validation} validation samples for this theta")
        
        # Calculate overall statistics
        avg_rho_k = xp.mean(rho_avgs)
        
        # Pool all subtracted sigma_j_m values for one overall confidence interval
        pooled_subtracted_sigma_j_m = np.concatenate(all_subtracted_sigma_j_m_values)
        print(f"Total number of subtracted sigma_j_m values for confidence interval: {len(pooled_subtracted_sigma_j_m)}")
        
        # Calculate one overall confidence interval
        confidence_interval, selected_sigma_m = self.empirical_confidence_interval(pooled_subtracted_sigma_j_m)
        
        print("\nExperiment Settings:")
        print(f"- Using {'GPU' if self.use_gpu else 'CPU'} acceleration")
        if normalize_theta:
            print("- Theta vectors normalized to unit length")
        
        print(f"\nResults (averaged over {n_training} theta vectors):")
        print(f"- Average ρ_avg (training): {avg_rho_k}")
        print(f"- Overall confidence interval: {confidence_interval}")
        

        
        return {
            'rho_avgs': rho_avgs,
            'rho_ms': rho_ms,
            'sigma_ms': sigma_ms,
            'confidence_interval': confidence_interval,
            'selected_sigma_m': selected_sigma_m,
            'avg_rho_k': float(avg_rho_k) if hasattr(avg_rho_k, 'item') else float(avg_rho_k),

            'runtime': time.time() - start_time
        }

# ----------------------------------------------------------------------------------
# Main Execution
# ----------------------------------------------------------------------------------
if __name__ == "__main__":
    # Example usage with explicit paths
    processor = AptosProcessor(
        theta_path="/drive2/Kuntal/Pysindy-experiment/aptos_theta_data/aptos_all_thetas.npy",
        ids_path="/drive2/Kuntal/Pysindy-experiment/aptos_theta_data/aptos_theta_ids.npy",
        use_gpu=True  # Will automatically fall back to CPU if GPU is not available
    )
    
    # Run the experiment with a fixed Q vector
    results = processor.run_experiment(
        normalize_theta=True   # Normalize theta vectors to unit length
    )
    
    print("\nExperiment Results Summary:")
    print(f"Total runtime: {results['runtime']:.2f} seconds")
    print(f"Average ρ_avg (training): {results['avg_rho_k']}")
    print(f"Overall confidence interval: {results['confidence_interval']}")
    
    # Check if zero is in the confidence interval
    if results['confidence_interval'][0] <= 0 <= results['confidence_interval'][1]:
        print("\nConclusion: No significant difference between training and validation performance.")
    else:
        print("\nConclusion: Significant difference detected between training and validation performance.")
    
