import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from utils import normalize

class CaliforniaHousing:
    def __init__(self, cfg, role, role_id):
        self.role = role
        self.role_id = role_id
        self.data_path = os.path.join(cfg['dataset_path'], 'features', role, f'{role_id}.npz')
        with np.load(self.data_path) as data:
            self.features = normalize(data['features'])
            self.labels = normalize(data['labels'].reshape(-1, 1)).ravel()
            # self.labels = data['labels']
            X_large, X_small, y_large, y_small = train_test_split(
                self.features, self.labels, test_size=0.1, random_state=cfg['seed']
            )
            self.X_large = torch.from_numpy(X_large)
            self.X_small = torch.from_numpy(X_small)
            self.y_large = torch.from_numpy(y_large)
            self.y_small = torch.from_numpy(y_small)


    def get_loader(self, split=None):
        return self.get_data(split)

    def get_data(self, split):
        return eval(f'self.{split}_data()')

    def specification_data(self):
        if self.role == 'learnware':
            return self.X_large, self.y_large
        return self.X_small, self.y_small

    def train_data(self):
        assert self.role == 'learnware', "Train data is only available for learnware."
        return self.X_large, self.y_large

    def eval_data(self):
        return self.test_data()

    def test_data(self):
        if self.role == 'learnware':
            return self.X_small, self.y_small
        return self.X_large, self.y_large