import numpy as np
import torch
import os
import pandas as pd




def split_data_by_ratio(data, val_ratio, test_ratio):
    data_len = data.shape[1]
    test_data = data[:,-int(data_len*test_ratio):]
    val_data = data[:,-int(data_len*(test_ratio+val_ratio)):-int(data_len*test_ratio)]
    train_data = data[:,:-int(data_len*(test_ratio+val_ratio))]
    return train_data, val_data, test_data




class StandardScaler:

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        #print(self.mean,self.std)
        return (data - self.mean) / (self.std+10e-5)

    def inverse_transform(self, data):
        if type(data) == torch.Tensor and type(self.mean) == np.ndarray:
            self.std = torch.from_numpy(self.std).to(data.device).type(data.dtype)
            self.mean = torch.from_numpy(self.mean).to(data.device).type(data.dtype)
            #print('success')
        return (data * self.std) + self.mean


def get_normalized_data(data_name: str):



    if data_name[:5] == "solar":
        data_dir = "./data/solar-cl"
        data_name = data_name+'.npy'
        X = np.load(os.path.join(data_dir,data_name))
        X.astype(np.float32) 
        X = np.transpose(X) 
    
    if data_name[:8] == "PEMS3-CL":
        data_dir = "./data/PEMS3-cl"
        data_name = data_name+'.npy'
        X = np.load(os.path.join(data_dir,data_name))
        X.astype(np.float32)
        X = np.transpose(X) 

    
    if data_name[:3] == "har":
        data_dir = "./data/har-cl"
        data_name = data_name+'.npy'
        X = np.load(os.path.join(data_dir,data_name))
        X.astype(np.float32)
        X = np.reshape(X, (-1,9))
        X = np.transpose(X) 

    if data_name[:3] == "syn":
        data_dir = "./data/synthetic-cl"
        data_name = data_name+'.npy'
        X = np.load(os.path.join(data_dir,data_name))
        X.astype(np.float32)
        X = np.reshape(X, (-1,10))
        X = np.transpose(X) 
    
    data_train, data_val, data_test = split_data_by_ratio(X, val_ratio = 0.2, test_ratio = 0.2)
    mean = data_train.mean(axis=1, keepdims=True)
    std = data_train.std(axis=1, keepdims=True)


    scaler = StandardScaler(mean, std)
    data_train_normalized = scaler.transform(data_train)
    data_val_normalized = scaler.transform(data_val)
    data_test_normalized = scaler.transform(data_test)

    
    data ={}
    data['train'] = data_train;data['train_normalized'] = data_train_normalized
    data['val'] = data_val;data['val_normalized'] = data_val_normalized
    data['test'] = data_test;data['test_normalized'] = data_test_normalized

  

    return data,scaler



def add_window(data,name,lag = 32,horizon = 12):

    
    indices = [ (i, i + (lag + horizon))for i in range(data.shape[1] - (lag + horizon) + 1)]

   

    features, target = [], []
    for i, j in indices:
        features.append((data[:, i : i + lag]))
        target.append((data[:, i + lag : j]))
    
    features = torch.from_numpy(np.array(features))
    target = torch.from_numpy(np.array(target)) #(num_seq, num_nodes, in_dim, seq_len)

    return features,target



### dataloader function


class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
        callback_get_label func: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # define custom callback
        self.callback_get_label = callback_get_label

        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        # distribution of classes in the dataset 
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
                
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        return dataset[idx][0].item()
                
    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples


def data_loader(year,X,Y,X_normalized,Y_normalized,batch_size, shuffle=True, drop_last=True):
   
    cuda = True if torch.cuda.is_available() else False
    
    # if cuda:
    #     TensorFloat = torch.cuda.FloatTensor
    # else:
    # TensorFloat = torch.FloatTensor 


    # X = TensorFloat(X)
    # Y = TensorFloat(Y)
    # X_normalized = TensorFloat(X_normalized)
    # Y_normalized = TensorFloat(Y_normalized)
   
    data = torch.utils.data.TensorDataset(year,X,Y,X_normalized,Y_normalized)
    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size,sampler=ImbalancedDatasetSampler(data), drop_last=drop_last,pin_memory=True)
    return dataloader




def get_forcasting_dataloader(year,window_set,batch_size = 32):

    loaders = {}

    LST = ['train','val','test']

    for state in LST:
       
       if state == 'train':
         loaders[state] = data_loader(year[state],window_set[state+'_X'],window_set[state+'_Y'],window_set[state+'_normalized_X'],window_set[state+'_normalized_Y'], batch_size = batch_size, shuffle=True, drop_last=True)
       else:
         loaders[state] = data_loader(year[state],window_set[state+'_X'],window_set[state+'_Y'],window_set[state+'_normalized_X'],window_set[state+'_normalized_Y'], batch_size = batch_size, shuffle=False, drop_last=True)
     

    return loaders


