"""The code to launch ablations for our method."""


import os
import os.path as osp
import argparse

import numpy as np
import torch

from torch.utils.data import TensorDataset

from experiments.hyperparameters_dict import hyperparameters_dict
from experiments.metrics_dict import metrics_dict
from models.models_dict import models_dict


parser = argparse.ArgumentParser()
parser.add_argument(
  "--ablation",
  type=str,
  choices=["ib", "train_sample"],
  default="ib"
)
parser.add_argument(
  "--dataset",
  type=str,
  default="cube"
)
parser.add_argument("--run_init", type=int, default=1)  # Lets us resume if we pause.
parser.add_argument("--num_repeats", type=int, default=5)
parser.add_argument("--device", type=str, default="0")
args = parser.parse_args()


# Set the device.
if args.device == "cpu":
  device = torch.device("cpu")
else:
  device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")


# Load the data, based on if it is our model or not.
data_path = osp.join("datasets", "data", args.dataset)
dataset_dict = torch.load(osp.join(data_path, "dataset_dict.pt"))

X_train = torch.load(osp.join(data_path, "X_train_cdf.pt"))
y_train = torch.load(osp.join(data_path, "y_train.pt"))
M_train = torch.load(osp.join(data_path, "M_train.pt"))

X_val = torch.load(osp.join(data_path, "X_val_cdf.pt"))
y_val = torch.load(osp.join(data_path, "y_val.pt"))
M_val = torch.load(osp.join(data_path, "M_val.pt"))

train_data = TensorDataset(X_train, y_train, M_train)
val_data = TensorDataset(X_val, y_val, M_val)


# Create the path and directory.
path = osp.join("experiments", "saved_models", args.dataset, f"ours_{args.ablation}")
os.makedirs(path, exist_ok=True)


# Get predictive metric function.
metric_f = metrics_dict[dataset_dict["metric"]]


# Construct the dictionary based on hyperparameters and dataset information.
config = hyperparameters_dict[args.dataset]["ours"]
for key in dataset_dict.keys():
  config[key] = dataset_dict[key]

if args.ablation == "ib":
  config["ib_beta"] = 0.0
elif args.ablation == "train_sample":
  config["num_samples_train"] = 1
torch.save(config, osp.join(path, f"config.pt"))



for repeat in range(args.run_init, args.num_repeats+1):
  rpt_path = osp.join(path, f"repeat_{repeat}")
  os.makedirs(rpt_path, exist_ok=True)

  print(f"\n\nRepeat {repeat} out of {args.num_repeats}")

  # Set the seed for consistency.
  seed = 8406*repeat + 383
  np.random.seed(seed)
  torch.manual_seed(seed)

  # Setup and train the model.
  model = models_dict["ours"](config).to(device)
  model.fit(train_data, val_data, rpt_path, metric_f)