import io
import pathlib
import re
import tokenize

import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer


class RapidcadpyDataset(Dataset):
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.filepaths = self._get_filepaths(config.data_dir)
        self.data = self._load_data()

    def _get_filepaths(self, data_dir):
        """Retrieve all CAD file paths from the given directory."""
        data_dir = pathlib.Path(data_dir)
        return list(data_dir.glob("*.stl")) + list(data_dir.glob("*.obj"))

    def _load_data(self):
        """Load and preprocess data from CAD files."""
        data = []
        for filepath in self.filepaths:
            with open(filepath, "r") as file:
                content = file.read()
                tokens = self._tokenize_content(content)
                data.append({"filepath": filepath, "tokens": tokens})
        return pd.DataFrame(data)

    def _tokenize_content(self, content):
        """Tokenize the content of a CAD file."""
        # This is a placeholder for actual tokenization logic
        return list(re.finditer(r"\S+", content))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        return {
            "input_ids": torch.tensor(item.tokens.start, dtype=torch.long),
            "attention_mask": torch.tensor([1] * len(item.tokens), dtype=torch.long),
            "filepath": item.filepath,
        }

    @staticmethod
    def create_splits(config, **ds_kwargs):
        dataset = RapidcadpyDataset(config)
        train_ds, val_ds, test_ds = random_split(dataset, [0.8, 0.1, 0.1])
        collate_fn = dataset.collate
        train_dl = DataLoader(
            train_ds,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            drop_last=True,
        )
        val_dl = DataLoader(
            val_ds,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.data.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            drop_last=True,
        )
        test_dl = DataLoader(
            test_ds,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.data.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            drop_last=True,
        )
        return train_dl, val_dl, test_dl


def collate_fn(batch):
    """Custom collate function to handle variable length sequences."""
    return {
        "input_ids": pad_sequence([item["input_ids"] for item in batch]),
        "attention_mask": pad_sequence([item["attention_mask"] for item in batch]),
        "filepath": [item["filepath"] for item in batch],
    }


def pad_sequence(sequences, batch_first=True, padding_value=0):
    """Pad a list of sequences to the same length."""
    max_len = max(seq.size(0) for seq in sequences)
    padded_sequences = sequences[0].new_zeros((len(sequences), max_len)).fill_(padding_value)
    for i, seq in enumerate(sequences):
        end = seq.size(0)
        padded_sequences[i, :end] = seq
    return padded_sequences
