""" sudoku_data.py
    Sudoku related dataloaders

    Adapted from Maze dataloaders
    for use with DeepThinking-style models
    July 2025
"""

import pandas as pd
import torch
from torch.utils import data
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

class SudokuDataset(data.Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.reset_index(drop=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        puzzle = self.data.loc[idx, 'puzzle']
        solution = self.data.loc[idx, 'solution']

        # Convert puzzle: "." or "0" to 0, "1"-"9" to corresponding int
        puzzle_digits = [int(c) if c != '.' else 0 for c in puzzle]
        solution_digits = [int(c) for c in solution]

        # Convert to tensor
        puzzle_tensor = torch.tensor(puzzle_digits, dtype=torch.long)
        solution_tensor = torch.tensor(solution_digits, dtype=torch.long)

        # Shift digits by +1 so that:
        # - 0 (blank) -> 0 (mask channel)
        # - 1–9 -> 1–9 (index for one-hot)
        puzzle_tensor_shifted = puzzle_tensor
        solution_tensor_shifted = solution_tensor

        # Convert to one-hot (10 channels: [mask, 1, 2, ..., 9])
        puzzle_onehot = F.one_hot(puzzle_tensor_shifted, num_classes=10).float().transpose(1, 0)

        # Reshape to 9x9x10
        puzzle_onehot = puzzle_onehot.view(10, 9, 9)
        

        return puzzle_onehot, solution_tensor_shifted

def prepare_sudoku_loader(
    csv_path,
    train_batch_size=64,
    test_batch_size=64,
    shuffle=True,
    train_data=0.3,
    test_data=0.6
):
    """
    Args:
        csv_path (str): path to the CSV file with 'puzzle', 'solution', 'rating'.
        train_data (callable): function(df) -> filtered df for training (e.g. rating < 0.3)
        test_data (callable): function(df) -> filtered df for testing
    """

    df = pd.read_csv(csv_path)

    # Lọc tập train và test riêng biệt dựa trên độ khó
    df_train = df[df["difficulty"] < train_data]  # Giả sử độ khó train < 0.3
    df_test = df[(df["difficulty"] >= test_data) & (df["difficulty"] < test_data + 0.5)]


    # Split df_train thành train và val
    train_df, val_df = train_test_split(df_train, test_size=0.1, random_state=42)
    test_df = df_test.sample(frac=1.0, random_state=42)  # shuffle test nếu muốn

    trainset = SudokuDataset(train_df)
    valset = SudokuDataset(val_df)
    testset = SudokuDataset(test_df)

    trainloader = data.DataLoader(trainset,
                                  num_workers=0,
                                  batch_size=train_batch_size,
                                  shuffle=shuffle,
                                  drop_last=True)
    valloader = data.DataLoader(valset,
                                num_workers=0,
                                batch_size=test_batch_size,
                                shuffle=False,
                                drop_last=False)
    testloader = data.DataLoader(testset,
                                 num_workers=0,
                                 batch_size=test_batch_size,
                                 shuffle=False,
                                 drop_last=False)

    loaders = {"train": trainloader, "val": valloader, "test": testloader}
    return loaders


if __name__ == "__main__":
    # Lọc puzzle rating < 0.3 cho train, >= 0.6 cho test
    loaders = prepare_sudoku_loader(
        csv_path="/home/fis/workspace_AI/AI-RnD/hieutb2/deep-thinking/data/2/sudoku-3m.csv",
        train_data=0.3,
        test_data=0.6,
        train_batch_size=64,
        test_batch_size=64
    )

    train_loader = loaders["train"]