import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import numpy as np
import pandas as pd


##### TimeDataset #####
##### Processing data for Dataloader

class TimeDataset(Dataset):
    def __init__(self, raw_data, edge_index, transform, mode='train', config = None):
        self.raw_data = raw_data
        self.config = config
        self.edge_index = edge_index
        self.mode = mode

        x_data = raw_data[:-1]
        labels = raw_data[-1]

        data = x_data
        data = transform.transform(data)
        
        data = torch.tensor(data).double()
        labels = torch.tensor(labels).double()
        
        self.x, self.y, self.labels = self.process(data, labels)
    
    def __len__(self):
        return len(self.x)

    def process(self, data, labels):
        x_arr, y_arr = [], []
        labels_arr = []
        slide_win, slide_stride = [self.config[k] for k
            in ['slide_win', 'slide_stride']
        ]
        node_num, total_time_len = data.shape

        rang = range(slide_win, total_time_len)

        for i in rang:

            ft = data[:, i-slide_win:i]
            tar = data[:, i]

            x_arr.append(ft)
            y_arr.append(tar)

            labels_arr.append(labels[i])
        
        x = torch.stack(x_arr).contiguous()
        y = torch.stack(y_arr).contiguous()

        labels = torch.stack(labels_arr).contiguous()
        
        return x, y, labels


    def __getitem__(self, idx):

        feature = self.x[idx].double()
        y = self.y[idx].double()

        edge_index = self.edge_index.long()

        label = self.labels[idx].double()

        return feature, y, label, edge_index
    
    
    
##### Transform #####
##### 1. Divide Continuous and Discrete features
##### 2. One-hot embedding for Discrete features

class Transform(Dataset):
    def __init__(self, raw_data):
        self.raw_data = raw_data
        x_data = raw_data
        data = x_data
        self.output_info = []
        self.meta, self.conti_ind, self.cate_ind = self.get_metadata(data)

    def get_metadata(self, data):
        meta = []
        self.max_size = 0

        df = data
        cate_ind = []
        conti_ind = []
        
        for ith in range(data.shape[1]):
            if len(set(data.iloc[:,ith])) > 16 or len(set(data.iloc[:,ith]))==1:
                conti_ind.append(ith)
            else:
                cate_ind.append(ith)
        
        categorical_columns = cate_ind

        for index, col_name in enumerate(data):
            column = data[col_name]
            if index in categorical_columns:
                mapper = column.value_counts().index.tolist()
                if self.max_size < len(mapper):
                    self.max_size = len(mapper)


        for index, col_name in enumerate(df):
            column = df[col_name]

            if index in cate_ind:
                mapper = column.value_counts().index.tolist()
                meta.append({
                    "name": col_name,
                    "type": "CATEGORICAL",
                    "size": len(mapper),
                    "i2s": mapper
                })
            else: 
                meta.append({
                    "name": col_name,
                    "type": "CONTINUOUS",
                    "min": column.min(),
                    "max": column.max(),
                })
                
        return meta, conti_ind, cate_ind

    def col_index(self):
        return self.conti_ind, self.cate_ind, self.output_info

    def transform(self, data):
        # data : [51, 40000]
        data = torch.tensor(data).double()
        data_t = []
        self.output_info = []
        for id_, info in enumerate(self.meta):
            col = data[id_, :]                                  # 40000
            if info['type'] == "CONTINUOUS":
                data_t.append(col.reshape([1, -1]))             # 1  40000
                self.output_info.append(1)
            else:
                col_t = np.zeros([self.max_size,data.shape[1]])
                col = torch.round(col, decimals=6)
                idx = list(map(info['i2s'].index, col))         # 40000
                col_t[idx, np.arange(data.shape[1])] = 1        # size  40000
                data_t.append(col_t)                            # size  40000
                #self.output_info.append(self.max_size)
                self.output_info.append(info['size'])           

        return np.concatenate(data_t, axis=0)                   # 51  size  40000


    def inverse_transform(self, data):
        data_t = np.zeros([len(data), len(self.meta)])          # 40000  51
        data = data.copy()                                      # 40000  93
        for id_, info in enumerate(self.meta):
            if info['type'] == "CONTINUOUS":
                current = data[:, 0]                            # 40000,
                data = data[:, 1:]                              # 40000  92
                data_t[:, id_] = current                        # 40000  1
            else:   
                current = data[:, :self.max_size]               # 40000  size
                data = data[:, self.max_size:]                  # 40000  93-size
                idx = np.argmax(current, axis=1)                # 40000  size
                data_t[:, id_] = list(map(info['i2s'].__getitem__, idx)) # 40000  1
        return data_t
    
