import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm

try:
    from data_process.sdk.solution import is_valid_board
except ImportError:
    def is_valid_board(x): return True

class SudokuDataProcessor:
    def __init__(self, config):
        """
        Initialize Sudoku data processor with configuration
        config: Instance of Config object
        """
        # Adapt to uppercase attribute names of Config class
        self.processed_data_dir = config.PROCESSED_DATA_DIR
        self.valid_datasets = ["big_kaggle", "minimal_17", "multiple_sol"]
        
        # Limit the number of puzzles
        self.max_puzzles = config.MAX_PUZZLES
        
        # Split ratios
        self.train_split = config.TRAIN_SPLIT
        self.val_split = config.VAL_SPLIT
        
        # Automatically calculate test split ratio
        self.test_split = 1.0 - self.train_split - self.val_split
        
        assert self.test_split >= -1e-9, f"Sum of split ratios ({self.train_split+self.val_split}) exceeds 1.0"

    def load_processed_data(self, dataset_name="big_kaggle"):
        if dataset_name not in self.valid_datasets:
            raise ValueError(f"Invalid dataset name. Valid options: {self.valid_datasets}")
        
        # Path concatenation
        puzzle_path = os.path.join(
            self.processed_data_dir, 
            dataset_name, 
            f"{dataset_name}_puzzles.npy"
        )
        solution_path = os.path.join(
            self.processed_data_dir, 
            dataset_name, 
            f"{dataset_name}_solutions.npy"
        )
        
        # Check if files exist
        if not os.path.exists(puzzle_path):
            raise FileNotFoundError(f"Data file not found: {puzzle_path}")

        # Load data
        print(f"Loading data from: {puzzle_path} ...")
        # Load all data if max_puzzles is None
        if self.max_puzzles is None:
            puzzles = np.load(puzzle_path, allow_pickle=True)
            solutions = np.load(solution_path, allow_pickle=True)
        else:
            puzzles = np.load(puzzle_path, allow_pickle=True)[:self.max_puzzles]
            solutions = np.load(solution_path, allow_pickle=True)[:self.max_puzzles]
            
        N = len(puzzles)
        print(f"Loaded data size: {N} samples")

        # Ensure shape is [N, 81]
        if puzzles.ndim == 3:
            puzzles = puzzles.reshape(N, 81)
            solutions = solutions.reshape(N, 81)

        # Calculate split points
        train_size = int(N * self.train_split)
        val_size = int(N * self.val_split)
        
        # Remaining samples are for test set
        test_size = N - train_size - val_size
        
        # Index slicing
        train_puzzles = puzzles[:train_size]
        train_solutions = solutions[:train_size]
        
        val_puzzles = puzzles[train_size : train_size + val_size]
        val_solutions = solutions[train_size : train_size + val_size]
        
        test_puzzles = puzzles[train_size + val_size :]
        test_solutions = solutions[train_size + val_size :]

        print(f"Data split: Train={len(train_puzzles)}, Val={len(val_puzzles)}, Test={len(test_puzzles)}")

        return (train_puzzles, train_solutions), (val_puzzles, val_solutions), (test_puzzles, test_solutions)


class SudokuDataset(Dataset):
    def __init__(self, puzzles, solutions):
        # Convert to Tensor
        self.puzzles = torch.from_numpy(puzzles).long()
        self.solutions = torch.from_numpy(solutions).long()
        assert len(self.puzzles) == len(self.solutions), "Mismatch between number of puzzles and solutions"
    
    def __len__(self):
        return len(self.puzzles)
    
    def __getitem__(self, idx):
        puzzle = self.puzzles[idx]
        solution = self.solutions[idx]
        # 0 indicates empty cell
        mask = (puzzle == 0)  
        
        return {
            "puzzle": puzzle,
            "solution": solution,
            "mask": mask
        }


def get_data_loaders(config):
    """
    Get DataLoader instances for training, validation and test
    config: Instance of Config object (contains uppercase attributes)
    """
    processor = SudokuDataProcessor(config)
    
    # Load data
    (tr_p, tr_s), (val_p, val_s), (te_p, te_s) = processor.load_processed_data(
        config.DATASET_NAME
    )

    # Build Dataset
    train_dataset = SudokuDataset(tr_p, tr_s)
    val_dataset = SudokuDataset(val_p, val_s)
    test_dataset = SudokuDataset(te_p, te_s) if len(te_p) > 0 else None

    # Build DataLoader
    # Use config.BATCH_SIZE, config.NUM_WORKERS and other uppercase attributes
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=(config.DEVICE == "cuda")
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=(config.DEVICE == "cuda")
    )
    
    test_loader = None
    if test_dataset is not None:
        test_loader = DataLoader(
            test_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=False,
            num_workers=config.NUM_WORKERS,
            pin_memory=(config.DEVICE == "cuda")
        )

    return train_loader, val_loader, test_loader