import collections as co
import numpy as np
import os
import pathlib

import torch
import urllib.request
import zipfile
import time_dataset


here = pathlib.Path(__file__).resolve().parent

DATA_PATH = os.path.dirname(os.path.abspath(__file__))

def _process_data(missing_rate,look_window,forecast_window,stride_window):
    PATH = os.path.dirname(os.path.abspath(__file__))
    
    torch.__version__
    
    X_times = np.loadtxt(PATH+"/data/Google/google_challenge.csv", delimiter=",", skiprows=1)
    X_times = time_dataset.normalize(X_times)
    total_length = len(X_times)
    timelen = X_times.shape[0]
    
    full_seq_data = torch.Tensor()
    full_y_seq_data = torch.Tensor()
    for _len in range(int((timelen-look_window-forecast_window+stride_window)/stride_window)):
        
        full_seq_temp = torch.Tensor([X_times[(_len*stride_window):(_len*stride_window)+look_window]])
        full_y_seq_temp = torch.Tensor([X_times[(_len*stride_window)+look_window:(_len*stride_window)+look_window+forecast_window]])
        
        full_seq_data=torch.cat([full_seq_data,full_seq_temp])
        full_y_seq_data = torch.cat([full_y_seq_data,full_y_seq_temp])
    
    generator = torch.Generator().manual_seed(56789)
    
    for Xi in full_seq_data:
        removed_points = torch.randperm(full_seq_data.size(1), generator=generator)[:int(full_seq_data.size(1) * missing_rate)].sort().values
        Xi[removed_points] = float('nan')
    
    
    train_seq_data = full_seq_data[:int(full_seq_data.shape[0]*0.7)]
    train_y_seq_data = full_y_seq_data[:int(full_y_seq_data.shape[0]*0.7)]
    val_seq_data = full_seq_data[int(full_seq_data.shape[0]*0.7):int(full_seq_data.shape[0]*0.85)]
    val_y_seq_data = full_y_seq_data[int(full_seq_data.shape[0]*0.7):int(full_seq_data.shape[0]*0.85)]
    
    test_seq_data = full_seq_data[int(full_seq_data.shape[0]*0.85):]
    test_y_seq_data = full_y_seq_data[int(full_seq_data.shape[0]*0.85):]
    DATA_PATH_SAVE = DATA_PATH + '/processed_data/TGoogle0910_look_'+str(look_window)+'_forecast_'+str(forecast_window)+'_stride_'+str(stride_window)+'_Missing_Rate_'+str(missing_rate)
    torch.save(train_seq_data,DATA_PATH_SAVE+'/train_seq_data.pt')
    torch.save(train_y_seq_data,DATA_PATH_SAVE+'/train_y_seq_data.pt')
    times = torch.Tensor(np.arange(look_window))
    torch.save(times,DATA_PATH_SAVE+'/times.pt')
    torch.save(val_seq_data,DATA_PATH_SAVE+'/val_seq_data.pt')
    torch.save(val_y_seq_data,DATA_PATH_SAVE+'/val_y_seq_data.pt')
    torch.save(test_seq_data,DATA_PATH_SAVE+'/test_seq_data.pt')
    torch.save(test_y_seq_data,DATA_PATH_SAVE+'/test_y_seq_data.pt')
    

def get_data(missing_rate,look_window,forecast_window,stride_window):
    
    base_base_loc = here / 'processed_data'
    loc = base_base_loc / ('TGoogle0910_look_'+str(look_window)+'_forecast_'+str(forecast_window)+'_stride_'+str(stride_window)+'_Missing_Rate_'+str(missing_rate))
    
    if os.path.exists(loc):
        pass
    else:
        if not os.path.exists(base_base_loc):
            os.mkdir(base_base_loc)
        if not os.path.exists(loc):
            os.mkdir(loc)
        _process_data(missing_rate,look_window,forecast_window,stride_window)
        
    return loc