## Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F


class PowerIndexDataset(Dataset):
    ''' Custom dataset class to work with Shapley values or Banzhaf indices 
    X, Y in shape (n_samples, n_players)
    '''
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        ''' get total number of samples in dataset '''
        return self.X.shape[0]
    
    def __getitem__(self, index):
        ''' get 1D tensor of weights and respective payoffs'''
        return ( 
            self.X[index, :].float(), 
            self.Y[index, :].float()
        )

class PayoffsDataset(Dataset):
    ''' Custom dataset class for Least core payoffs 
    '''
    def __init__(self, X_stack: torch.Tensor, Y_stack: torch.Tensor):
        self.X_stack = X_stack
        self.Y_stack = Y_stack
        #self.C_min_win = C_min_win

    def __len__(self):
        return self.X_stack.shape[0]
    
    def __getitem__(self, index):

        return ( 
            self.X_stack[index, :].float(), 
            self.Y_stack[index, :].float(),
        )
