import torch
import numpy as np
from pathlib import Path
from collections import defaultdict
from torch.utils.data import Dataset
class SequenceDataset(Dataset):
    def __init__(self, config, sequences):
        self.sequences = sequences
        self.config = config

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        item_seq = seq[:-1]
        labels = seq[-1]
        seq_length = len(item_seq)
        padding_length = self.config['max_seq_length'] - len(item_seq)
        if padding_length > 0:
            item_seq = item_seq + [0] * padding_length 
        return {
            'item_seqs': torch.tensor(item_seq, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'seq_lengths': seq_length,
        }


class RecDataReader:
    """
    Reads and preprocesses recommendation data for a single domain.
    """
    def __init__(self, config: dict):
        self.config = config
        self.data_path = Path('data/')
        self.source_dict = {
            'O': 'Sports_and_Outdoors',
            'T': 'Toys_and_Games',
            'B': 'Beauty',
        }
        self.config['source_dict'] = self.source_dict

    def _read_data_from_file(self, domain_code, mode=''):
        """Private helper to read sequence data from a file."""
        domain_name = self.source_dict[domain_code]
        file_path = self.data_path / domain_name / f'{mode}data.txt'
        with file_path.open('r') as file:
            item_seqs = [list(map(int, line.split())) for line in file]
        
        if mode == '':
            flat_list = [item for sublist in item_seqs for item in sublist if sublist]
            item_num = np.max(flat_list) if flat_list else 0
            return item_seqs, item_num
        else:
            return item_seqs

    def load_data(self):
        dataset_code = self.config['dataset_code']
        print(f"Loading data: {self.source_dict.get(dataset_code, 'Unknown')}")

        _, item_num = self._read_data_from_file(dataset_code)
        
        train_seqs = self._read_data_from_file(dataset_code, mode='train_')
        valid_seqs = self._read_data_from_file(dataset_code, mode='valid_')
        test_seqs = self._read_data_from_file(dataset_code, mode='test_')

        train_dataset = SequenceDataset(self.config, train_seqs)
        valid_dataset = SequenceDataset(self.config, valid_seqs)
        test_dataset = SequenceDataset(self.config, test_seqs)

        select_pool = [1, item_num + 1]

        return (
            train_dataset,
            valid_dataset,
            test_dataset,
            select_pool,
            item_num
        )







