import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import pandas as pd
import time
import datetime
from datetime import timedelta
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
import joblib
from scipy.stats import norm
import warnings
import os
warnings.filterwarnings("ignore")

import argparse
parser = argparse.ArgumentParser(description="get prediction")
parser.add_argument("--location", default="OR")
parser.add_argument("--time_now",default='06-18')
parser.add_argument("--model_year",default='2005')
args = parser.parse_args()


model_path = f'./saved models/{args.location}load_{args.model_year}.pt'
data_path = f'./testing/{args.location}_2023-{args.time_now}/'
traindata_path = f'./training/{args.model_year}_seed_60/{args.location}_0.csv'

class LSTM(nn.Module):

    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=6,   
            hidden_size=200,
            num_layers=1, 
            batch_first=True,
        )

        self.ann2 = nn.Sequential(
            nn.Linear(10,1),
        )
        self.out = nn.Sequential(
            nn.Linear(200,11),
        )
        self.state = None

    def forward(self, x, z):

        batch = x.shape[0]
        tx = torch.reshape(x,(batch,336,-1))
        tx, self.state = self.lstm(tx, state)
        tz = torch.reshape(z,(batch,24,-1))
        temp = self.ann2(tz)
        out = self.out(tx[:,-24:,:]+temp)

        out = torch.unsqueeze(out,2)
        out1, out2 = out[:,:,:,0], out[:,:,:,1:]
        out1 = torch.unsqueeze(out1, -1)
        out2 = F.softplus(out2)
        out = torch.cat((out1,out2),dim=-1)
        out = torch.cumsum(out, dim=-1)

        return out

def load_model(model_path):
    model = LSTM()
    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
    model.double()
    model.eval()
    return model

def read_data_cont(df):
    data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   
    data_cont = data_cont.values
    return data_cont
def read_data_target(df):
    data_target = df[['total_grid']]
    data_target = data_target.values
    return data_target
def read_data_time(df):
    data_time = df[['dayofweek','timeofday', 'month']]
    data_time = data_time.values
    return data_time





if not os.path.exists(data_path):
    print(f"Error: Directory {data_path} does not exist.")
    

df_history = pd.read_csv(os.path.join(data_path, f"{args.location}_history.csv"))
df_future = pd.read_csv(os.path.join(data_path, f"{args.location}_future.csv"))
new_col = np.zeros(df_future.shape[0])
df_future.insert(0, 'total_grid', new_col)

if not os.path.exists(traindata_path):
    print(f"Error: Directory {traindata_path} does not exist.")
    

df_raw = pd.read_csv(traindata_path)
rawdata = df_raw.copy()
df_raw = rawdata[(rawdata["year"] == 2022) | ((rawdata["year"] == 2023) & (rawdata["dayofyear"] <= 134))]   
rawdata_cont = read_data_cont(df_raw)
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(rawdata_cont)

print("Data loaded successfully.")

if not os.path.exists(model_path):
    print(f"Error: Model file {model_path} does not exist.")
    
        
model = load_model(model_path)
print("Saved model loaded successfully.")


def read_test(df):
    test_cont = read_data_cont(df)
    test_cont = scaler.transform(test_cont)

    test_time = read_data_time(df)
    cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)
    W = cycl_(test_time[:,0],7)    
    H = cycl_(test_time[:,1],24)   
    M = cycl_(test_time[:,2]-1,12)  
    test_time = np.concatenate((W,H,M),0).T

    test_context = np.concatenate((test_cont,test_time),1)
    
    return test_context

history_context = read_test(df_history)[:,:6]
future_context = read_test(df_future)[:,2:]
test_X = (torch.unsqueeze(torch.tensor(history_context, dtype=torch.double),axis=0))
test_Z = (torch.unsqueeze(torch.tensor(future_context, dtype=torch.double),axis=0))

with torch.no_grad():
    state = None
    preds = model(test_X,test_Z)
    preds = torch.squeeze(preds)

df_pred = pd.DataFrame(preds)
df_pred.columns = ["p00","p10","p20","p30","p40","p50","p60","p70","p80","p90","p100"]
df_pred.index = df_future['localTime']

current_date = datetime.datetime.strptime(args.time_now, "%m-%d")
next_date = current_date + timedelta(days=1)
formatted_next_date = next_date.strftime("%m-%d")
filename = f"preds_{args.location}_data{args.model_year}_{formatted_next_date}.csv" 
export_path = './preds_results'
file_path = os.path.join(export_path, filename)
df_pred.to_csv(file_path, index=True)
print("results saved successfully.")

