import csv
import math
import os
import pathlib
import torch
import urllib.request
import zipfile
import numpy as np 

import pandas as pd 

here = pathlib.Path(__file__).resolve().parent

base_base_loc = here / 'data'



DATA_PATH = os.path.dirname(os.path.abspath(__file__))
def download():
    
    ushcn_data = pd.read_csv(DATA_PATH+'/data/USHCN/small_chunked_sporadic.csv')
    full_id = ushcn_data['ID'].unique() # 1114 
    full_Tensor = torch.Tensor()
    max_timelen = 0 
    min_timelen = np.inf
    less_50 = 0 
    less_100 =0 
    less_150 =0 
    less_200 = 0 
    less_250 = 0 
    less_300 = 0
    more_300 = 0
    for id in full_id:
        id_pd = ushcn_data[ushcn_data.ID==id]
        id_sort = id_pd.sort_values(by=['Time'])
        timelen = len(id_sort.Time)
        
        if max_timelen<timelen :
            
            max_timelen = timelen
    
    for id in full_id:
        id_pd = ushcn_data[ushcn_data.ID==id]
        id_sort = id_pd.sort_values(by=['Time'])
        timelen = len(id_sort.Time)
        if timelen<=250:
            pass 
        else:
            id_sort['Mask_0'] = id_sort['Mask_0'].replace(0,np.NaN)
            id_sort['Mask_1'] = id_sort['Mask_1'].replace(0,np.NaN)
            id_sort['Mask_2'] = id_sort['Mask_2'].replace(0,np.NaN)
            id_sort['Mask_3'] = id_sort['Mask_3'].replace(0,np.NaN)
            id_sort['Mask_4'] = id_sort['Mask_4'].replace(0,np.NaN)
            id_sort['Value_0']=id_sort['Value_0']*id_sort['Mask_0']
            id_sort['Value_1']=id_sort['Value_1']*id_sort['Mask_1']
            id_sort['Value_2']=id_sort['Value_2']*id_sort['Mask_2']
            id_sort['Value_3']=id_sort['Value_3']*id_sort['Mask_3']
            id_sort['Value_4']=id_sort['Value_4']*id_sort['Mask_4']
            id_sort =id_sort.drop(columns=['ID','Time','Mask_0','Mask_1','Mask_2','Mask_3','Mask_4'])
            id_np = id_sort.to_numpy()
            id_tensor= torch.Tensor(id_np)

            if id_tensor.shape[0]<max_timelen:
                
                
                add_tensor = torch.zeros([max_timelen-id_tensor.shape[0],5])+np.NaN
        
                id_tensor_add = torch.cat([id_tensor,add_tensor]).unsqueeze(0)
            else:
                id_tensor_add = id_tensor.unsqueeze(0)
            
            full_Tensor = torch.cat([full_Tensor,id_tensor_add],axis=0)
    
    torch.save(full_Tensor,DATA_PATH+'/data/USHCN/full_data_250.pt')

def _process_data(look_window,forecast_window,stride_window):
    
    full_data= torch.load(DATA_PATH+'/data/USHCN/full_data_250.pt')
    # import pdb ; pdb.set_trace()
    train_data = full_data[:int(full_data.shape[0]*0.7)]
    val_data = full_data[int(full_data.shape[0]*0.7):int(full_data.shape[0]*0.85)]
    test_data = full_data[int(full_data.shape[0]*0.85):]
    timelen = train_data.shape[1]
    train_seq_data = torch.Tensor()
    train_y_seq_data = torch.Tensor()
    
    val_seq_data = torch.Tensor()
    val_y_seq_data = torch.Tensor()
    
    test_seq_data = torch.Tensor()
    test_y_seq_data = torch.Tensor()
    
    for id in range(train_data.shape[0]):
        
        temp_id_data = train_data[id]
        for _len in range(int((timelen-look_window-forecast_window-stride_window)/stride_window)):
            
            train_seq_temp = temp_id_data[(_len*stride_window):(_len*stride_window)+look_window].unsqueeze(0)
            train_y_seq_temp = temp_id_data[(_len*stride_window)+look_window:(_len*stride_window)+look_window+forecast_window].unsqueeze(0)
            
            train_seq_data=torch.cat([train_seq_data,train_seq_temp])
            train_y_seq_data = torch.cat([train_y_seq_data,train_y_seq_temp])
    
    DATA_PATH_SAVE = DATA_PATH + '/processed_data/USHCN_look_'+str(look_window)+'_forecast_'+str(forecast_window)+'_stride_'+str(stride_window)
    torch.save(train_seq_data,DATA_PATH_SAVE+'/train_seq_data.pt')
    
    times = torch.Tensor(np.arange(look_window))
    torch.save(times,DATA_PATH_SAVE+'/times.pt')
    # torch.save(train_seq_data,loc+'/train_seq_data.pt')
    
    torch.save(train_y_seq_data,DATA_PATH_SAVE+'/train_y_seq_data.pt')
    for id in range(val_data.shape[0]):
        
        temp_id_data = val_data[id]
        for _len in range(int((timelen-look_window-forecast_window-stride_window)/stride_window)):
            
            val_seq_temp = temp_id_data[(_len*stride_window):(_len*stride_window)+look_window].unsqueeze(0)
            val_y_seq_temp = temp_id_data[(_len*stride_window)+look_window:(_len*stride_window)+look_window+forecast_window].unsqueeze(0)
            # import pdb ; pdb.set_trace()
            val_seq_data=torch.cat([val_seq_data,val_seq_temp])
            val_y_seq_data = torch.cat([val_y_seq_data,val_y_seq_temp])
    # import pdb ; pdb.set_trace()
    torch.save(val_seq_data,DATA_PATH_SAVE+'/val_seq_data.pt')
    torch.save(val_y_seq_data,DATA_PATH_SAVE+'/val_y_seq_data.pt')
    
    for id in range(test_data.shape[0]):
        
        temp_id_data = test_data[id]
        for _len in range(int((timelen-look_window-forecast_window-stride_window)/stride_window)):
            
            test_seq_temp = temp_id_data[(_len*stride_window):(_len*stride_window)+look_window].unsqueeze(0)
            test_y_seq_temp = temp_id_data[(_len*stride_window)+look_window:(_len*stride_window)+look_window+forecast_window].unsqueeze(0)
            # import pdb ; pdb.set_trace()
            test_seq_data=torch.cat([test_seq_data,test_seq_temp])
            test_y_seq_data = torch.cat([test_y_seq_data,test_y_seq_temp])
    # import pdb ; pdb.set_trace()
    torch.save(test_seq_data,DATA_PATH_SAVE+'/test_seq_data.pt')
    torch.save(test_y_seq_data,DATA_PATH_SAVE+'/test_y_seq_data.pt')
    


def get_data(look_window,forecast_window,stride_window):
    
    base_base_loc = here / 'processed_data'
    loc = base_base_loc / ('USHCN' + '_look_'+str(look_window)+'_forecast_'+str(forecast_window)+'_stride_'+str(stride_window))
    
    if os.path.exists(loc):
        pass
    else:
        # download()
        if not os.path.exists(base_base_loc):
            os.mkdir(base_base_loc)
        if not os.path.exists(loc):
            os.mkdir(loc)
        _process_data(look_window,forecast_window,stride_window)
        
    return loc
