import numpy as np
import torch

from ..utils import map_elementwise

class Data:
    '''Standard data format. 
    '''
    def __init__(self, X_train=None, y_train=None, X_test=None, y_test=None):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        
        self.__device = None
        self.__dtype = None
    
    def get_batch(self, batch_size):
        @map_elementwise
        def batch_mask(X, num):
            #print(X.shape)
            return np.random.choice(X.size(0), num, replace=False)
        @map_elementwise
        def batch(X, mask):
            if self.__device == 'cpu':
                return X[mask]
            elif self.__device == 'gpu':
                return X[mask].cuda()
        if batch_size == None:
            if self.__device == 'cpu':
                return self.X_train, self.y_train
            elif self.__device == 'gpu':
                return self.X_train.cuda(), self.y_train.cuda()
        else:
            mask = batch_mask(self.X_train, batch_size)
            return batch(self.X_train, mask), batch(self.y_train, mask)
        
    def get_batch_test(self, batch_size):
        @map_elementwise
        def batch_mask(X, num):
            return np.random.choice(X.size(0), num, replace=False)
        @map_elementwise
        def batch(X, mask):
            if self.__device == 'cpu':
                return X[mask]
            elif self.__device == 'gpu':
                return X[mask].cuda()
        if batch_size == None:
            if self.__device == 'cpu':
                return self.X_test, self.y_test
            elif self.__device == 'gpu':
                return self.X_test.cuda(), self.y_test.cuda()
        else:
#             mask = batch_mask(self.X_train, batch_size)
#             return batch(self.X_train, mask), batch(self.y_train, mask)
            mask = batch_mask(self.X_test, batch_size)
            return batch(self.X_test, mask), batch(self.y_test, mask)
    
    @property
    def device(self):
        return self.__device
    
    @property
    def dtype(self):
        return self.__dtype
    
    @device.setter    
    def device(self, d):
        self.__to_cpu()
        self.__device = d
    
    @dtype.setter     
    def dtype(self, d):
        if d == 'float':
            self.__to_float()
            self.__dtype = torch.float32
        elif d == 'double':
            self.__to_double()
            self.__dtype = torch.float64
        else:
            raise ValueError
    
    @property
    def dim(self):
        if isinstance(self.X_train, np.ndarray):
            return self.X_train.shape[-1]
        elif isinstance(self.X_train, torch.Tensor):
            return self.X_train.size(-1)
    
    @property
    def K(self):
        if isinstance(self.y_train, np.ndarray):
            return self.y_train.shape[-1]
        elif isinstance(self.y_train, torch.Tensor):
            return self.y_train.size(-1)
    
    @property
    def X_train_np(self):
        return Data.to_np(self.X_train)
    
    @property
    def y_train_np(self):
        return Data.to_np(self.y_train)
    
    @property
    def X_test_np(self):
        return Data.to_np(self.X_test)
    
    @property
    def y_test_np(self):
        return Data.to_np(self.y_test)
    
    @staticmethod
    @map_elementwise
    def to_np(d):
        if isinstance(d, np.ndarray) or d is None:
            return d
        elif isinstance(d, torch.Tensor):
            return d.cpu().detach().numpy()
        else:
            raise ValueError
    
    def __to_cpu(self):
        @map_elementwise
        def trans(d):
            if isinstance(d, np.ndarray):
                return torch.DoubleTensor(d)
            elif isinstance(d, torch.Tensor):
                return d.cpu()
        for d in ['X_train', 'y_train', 'X_test', 'y_test']:
            setattr(self, d, trans(getattr(self, d)))
    
    def __to_gpu(self):
        @map_elementwise
        def trans(d):
            if isinstance(d, np.ndarray):
                return torch.cuda.DoubleTensor(d)
            elif isinstance(d, torch.Tensor):
                return d.cuda()
        for d in ['X_train', 'y_train', 'X_test', 'y_test']:
            setattr(self, d, trans(getattr(self, d)))
    
    def __to_float(self):
        if self.device is None: 
            raise RuntimeError('device is not set')
        @map_elementwise
        def trans(d):
            if isinstance(d, torch.Tensor):
                return d.float()
        for d in ['X_train', 'y_train', 'X_test', 'y_test']:
            setattr(self, d, trans(getattr(self, d)))
    
    def __to_double(self):
        if self.device is None: 
            raise RuntimeError('device is not set')
        @map_elementwise
        def trans(d):
            if isinstance(d, torch.Tensor):
                return d.double()
        for d in ['X_train', 'y_train', 'X_test', 'y_test']:
            setattr(self, d, trans(getattr(self, d)))
