import torch
from torch import nn
import torch.nn.functional as F
import argparse
import os.path

from Models2.Scinet import config, SCINet, TimeSeriesIterableDatasetFinetuning
from metrics import rmse

device = torch.device('cuda')

parser = argparse.ArgumentParser(
                    prog='Scinet_finetune',
                    description='Scinet AR1248')
parser.add_argument("--lag", help="The context size",
                    type=int)
parser.add_argument("--horizon", help="The horizon size",
                    type=int)
parser.add_argument("--index", help="The name of the index used for finetuning",
                    type=str)
parser.add_argument("--n_epochs", help="The maximum number of epochs",
                    type=int, default=1000)

args = parser.parse_args()

lag=args.lag
h=args.horizon
index=args.index
n_epochs = args.n_epochs

name="scinet"

patience = 100

min_delta = 10**-8

dir_data = f'Data/{index}_{lag}/horizon{h}'

dir_save = f'runs_nofinetuning/{name}_{index}_lag{lag}_h{h}'

config = {
    "output_len" : 1, #Horizon
    "input_len" : lag - 2, #Lag
    "input_dim" : 5, #Input dimension
    "hid_size" : 128,
    "num_stacks" : 1, #Num of SCINet
    "num_levels" : 1,
    "num_decoder_layer" : 1,
    "concat_len" : 0,
    "groups" : 1,
    "kernel" : 5, #Kernel of convolution
    "dropout" : 0.0, #Dropout
    "single_step_output_One" : 0,
    "input_len_seg" : 0,
    "positionalE" : True, #Positional Encoding
    "modified" : False,
    "RIN" : False,
    "path" : "Weights/Scinet_128"
}

learning_rate = 1e-4

train_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='train')
train_dataset = torch.utils.data.DataLoader(train_data, batch_size=64)

val_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='val')
val_dataset = torch.utils.data.DataLoader(val_data, batch_size=64)

model = SCINet(config)

loss_fn = torch.nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

load_path = "Weights/Scinet_128";
if lag == 360:
    load_path += "_lag360"

output_file = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{h}"
if os.path.isfile(output_file):
    exit(0)
                            
model.finetuning(
    train_data = train_dataset, 
    epochs = n_epochs, 
    optimizer = optimizer, 
    loss_fn = loss_fn, 
    metric = rmse, 
    device = device, 
    target = 3,
    patience = patience,
    min_delta = min_delta,
    val_data = val_dataset, 
    log_dir = dir_save, 
    load_path = load_path, 
    save_path = output_file
)
