import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import pandas as pd
import time
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import os
from os import listdir
from os.path import isfile, join
import joblib

def read_data_cont(task_type, df):
    
    if task_type == 'SONNET':
        data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]    
    elif task_type == 'SONNET_base':
        data_cont = df[['total_grid','DNI','DHI','temperature','relativehumidity']]

    data_cont = data_cont.values
    return data_cont

def create_datasets_onechunk(configs, scaler, df):
    location = configs.location
    gap = configs.gap
    look_back = configs.look_back
    pred_len = configs.pred_len
    truncate = configs.truncate
    total_len = gap+look_back+pred_len
    task_type = configs.task_type
        
    data_cont = read_data_cont(task_type, df)
    data_time = df[['weekday','dayofyear', 'timeofday', 'month','year']]
     
    data_time = data_time.values
    data_target = df[['total_grid']]
    data_target = data_target.values
    data_cont = scaler.transform(data_cont)

    
    if location != "TX":
        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)
        W = cycl_(data_time[:,0],7)    # week of day
        H = cycl_(data_time[:,2],24)   # timeslot of the day
        M = cycl_(data_time[:,3],12)   # month of year
        data_time = np.concatenate((W,H,M),0).T
        data_context = np.concatenate((data_cont,data_time),1)
    else:
        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)
        W = cycl_(data_time[:,0],7)    # week of day
        D = cycl_(data_time[:,1],365)  # day of year
        H = cycl_(data_time[:,2],24)   # timeslot of the day
        M = cycl_(data_time[:,3],12)   # month of year
        Y = torch.tensor(data_time[:,4] - 2022).unsqueeze(0) # year
        data_time = np.concatenate((W,D,H,M,Y),0).T
        data_context = np.concatenate((data_cont,data_time),1)
  
    
    def Create_dataset(data_context,data_target):
        hist_input, future_input, target = [], [], []
        for i in range(len(data_context)-total_len): 
            tempx = data_context[i:i+look_back,:]
            tempy = data_context[i+look_back+gap:i+total_len,configs.masked:]
            tempy_target = data_target[i+look_back+gap:i+total_len,0]

            hist_input.append(tempx)
            future_input.append(tempy)
            target.append(tempy_target)

        hist_input = np.array(hist_input)
        future_input = np.array(future_input)
        target = np.expand_dims(np.array(target),-1)

        return hist_input, future_input, target
    hist_input, future_input, target = Create_dataset(data_context,data_target)
    
    future_input = np.concatenate((np.zeros((future_input.shape[0],future_input.shape[1],configs.masked)),future_input), axis=-1)
    input_data = np.concatenate((hist_input, future_input), axis=1)
    output_data = target.copy()
    input_data, output_data = np.swapaxes(input_data,1,2), np.swapaxes(output_data,1,2)
    
    if truncate:
        N = input_data.shape[0] - gap*2
        input_data = input_data[:N]
        output_data = output_data[:N]
    
    return input_data, output_data

def create_datasets_location(configs):
    location = configs.location
    gap = configs.gap
    look_back = configs.look_back
    pred_len = configs.pred_len
    truncate = configs.truncate
    total_len = gap+look_back+pred_len
    task_type = configs.task_type
    valid = configs.valid
    
    datapath = './datasets/processed-dataset/'
    datasets = [f for f in os.listdir(datapath) if isfile(join(datapath, f))]
    selected_datasets = [f for f in datasets if location in f]
    
    data_conts = []
    data_conts_valid = []
    dfs = []
    dfs_valid = []
    for selected_dataset in selected_datasets:
        df = pd.read_csv(join(datapath, selected_dataset))
        if valid:
            data = df.copy()            
            df = data[(data["year"] == 2022) | ((data["year"] == 2023) & (data["dayofyear"] <= 134))]
            df_valid = data[(data["year"] == 2023) & (data["dayofyear"] >= 121)]
        
        df['date'] = pd.to_datetime(df['year'].astype(str) + df['dayofyear'].astype(str), format='%Y%j')
        df['weekday'] = df['date'].dt.weekday
        for i in set(df["train_group"]):
            df_sub = df[df["train_group"] == i]
            dfs.append(df_sub)
            data_cont = read_data_cont(task_type, df_sub)
            data_conts.append(data_cont)
        
        if valid:
            df_valid['date'] = pd.to_datetime(df_valid['year'].astype(str) + df_valid['dayofyear'].astype(str), format='%Y%j')
            df_valid['weekday'] = df_valid['date'].dt.weekday
            
            for i in set(df_valid["valid_group"]):
                df_sub = df_valid[df_valid["valid_group"] == i]
                dfs_valid.append(df_sub)
        
    data_cont_cat = np.concatenate(data_conts, axis=0)
    
    scaler = StandardScaler()
    scaler.fit(data_cont_cat)
    joblib.dump(scaler, configs.path+location+'_L='+str(look_back)+'_gap='+str(gap)+'_'+task_type+'_scaler.gz')
    
    input_datas, output_datas = [], []

    for df in dfs:
        if df.shape[0] > total_len:
            input_data, output_data = create_datasets_onechunk(configs, scaler, df)
            input_datas.append(input_data)
            output_datas.append(output_data)
            
    if valid:
        input_datas_valid, output_datas_valid = [], []
        for df in dfs_valid:
            if df.shape[0] > total_len:
                input_data, output_data = create_datasets_onechunk(configs, scaler, df)
                input_datas_valid.append(input_data)
                output_datas_valid.append(output_data)

    input_datas = np.concatenate(input_datas, axis=0)
    output_datas = np.concatenate(output_datas, axis=0)
    
    if valid:
        input_datas_validation = np.concatenate(input_datas_valid, axis=0)
        output_datas_validation = np.concatenate(output_datas_valid, axis=0)
        return input_datas, output_datas, input_datas_validation, output_datas_validation
    else:
        return input_datas, output_datas, input_datas, output_datas
    
if __name__=="__main__":    
    class ModelParams:
        location = "HI"
        gap = 48
        look_back = 336
        pred_len = 24
        truncate = False
        task_type = 'all-N'

    configs = ModelParams
    input_datas, output_datas = create_datasets_location(configs)
    print(input_datas.shape, output_datas.shape)