import os
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple

# Configuration parameters
MONOTONIC = True     # Make arrays monotonically decreasing
CLIP_VALUES = True   # Clip array values
CLIP_MIN = 0.0       # Minimum value after clipping
CLIP_MAX = 10000.0   # Maximum value after clipping
MIN_MAX_SCALING = True  # Apply min-max scaling after other transformations

def make_monotonic_decreasing(arr):
    """
    Transform an array to be monotonically decreasing.
    If an element is greater than the previous one, replace it with the previous value.
    """
    result = np.copy(arr)
    for i in range(1, len(result)):
        if result[i] > result[i-1]:
            result[i] = result[i-1]

    return result

def apply_min_max_scaling(arr):
    """
    Scale an array between 0 and 1 based on its min and max values.
    Handles edge cases properly.
    """
    max_val = np.max(arr)
    
    
    # Scale to [0, 1] range
    scaled_arr = (arr/max(arr))
    
    
    return scaled_arr

def list_cost_npy_files_in_plots():
    """
    Find and group .npy files that start with 'cost' in the 'plots' folder and its subfolders.
    Files are sorted alphabetically and clustered in groups of 10.
    Calculate mean and standard deviation across each cluster and save them as numpy arrays.
    If CLIP_VALUES is True, clip values to [CLIP_MIN, CLIP_MAX].
    If MONOTONIC is True, transform arrays to be monotonically decreasing.
    If MIN_MAX_SCALING is True, apply min-max scaling.
    """
    # Define the path to the plots folder
    plots_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'plots/iter7')
    
    # Check if the directory exists
    if not os.path.exists(plots_dir):
        print(f"Error: Directory '{plots_dir}' does not exist.")
        return
    
    # List to store all cost*.npy files with their paths
    cost_npy_files = []
    
    # Walk through all subdirectories
    for root, dirs, files in os.walk(plots_dir):
        # Sort directories to ensure alphabetical traversal
        dirs.sort()
        
        # Filter for .npy files starting with 'cost' in current directory
        cost_files = sorted([file for file in files if file.endswith('.npy') and file.startswith('cost')])
        
        # Store each file with its relative path
        for file in cost_files:
            rel_path = os.path.relpath(os.path.join(root, file), plots_dir)
            cost_npy_files.append((rel_path, os.path.join(root, file)))
    
    # Sort by relative path
    cost_npy_files.sort()
    
    # Print the files in clusters
    total_files = len(cost_npy_files)
    print(f"Found {total_files} cost*.npy files in '{plots_dir}' and its subdirectories:")
    
    cluster_results = []
    
    # Create output directory for saving statistics
    output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'plots', 'statistics')
    os.makedirs(output_dir, exist_ok=True)
    
    # Group files in clusters
    cluster_size = 5
    for cluster_idx in range(0, total_files, cluster_size):
        cluster_end = min(cluster_idx + cluster_size, total_files)
        cluster_files = cost_npy_files[cluster_idx:cluster_end]
        
        print(f"\n--- Cluster {cluster_idx//cluster_size + 1}: Files {cluster_idx+1}-{cluster_end} ---")
        
        # Lists to store arrays for this cluster
        cluster_arrays = []
        max_length = 0
        
        for i, (rel_path, full_path) in enumerate(cluster_files, 1):
            abs_idx = cluster_idx + i
            print(f"{abs_idx}. {rel_path}")
            
            # Load the numpy array
            try:
                array = np.load(full_path, allow_pickle=True)
                print(f"   - Shape: {array.shape}")
                print(f"   - Type: {array.dtype}")
                
                # Store original stats for reporting
                orig_min = np.min(array)
                orig_max = np.max(array)
                
                # Clip values if enabled
                if CLIP_VALUES:
                    original_array = array.copy()
                    array = np.clip(array, CLIP_MIN, CLIP_MAX)
                    clipped_count = np.sum((original_array < CLIP_MIN) | (original_array > CLIP_MAX))
                    if clipped_count > 0:
                        print(f"   - Clipped {clipped_count}/{len(array)} values to range [{CLIP_MIN}, {CLIP_MAX}]")
                
                # Transform to monotonically decreasing if enabled
                if MONOTONIC:
                    original_array = array.copy()

                    array = make_monotonic_decreasing(array)

                    # Calculate how many values were affected
                    modified_count = np.sum(original_array != array)
                    if modified_count > 0:
                        print(f"   - Modified {modified_count}/{len(array)} values to ensure monotonic decrease")
                
                # Apply min-max scaling if enabled
                if MIN_MAX_SCALING:
                    original_array = array.copy()
                    array = apply_min_max_scaling(array)
                    print(f"   - Min-max scaled: original range [{np.min(original_array):.4f}, {np.max(original_array):.4f}] → [0.0, 1.0]")
                
                # Store array for statistics calculation
                if len(array.shape) == 1:  # Ensure it's a 1D array
                    cluster_arrays.append(array)
                    max_length = max(max_length, len(array))
                else:
                    print(f"   - Warning: Array is not 1D, skipping for statistics")
            except Exception as e:
                print(f"   - Error loading file: {e}")
        
        # Calculate statistics only if we have valid arrays
        if cluster_arrays:
            # Pad shorter arrays to match the longest one
            padded_arrays = []
            for arr in cluster_arrays:
                if len(arr) < max_length:
                    # Pad with the last value
                    padded = np.pad(arr, (0, max_length - len(arr)), 'edge')
                    padded_arrays.append(padded)
                else:
                    padded_arrays.append(arr)
            
            # Stack arrays for statistics calculation
            stacked_arrays = np.vstack(padded_arrays)
            
            # Calculate mean and std deviation
            mean_array = np.mean(stacked_arrays, axis=0)
            std_array = np.std(stacked_arrays, axis=0)
            
            # Ensure mean_array is also monotonically decreasing if enabled
            if MONOTONIC:
                original_mean = mean_array.copy()
                mean_array = make_monotonic_decreasing(mean_array)
                modified_count = np.sum(original_mean != mean_array)
                if modified_count > 0:
                    print(f"   - Modified {modified_count}/{len(mean_array)} values in mean array to ensure monotonic decrease")
            
            # Print statistics summary
            print(f"\nCluster {cluster_idx//cluster_size + 1} Statistics:")
            print(f"   - Mean shape: {mean_array.shape}")
            print(f"   - Final mean value: {mean_array[-1]:.4f}")
            print(f"   - Final std deviation: {std_array[-1]:.4f}")
            

            cluster_name = f"cluster_{cluster_idx//cluster_size + 1}"
            
            # Save mean and std arrays
            np.save(os.path.join(output_dir, f"{cluster_name}_mean.npy"), mean_array)
            np.save(os.path.join(output_dir, f"{cluster_name}_std.npy"), std_array)
            
            print(f"Saved mean array to {os.path.join(output_dir, f'{cluster_name}_mean.npy')}")
            print(f"Saved std array to {os.path.join(output_dir, f'{cluster_name}_std.npy')}")
            
            # Store results for potential future use
            cluster_results.append({
                'cluster_idx': cluster_idx//cluster_size + 1,
                'mean': mean_array,
                'std': std_array,
                'files': [rel_path for rel_path, _ in cluster_files],
                'clipped': CLIP_VALUES,
                'clip_range': [CLIP_MIN, CLIP_MAX] if CLIP_VALUES else None,
                'monotonic': MONOTONIC,
                'normalized': MIN_MAX_SCALING
            })
        else:
            print(f"\nCluster {cluster_idx//cluster_size + 1}: No valid arrays for statistics calculation")
    
    return cost_npy_files, cluster_results

if __name__ == "__main__":
    cost_npy_files, cluster_results = list_cost_npy_files_in_plots()
    
    # Summary
    if cost_npy_files:
        print(f"\nTotal: {len(cost_npy_files)} cost*.npy files")
        print(f"Processed {len(cluster_results)} clusters with statistics")
        
        settings = []
        if CLIP_VALUES:
            settings.append(f"clipped to range [{CLIP_MIN}, {CLIP_MAX}]")
        if MONOTONIC:
            settings.append("monotonic decreasing")
        if MIN_MAX_SCALING:
            settings.append("min-max scaled")
            
        if settings:
            print(f"Applied transformations: {', '.join(settings)}")
    else:
        print("No cost*.npy files found")