from utils import input_from_text, output_from_text

def read_file(filename: str, **kwargs):
    inputs = []
    current_input = []

    with open(filename, 'r') as file:
        for line in file:
            line = line.strip()  # Remove leading/trailing whitespace
            if not line:  # Check for blank line
                if current_input:  # If the current input is not empty, add it to the list of inputs
                    inputs.append(current_input)
                    current_input = []  # Reset the current input
            else:
                current_input.append(line)

        if current_input:  # Add the last chunk if not already added
            inputs.append(current_input)

    return inputs


class SymmetricSudokuDataset:
    def __init__(self, input_dataset_file: str, output_dataset_file: str, **kwargs):
        self.input_dataset_file = input_dataset_file
        self.output_dataset_file = output_dataset_file
        self.outputs_present = self.output_dataset_file is not None
        self.kwargs = kwargs
        inputs = read_file(self.input_dataset_file, **kwargs)
        self.inputs = [input_from_text(input, **kwargs) for input in inputs]
        if self.outputs_present:
            outputs = read_file(self.output_dataset_file, **kwargs)
            self.outputs = [output_from_text(output, **kwargs) for output in outputs]
            assert len(self.inputs) == len(self.outputs)
    
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx: int):
        if idx >= len(self):
            raise IndexError(f"index: {idx} out of range for dataset of length: {len(self)}")
        return {
            "input": self.inputs[idx],
            "output": self.outputs[idx] if self.outputs_present else None
        }

def DataSet():
    return SymmetricSudokuDataset

if __name__ == "__main__":
    dataset = SymmetricSudokuDataset("data/train.txt", None, n=4)
    print(f"Length of dataset: {len(dataset)}")
    print(f"Sample Item: {dataset[0]}")
    print(f"Sample Item: {dataset[len(dataset)-1]}")