import torch
from torch import nn
import torch.nn.functional as F
import argparse
import os.path

from Models2.AR_t import config, AR_Transformer, TimeSeriesIterableDatasetFinetuning
from metrics import rmse

device = torch.device('cuda')


parser = argparse.ArgumentParser(
                    prog='AR312_finetune',
                    description='Finetune AR312')
parser.add_argument("--lag", help="The context size",
                    type=int)
parser.add_argument("--horizon", help="The horizon size",
                    type=int)
parser.add_argument("--index", help="The name of the index used for finetuning",
                    type=str)
parser.add_argument("--n_layers_freeze", help="The number of layers to freeze",
                    type=int, default=0)
parser.add_argument("--n_epochs", help="The maximum number of epochs",
                    type=int, default=1000)

args = parser.parse_args()

lag=args.lag
h=args.horizon
index=args.index
n_layers_freeze = args.n_layers_freeze
n_epochs = args.n_epochs

name="AR312"

patience = 100

min_delta = 10**-8

dir_data = f'Data/{index}_{lag}/horizon{h}'

dir_save = f'runs_nofinetuning/{name}_{index}_lag{lag}_h{h}'

config['n_features'] = 5
config['lag'] = lag - 1
config["latent_dim"] = 312
config["num_layers"] = 12
config['n_head'] = 24
config["dropout"] = 0.0
config["path"] = "Weights/"+dir_save

learning_rate = 1e-4

train_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='train')
train_dataset = torch.utils.data.DataLoader(train_data, batch_size=8)

val_data = TimeSeriesIterableDatasetFinetuning(dir_data, phase='val')
val_dataset = torch.utils.data.DataLoader(val_data, batch_size=64)

model = AR_Transformer(config)

if n_layers_freeze > 0:
	for param in model.embedding.parameters():
		param.requires_grad = False
	for param in model.positional_encoder.parameters():
		param.requires_grad = False
	for param in model.encoder_layer.parameters():
		param.requires_grad = False
	for i in range(0, min(n_layers_freeze, len(model.transformer_encoder.layers))):
		for param in model.transformer_encoder.layers[i].parameters():
			param.requires_grad = False
        
loss_fn = torch.nn.MSELoss()

# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

load_path = "Weights/dim312_layers12_head24";
if lag == 360:
    load_path += "_lag360"

output_file = f"Weights_nofinetuning/{name}_{index}_lag{lag}_h{h}"
if os.path.isfile(output_file):
    exit(0)
                            
model.finetuning(
    train_data = train_dataset, 
    epochs = n_epochs, 
    optimizer = optimizer, 
    loss_fn = loss_fn, 
    metric = rmse, 
    device = device, 
    target = 3,
    patience = patience,
    min_delta = min_delta,
    val_data = val_dataset, 
    log_dir = dir_save, 
    load_path = load_path, 
    save_path = output_file
)
