import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import matplotlib.pyplot as plt


class MazeDataProcessor:
    def __init__(self, config):
        self.processed_data_dir = config.PROCESSED_DATA_DIR
        self.dataset_name = getattr(config, 'DATASET_NAME')
        self.max_samples = getattr(config, 'MAX_PUZZLES', None) 
        
        self.train_split = config.TRAIN_SPLIT
        self.val_split = config.VAL_SPLIT
        self.test_split = 1.0 - self.train_split - self.val_split
        assert self.test_split >= 0, f"Error: train({self.train_split})+val({self.val_split})>1"

        self.grid_size = config.GRID_SIZE
        self.seq_len = self.grid_size * self.grid_size

    def _convert_to_class_indices(self, grid_data, is_solution=False):
        """
        Input: (N, H ,W, 3)
            channel 0: -1=wall, 1=else
            channel 1: 0=start, 1=end, -1=else
            channel 2: 1=path, -1=else
        Output: (N, L)
            Puzzle: 1=wall, 2=start, 3=end, 0=else
            Solution: 1=path, 0=else
        """
        N, H, W, C = grid_data.shape
        assert H == self.grid_size and W == self.grid_size and C == 3, \
            f"Incorrect data dim, (N, {self.grid_size}, {self.grid_size}, 3) needed but got {grid_data.shape}"
        
        flat_data = grid_data.reshape(N, self.seq_len, 3)
        indices = np.zeros((N, self.seq_len), dtype=np.longlong)
        
        if is_solution:
            is_path = (flat_data[..., 2] == 1)
            indices[is_path] = 1
        else:
            is_wall = (flat_data[..., 0] == -1)
            is_start = (flat_data[..., 1] == 0)
            is_end = (flat_data[..., 1] == 1)
            
            indices[is_wall] = 1
            indices[is_start] = 2
            indices[is_end] = 3
        return indices

    def load_data(self):
        input_path = os.path.join(self.processed_data_dir, self.dataset_name, "all_input_mazes.npy")  
        sol_path = os.path.join(self.processed_data_dir, self.dataset_name, "all_solution_mazes.npy") 
        for path in [input_path, sol_path]:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Dataset file missing: {path}. Please make sure the file path matches the dataset name!")
        
        print(f"----- Loading dataset: {self.processed_data_dir}/{self.dataset_name} -----")
        print(f"----- Puzzle path: {input_path} -----")
        print(f"----- Solution path: {sol_path} -----")
        
        inputs_raw = np.load(input_path, allow_pickle=True).astype(np.int8)
        solutions_raw = np.load(sol_path, allow_pickle=True).astype(np.int8)
        
        if self.max_samples is not None and self.max_samples > 0:
            inputs_raw = inputs_raw[:self.max_samples]
            solutions_raw = solutions_raw[:self.max_samples]
        
        N = len(inputs_raw)
        print(f">>> Original data shape: Input = {inputs_raw.shape} | Label = {solutions_raw.shape} | Total number of samples = {N}")
        
        inputs_indices = self._convert_to_class_indices(inputs_raw, is_solution=False)
        solutions_indices = self._convert_to_class_indices(solutions_raw, is_solution=True)

        self._verify_data_distribution(inputs_indices, solutions_indices)
        
        train_size = int(N * self.train_split)
        val_size = int(N * self.val_split)
        
        train_in = inputs_indices[:train_size]
        train_sol = solutions_indices[:train_size]
        val_in = inputs_indices[train_size : train_size + val_size]
        val_sol = solutions_indices[train_size : train_size + val_size]
        test_in = inputs_indices[train_size + val_size :]
        test_sol = solutions_indices[train_size + val_size :]
        
        print(f">>> Dataset division completed: Train={len(train_in)} | Val={len(val_in)} | Test={len(test_in)}")
        return (train_in, train_sol), (val_in, val_sol), (test_in, test_sol)

    def _verify_data_distribution(self, inputs_indices, solutions_indices):
        N = len(inputs_indices)
        flat_input = inputs_indices.flatten()
        flat_sol = solutions_indices.flatten()
        
        input_classes = {
            0: "else",
            1: "wall", 
            2: "start",
            3: "end"
        }
        
        sol_classes = {
            0: "else",
            1: "path"
        }
        
        print("\n Data Class Distribution Check (verifying label mappings):")
        print("=" * 50)
        print("Input (maze grid):")
        total_input = flat_input.size
        for cls_id, cls_name in input_classes.items():
            count = np.sum(flat_input == cls_id)
            ratio = count / total_input * 100
            print(f"  {cls_name}({cls_id}): {count:,} ({ratio:.2f}%)")
        
        print("\nSolution (path annotation):")
        total_sol = flat_sol.size
        for cls_id, cls_name in sol_classes.items():
            count = np.sum(flat_sol == cls_id)
            ratio = count / total_sol * 100
            print(f"  {cls_name}({cls_id}): {count:,} ({ratio:.2f}%)")
        print("=" * 50)


class MazeDataset(Dataset):
    def __init__(self, puzzles, solutions):
        self.puzzles = torch.from_numpy(puzzles).long()
        self.solutions = torch.from_numpy(solutions).long()
        
    def __len__(self):
        return len(self.puzzles)
    
    def __getitem__(self, idx):
        return {
            "puzzle": self.puzzles[idx], 
            "solution": self.solutions[idx]
        }


def get_maze_loaders(config):
    processor = MazeDataProcessor(config)
    (tr_p, tr_s), (val_p, val_s), (te_p, te_s) = processor.load_data()
    from TLAD.maze.utils import register_precomputed_valid_regions, clear_precomputed_valid_regions
    clear_precomputed_valid_regions()

    all_puzzles = np.concatenate([tr_p, val_p], axis=0)
    all_puzzles_torch = torch.from_numpy(all_puzzles).long()

    batch_size = 128
    valid_list = []
    for i in range(0, len(all_puzzles_torch), batch_size):
        batch = all_puzzles_torch[i:i+batch_size]
        with torch.no_grad():
            from TLAD.maze.utils import get_valid_path_region
            vr = get_valid_path_region(batch, config.GRID_SIZE)
        valid_list.append(vr)
    
    all_valid = torch.cat(valid_list, dim=0)
    register_precomputed_valid_regions(all_puzzles_torch, all_valid)

    train_dataset = MazeDataset(tr_p, tr_s)
    val_dataset = MazeDataset(val_p, val_s)
    test_dataset = MazeDataset(te_p, te_s)

    _pin_memory = (str(config.DEVICE) == "cuda") and torch.cuda.is_available()
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=_pin_memory,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=_pin_memory
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=_pin_memory
    )

    return train_loader, val_loader, test_loader

