import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import os
import pandas as pd
import numpy as np
import argparse

from Models.AR_t import config, AR_Transformer
from metrics import rmse, mape

device = torch.device('cuda')

parser = argparse.ArgumentParser(
                    prog='AR1248_finetune',
                    description='Finetune AR1248')
parser.add_argument("--lag", help="The context size",
                    type=int)
parser.add_argument("--horizon", help="The horizon size",
                    type=int)
parser.add_argument("--index", help="The name of the index used for finetuning",
                    type=str)

args = parser.parse_args()

lag=args.lag
horizon=args.horizon
index=args.index

print(f"Dataset : {index}, lag : {lag}, horizon : {horizon}")

if os.path.exists(f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}')==False:
    os.mkdir(f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}')

data_path = f'Data/{index}_{lag}/test/horizon{horizon}'
days = os.listdir(data_path)

class TimeSeriesDataset(Dataset):
    def __init__(self, data, names, day):
        super(TimeSeriesDataset).__init__()
        self.data = data
        self.names = names
        self.day = day
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx], self.names[idx]

first_obs = pd.DataFrame({"open" : 1.0, "high" : 1.0, "low" : 1.0, "close" : 1.0, "volume" : 1.0, "abs_close" : 1.0}, index=[0])
data_days = []


last = np.ones(lag)

print("Last")

last_mape=0
last_rmse=0
count=0
for d in days:
    series = os.listdir(data_path+'/'+d)
    data=[]
    names=[]
    for s in series:
        try:
            df = pd.read_csv(data_path+'/'+d+'/'+s)
        except pd.errors.EmptyDataError:
            continue
        new_df = pd.concat((first_obs, df)).reset_index().drop("index", axis=1)
        if len(new_df)!=lag:
            continue
        data.append(new_df.values)
        names.append(s)
        last_mape += mape(torch.tensor(last), torch.tensor(new_df.values[:,3]))
        last_rmse += rmse(torch.tensor(last), torch.tensor(new_df.values[:,3]))
        count+=1
    data_days.append(TimeSeriesDataset(data, names, d))

print("RMSE : ", round(last_rmse/count, 8))
print('MAPE : ', round(last_mape/count, 8))
print(" ")

##############AR312
name = 'AR312'
print(name)

config['n_features'] = 5
config['lag'] = lag-1  
config["latent_dim"] = 312
config["num_layers"] = 12
config['n_head'] = 24
config["dropout"] = 0.0
config["path"] = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{horizon}"

model = AR_Transformer(config)

model.predict_index(
    data_days = data_days, 
    metric = rmse, 
    target = 3, 
    device = device,
    save_path = f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}/{name}'
)
print(" ")

###########AR1248
name = 'AR1248'
print(name)

config['n_features'] = 5
config['lag'] = lag-1  
config["latent_dim"] = 1248
config["num_layers"] = 24
config['n_head'] = 24
config["dropout"] = 0.0
config["path"] = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{horizon}"

model = AR_Transformer(config)

model.predict_index(
    data_days = data_days, 
    metric = rmse, 
    target = 3, 
    device = device,
    save_path = f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}/{name}'
)
print(" ")


if lag < 80:
	###########AR2048
	name = 'AR2048'
	print(name)

	config['n_features'] = 5
	config['lag'] = lag -1
	config["latent_dim"] = 2048
	config["num_layers"] = 32
	config['n_head'] = 32
	config["dropout"] = 0.0
	config["path"] = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{horizon}"

	model = AR_Transformer(config)

	model.predict_index(
	    data_days = data_days, 
	    metric = rmse, 
	    target = 3, 
	    device = device,
	    save_path = f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}/{name}'
	)
	print(" ")

###########SCINET
from Models.Scinet import config, SCINet

name = "scinet"
print(name)

config = {
    "output_len" : 1, #Horizon
    "input_len" : lag - 2, #Lag
    "input_dim" : 5, #Input dimension
    "hid_size" : 128,
    "num_stacks" : 1, #Num of SCINet
    "num_levels" : 1,
    "num_decoder_layer" : 1,
    "concat_len" : 0,
    "groups" : 1,
    "kernel" : 5, #Kernel of convolution
    "dropout" : 0.0, #Dropout
    "single_step_output_One" : 0,
    "input_len_seg" : 0,
    "positionalE" : True, #Positional Encoding
    "modified" : False,
    "RIN" : False,
    "path" : f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{horizon}"
}

model = SCINet(config)

model.predict_index(
    data_days = data_days, 
    metric = rmse, 
    target = 3, 
    device = device,
    save_path = f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}/{name}'
)
print(" ")

#########DARNN
from  Models.DA_RNN import config, DualAttentionRNN

name = "DARNN"
print(name)

config["n_features"] = 4
config["lag"] = lag-1
config["lstm_units_encoder"]=128
config["lstm_units_decoder"]=128
config["path"] = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{horizon}"

model = DualAttentionRNN(config)

model.predict_index(
    data_days = data_days, 
    metric = rmse, 
    target = 3, 
    device = device,
    save_path = f'Predictions_nofinetuning/{index}_lag{lag}_horizon{horizon}/{name}'
)
print(" ")

