import torch
from torch import nn
import torch.nn.functional as F
import argparse
import os.path

from  Models2.DA_RNN import config, DualAttentionRNN, TimeSeriesIterableDatasetFinetuning
from metrics import rmse

device = torch.device('cuda')

parser = argparse.ArgumentParser(
                    prog='DARNN_finetune',
                    description='Finetune DARNN')
parser.add_argument("--lag", help="The context size",
                    type=int)
parser.add_argument("--horizon", help="The horizon size",
                    type=int)
parser.add_argument("--index", help="The name of the index used for finetuning",
                    type=str)
parser.add_argument("--n_epochs", help="The maximum number of epochs",
                    type=int, default=1000)

args = parser.parse_args()

lag=args.lag
h=args.horizon
index=args.index
n_epochs = args.n_epochs

name="DARNN"

patience = 100

min_delta = 10**-8

dir_data = f'Data/{index}_{lag}/horizon{h}'

dir_save = f'runs_nofinetuning/{name}_{index}_lag{lag}_h{h}'

config["n_features"] = 4
config["lag"] = lag - 1
config["lstm_units_encoder"]=128
config["lstm_units_decoder"]=128
config["path"] = "Weights/"+dir_save

learning_rate = 1e-4

train_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='train')
train_dataset = torch.utils.data.DataLoader(train_data, batch_size=64)

val_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='val')
val_dataset = torch.utils.data.DataLoader(val_data, batch_size=64)

model = DualAttentionRNN(config)

loss_fn = torch.nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
                            
load_path = "Weights/DARNN_128";
if lag == 360:
    load_path += "_lag360"

output_file = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{h}"
if os.path.isfile(output_file):
    exit(0)
                            
model.finetuning(
    train_data = train_dataset, 
    epochs = n_epochs, 
    optimizer = optimizer, 
    loss_fn = loss_fn, 
    metric = rmse, 
    device = device, 
    target = 3,
    patience = patience,
    min_delta = min_delta,
    val_data = val_dataset, 
    log_dir = dir_save, 
    load_path = load_path, 
    save_path = output_file
)
