import torch
from torch.utils.data import DataLoader, TensorDataset,  Dataset
from torchvision import transforms

from sklearn.preprocessing import StandardScaler

import json, os

class FEMNISTDataset(Dataset):
    def __init__(self, data_path, scaler=None, load_tensors=False):
        self.data_path = data_path
        
        if load_tensors:
            client_tensors = [torch.load(os.path.join(data_path, f)) for f in os.listdir(data_path) if f.endswith('.pt')]
            self.images = [tensor[:,:-1] for tensor in client_tensors]
            self.labels = [tensor[:,-1].to(dtype=torch.int64) for tensor in client_tensors]
        else:
            # Load the data from the JSON file
            with open(data_path, 'r') as f:
                self.data = json.load(f)

            # Flatten and normalize the images
            user_data = self.data['user_data']
            with torch.no_grad():
                self.images = [torch.tensor(user_data[user]['x'], dtype=torch.float64).reshape(-1, 28* 28)for user in self.data['users']]
                self.labels = [torch.tensor(user_data[user]['y'], dtype=torch.int64) for user in self.data['users']]

                self.images = [client_X[client_y < 10] for client_X, client_y in zip(self.images, self.labels)]
                self.labels = [client_y[client_y < 10] for client_y in self.labels]

                self.labels = [client_y for client_X, client_y in zip(self.images, self.labels) if client_X.size()[0] > 0]
                self.images = [client_X for client_X in self.images if client_X.size()[0] > 0]

                if not scaler:
                    self.scaler = StandardScaler()
                    self.X = torch.tensor(self.scaler.fit_transform(torch.concatenate(self.images)))
                else:
                    self.scaler = scaler
                    self.X = torch.tensor(self.scaler.transform(torch.concatenate(self.images)))

                self.images = [torch.tensor(self.scaler.transform(X), dtype=torch.float64) for X in self.images]
                self.images = [torch.hstack([X, torch.ones(X.size()[0],1, dtype=torch.float64)]) for X in self.images]

        with torch.no_grad():
            self.X = torch.concatenate(self.images)
            self.y = torch.concatenate(self.labels)
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        images = self.images[idx]
        labels = self.labels[idx]

        return images, labels
    
    def get_full_loader(self, batch_size=1024):
        return DataLoader(TensorDataset(self.X, self.y), batch_size=batch_size, shuffle=True)
    
    def get_client_loaders(self, batch_size=1024):
        return [DataLoader(TensorDataset(X, y), batch_size=batch_size, shuffle=True) for X, y in zip(self.images, self.labels)]