import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader



class HSI_train(Dataset):
    def __init__(self, args, train_set, trn_split, epoch_sum_num, id=None):
        super(HSI_train, self).__init__()
        self.train_set = train_set
        self.args = args
        self.id = id
        self.trn_split = trn_split
        self.epoch_sum_num = epoch_sum_num

        if self.trn_split == 0: # no-split
            self.usr_trn_sz = len(self.train_set)
        elif self.trn_split == 1: # even split
            self.usr_trn_sz = int(np.floor(len(self.train_set)/self.args.num_clients))
        elif self.trn_split == 2: # ratio split
            assert len(self.args.trn_split_ratio) == self.args.num_clients, 'len of [trn_split_ratio] should be consistent with [num_clients]!'
            assert len(self.train_set) % sum(self.args.trn_split_ratio) ==0, 'Only support the dividable split'
            self.usr_trn_sz = int(len(self.train_set)/sum(self.args.trn_split_ratio)*self.args.trn_split_ratio[self.id])
        else:
            raise ValueError

    def __len__(self):
        return self.epoch_sum_num

    def __getitem__(self, item):
        if self.trn_split==0:
            img = self.train_set[ (item+1)%len(self.train_set) ]

        elif self.trn_split==1:

            img = self.train_set[ (item+1)%self.usr_trn_sz + self.id * self.usr_trn_sz]

        elif self.trn_split==2:
            idx = int((item ) % self.usr_trn_sz + len(self.train_set)/sum(self.args.trn_split_ratio)*sum(self.args.trn_split_ratio[:self.id]))
            img = self.train_set[idx]

        else:
            raise ValueError


        return img

