import argparse
import os
import sys
import numpy as np
from multiprocessing import freeze_support
import matplotlib.pyplot as plt

# Maze dataset core dependencies
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators

# Custom utility for large data handling (ZANJ)
from zanj import ZANJ

# --------------------------
# Global Configuration (Fixed for Reproducibility)
# --------------------------
LOCAL_DATA_ROOT: str = "data/maze_dataset/"
ZANJ_INSTANCE: ZANJ = ZANJ(external_list_threshold=256)

def str2bool(value: str) -> bool:
    """Convert string to boolean for argument parsing (standard utility)."""
    if isinstance(value, bool):
        return value
    lower_val = value.lower()
    if lower_val[0] in ['0', 'n', 'f']:
        return False
    elif lower_val[0] in ['1', 'y', 't']:
        return True
    raise ValueError(f"Invalid boolean value: {value}")

# --------------------------
# Core Maze Dataset Generation Pipeline
# --------------------------
if __name__ == '__main__':
    # Initialize multiprocessing and argument parser
    freeze_support()
    sys.argv = [sys.argv[0]]
    parser = argparse.ArgumentParser(description="Reproducible Maze Dataset Generation (with Visualization and NPY Merging)")
    
    # Dataset Metadata (Key for Reproducibility)
    parser.add_argument("--dataset_name", type=str, default="Maze5-30000", 
                        help="Unique identifier for the dataset")
    parser.add_argument("--grid_n", type=int, default=5, 
                        help="Size of square maze lattice (rows = columns = grid_n)")
    parser.add_argument("--n_mazes", type=int, default=30000, 
                        help="Total number of mazes to generate")
    parser.add_argument("--maze_ctor", type=str, default="gen_dfs", 
                        help="Maze generation algorithm (from LatticeMazeGenerators)")
    
    # Pipeline Control Flags
    parser.add_argument("--do_download", type=str2bool, default=False, 
                        help="Download pre-existing dataset (not used for custom generation)")
    parser.add_argument("--load_local", type=str2bool, default=False, 
                        help="Load dataset from local storage")
    parser.add_argument("--do_generate", type=str2bool, default=True, 
                        help="Execute custom maze generation (core flag)")
    parser.add_argument("--save_local", type=str2bool, default=False, 
                        help="Save raw dataset to local storage")
    parser.add_argument("--gen_parallel", type=str2bool, default=False, 
                        help="Enable parallel maze generation (for large n_mazes)")
    
    # Path and Filter Configuration
    parser.add_argument("--local_base_path", type=str, default="data/maze", 
                        help="Root directory for local dataset storage")
    parser.add_argument("--min_length", type=int, default=3, 
                        help="Minimum path length filter for generated mazes")
    parser.add_argument("--max_length", type=int, default=10, 
                        help="Maximum path length filter for generated mazes")
    
    # Verbosity and Visualization
    parser.add_argument("--verbose", type=str2bool, default=True, 
                        help="Print dataset statistics and progress")
    
    # Parse arguments and configure output paths
    args = parser.parse_args()
    args.local_base_path = os.path.join(
        args.local_base_path,
        args.dataset_name,
        f"grid_n-{args.grid_n}_n_mazes-{args.n_mazes}_min-{args.min_length}_max-{args.max_length}"
    )
    
    # --------------------------
    # Step 1: Maze Dataset Initialization via Config (Reproducible Core)
    # --------------------------
    maze_config: MazeDatasetConfig = MazeDatasetConfig(
        name=args.dataset_name,
        grid_n=args.grid_n,
        n_mazes=args.n_mazes,
        maze_ctor=getattr(LatticeMazeGenerators, args.maze_ctor),
    )
    
    maze_dataset: MazeDataset = MazeDataset.from_config(
        maze_config,
        do_download=args.do_download,
        load_local=args.load_local,
        do_generate=args.do_generate,
        save_local=args.save_local,
        local_base_path=args.local_base_path,
        verbose=args.verbose,
        zanj=ZANJ_INSTANCE,
        gen_parallel=args.gen_parallel,
    )
    
    # --------------------------
    # Step 2: Filter Mazes by Path Length (Quality Control)
    # --------------------------
    filtered_dataset: MazeDataset = maze_dataset.filter_by.path_length(min_length=args.min_length)
    
    # --------------------------
    # Step 3: Maze Tensor Encoding (3-Channel Semantic Representation)
    # --------------------------
    # Channel Semantics (Fixed for Model Compatibility - Critical for Reproducibility)
    CHANNEL_0_WALL: int = -1    # Channel 0: -1=wall, 1=path (free space)
    CHANNEL_0_PATH: int = 1
    CHANNEL_1_START: int = 0    # Channel 1: 0=start, 1=end
    CHANNEL_1_END: int = 1
    CHANNEL_2_SOLUTION: int = 1 # Channel 2: 1=solution path, -1=unused
    
    # Initialize output directories for visualization and NPY files
    vis_dir = os.path.join(args.local_base_path, "maze_visualizations")
    npy_save_dir = os.path.join(args.local_base_path, f"N-{len(filtered_dataset)}")
    os.makedirs(vis_dir, exist_ok=True)
    os.makedirs(npy_save_dir, exist_ok=True)
    
    # Collect encoded mazes (input: no solution; solution: full semantic information)
    all_input_mazes = []
    all_solution_mazes = []
    
    # --------------------------
    # Step 4: Batch Encode and Visualize Mazes
    # --------------------------
    total_samples = len(filtered_dataset)
    for sample_idx in range(total_samples):
        # Load raw maze data
        maze_sample = filtered_dataset[sample_idx]
        binary_grid = maze_sample._as_pixels_bw()  # (H, W) bool: True = path
        
        # Initialize 3-channel grid with default wall value (-1)
        semantic_grid = np.full((*binary_grid.shape, 3), CHANNEL_0_WALL, dtype=np.int8)
        
        # Encode path (Channel 0)
        semantic_grid[binary_grid == True, 0] = CHANNEL_0_PATH
        
        # Encode start/end positions (Channel 1)
        start_y, start_x = (maze_sample.start_pos[0] * 2 + 1, maze_sample.start_pos[1] * 2 + 1)
        end_y, end_x = (maze_sample.end_pos[0] * 2 + 1, maze_sample.end_pos[1] * 2 + 1)
        
        if 0 <= start_y < semantic_grid.shape[0] and 0 <= start_x < semantic_grid.shape[1]:
            semantic_grid[start_y, start_x, 1] = CHANNEL_1_START
        if 0 <= end_y < semantic_grid.shape[0] and 0 <= end_x < semantic_grid.shape[1]:
            semantic_grid[end_y, end_x, 1] = CHANNEL_1_END
        
        # Encode solution path and connections (Channel 2)
        solution_coords = maze_sample.solution
        for (row, col) in solution_coords:
            y, x = (row * 2 + 1, col * 2 + 1)
            if 0 <= y < semantic_grid.shape[0] and 0 <= x < semantic_grid.shape[1]:
                semantic_grid[y, x, 2] = CHANNEL_2_SOLUTION
        
        # Encode path connections between consecutive solution points
        for idx in range(len(solution_coords) - 1):
            curr_row, curr_col = solution_coords[idx]
            next_row, next_col = solution_coords[idx + 1]
            line_y = curr_row * 2 + 1 + (next_row - curr_row)
            line_x = curr_col * 2 + 1 + (next_col - curr_col)
            if 0 <= line_y < semantic_grid.shape[0] and 0 <= line_x < semantic_grid.shape[1]:
                semantic_grid[line_y, line_x, 2] = CHANNEL_2_SOLUTION
        
        # Split into input (no solution) and solution (full semantics)
        input_maze = semantic_grid.copy()
        input_maze[:, :, 2] = CHANNEL_0_WALL  # Clear solution channel for input
        solution_maze = semantic_grid.copy()
        
        # Collect encoded mazes
        all_input_mazes.append(input_maze)
        all_solution_mazes.append(solution_maze)
        
        # --------------------------
        # Step 5: Visualize First 5 Samples (Qualitative Validation)
        # --------------------------
        if sample_idx < 5:
            # Initialize visualization image (black = wall, white = path)
            vis_image = np.zeros((*semantic_grid.shape[:2], 3), dtype=np.float32)
            
            # Plot path (white) and solution (light blue)
            path_mask = (semantic_grid[:, :, 0] == CHANNEL_0_PATH)
            sol_mask = (semantic_grid[:, :, 2] == CHANNEL_2_SOLUTION)
            vis_image[path_mask] = [1.0, 1.0, 1.0]
            vis_image[sol_mask] = [0.2, 0.7, 0.9]
            
            # Plot start (green) and end (red) points
            plt.figure(figsize=(8, 8))
            plt.imshow(vis_image)
            plt.scatter(start_x, start_y, color='limegreen', s=180, edgecolor='darkgreen', linewidth=2, zorder=10)
            plt.scatter(end_x, end_y, color='crimson', s=180, edgecolor='darkred', linewidth=2, zorder=10)
            
            # Add labels and formatting
            plt.axis('on')
            plt.xticks(np.arange(0, semantic_grid.shape[1], 2.5))
            plt.yticks(np.arange(0, semantic_grid.shape[0], 2.5))
            plt.legend(handles=[
                plt.scatter([], [], color='limegreen', edgecolor='darkgreen', linewidth=2, label='Start'),
                plt.scatter([], [], color='crimson', edgecolor='darkred', linewidth=2, label='End')
            ], loc='upper right', framealpha=0.8, fontsize=12)
            plt.title(f"Maze Sample {sample_idx+1} (Grid Size: {args.grid_n}x{args.grid_n})", fontsize=14, pad=10)
            
            # Save visualization
            plt.savefig(
                os.path.join(vis_dir, f"maze_visualization_{sample_idx+1}.png"),
                dpi=100,
                bbox_inches='tight',
                pad_inches=0.1
            )
            plt.close()
    
    # --------------------------
    # Step 6: Save Encoded Mazes as NPY (For Model Training)
    # --------------------------
    np.save(os.path.join(npy_save_dir, "input_mazes.npy"), np.array(all_input_mazes))
    np.save(os.path.join(npy_save_dir, "solution_mazes.npy"), np.array(all_solution_mazes))
    
    # --------------------------
    # Final Reproducibility Summary
    # --------------------------
    if args.verbose:
        print("=" * 60)
        print("Maze Dataset Generation Completed Successfully")
        print("=" * 60)
        print(f"1. Visualizations Saved to: {vis_dir}")
        print(f"2. Encoded Input Mazes: {os.path.join(npy_save_dir, 'all_input_mazes.npy')}")
        print(f"3. Encoded Solution Mazes: {os.path.join(npy_save_dir, 'all_solution_mazes.npy')}")
        print(f"4. Dataset Shape (Samples, H, W, Channels): {np.array(all_input_mazes).shape}")
        print(f"5. Filtered Samples (Path Length ≥ {args.min_length}): {total_samples}")
        print("=" * 60)