"""The code to launch multiple training runs."""


import os
import os.path as osp
import argparse

import numpy as np
import torch

from torch.utils.data import TensorDataset

from experiments.hyperparameters_dict import hyperparameters_dict
from experiments.metrics_dict import metrics_dict
from models.models_dict import models_dict


parser = argparse.ArgumentParser()
parser.add_argument(
  "--model",
  type=str,
  choices=list(models_dict.keys()),
  default="ours"
)
parser.add_argument(
  "--dataset",
  type=str,
  default="cube"
)
parser.add_argument("--run_init", type=int, default=1)  # Lets us resume if we pause.
parser.add_argument("--num_repeats", type=int, default=5)
parser.add_argument("--device", type=str, default="0")
args = parser.parse_args()


# Set the device.
if args.device == "cpu":
  device = torch.device("cpu")
else:
  device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")


# Load the data, based on if it is our model or not.
data_path = osp.join("datasets", "data", args.dataset)
dataset_dict = torch.load(osp.join(data_path, "dataset_dict.pt"))

dataset_x_ending = "_cdf" if args.model == "ours" else "_std"

X_train = torch.load(osp.join(data_path, f"X_train{dataset_x_ending}.pt"))
y_train = torch.load(osp.join(data_path, f"y_train.pt"))
M_train = torch.load(osp.join(data_path, f"M_train.pt"))

X_val = torch.load(osp.join(data_path, f"X_val{dataset_x_ending}.pt"))
y_val = torch.load(osp.join(data_path, f"y_val.pt"))
M_val = torch.load(osp.join(data_path, f"M_val.pt"))

train_data = TensorDataset(X_train, y_train, M_train)
val_data = TensorDataset(X_val, y_val, M_val)


# Create the path and directory.
path = osp.join("experiments", "saved_models", args.dataset, args.model)
os.makedirs(path, exist_ok=True)


# Get predictive metric function.
metric_f = metrics_dict[dataset_dict["metric"]]


# Construct the dictionary based on hyperparameters and dataset information.
config = hyperparameters_dict[args.dataset][args.model]
for key in dataset_dict.keys():
  config[key] = dataset_dict[key]
torch.save(config, osp.join(path, f"config.pt"))

for repeat in range(args.run_init, args.num_repeats+1):
  rpt_path = osp.join(path, f"repeat_{repeat}")
  os.makedirs(rpt_path, exist_ok=True)

  print(f"\n\nRepeat {repeat} out of {args.num_repeats}")

  # Set the seed for consistency.
  seed = 8406*repeat + 383
  np.random.seed(seed)
  torch.manual_seed(seed)

  # Setup and train the model.
  model = models_dict[args.model](config).to(device)
  model.fit(train_data, val_data, rpt_path, metric_f)