"""The code to launch multiple hyperparameter tuning."""


import os
import os.path as osp
import argparse
import copy

import numpy as np
import torch

from torch.utils.data import TensorDataset

from experiments.metrics_dict import metrics_dict
from experiments.tuning.trial_hyperparams import trial_hyperparams
from models.models_dict import models_dict


parser = argparse.ArgumentParser()
parser.add_argument(
  "--model",
  type=str,
  choices=list(models_dict.keys()),
  default="ours"
)
parser.add_argument(
  "--dataset",
  type=str,
  default="cube"
)
parser.add_argument(
  "--configs", 
  type=str,
  choices = [
    "123",
    "456",
    "789",
  ],
  default="123"
)
parser.add_argument("--num_repeats", type=int, default=3)
parser.add_argument("--device", type=str, default="0")
args = parser.parse_args()


# Make temp path for checkpoints to avoid errors.
path = osp.join("experiments", "tuning", "tmp", args.dataset, f"{args.model}{args.configs}")
os.makedirs(path, exist_ok=True)


# Set the device.
if args.device == "cpu":
  device = torch.device("cpu")
else:
  device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")


# Load the data, based on if it is our model or not.
data_path = osp.join("datasets", "data", args.dataset)
dataset_dict = torch.load(osp.join(data_path, "dataset_dict.pt"))

dataset_x_ending = "_cdf" if args.model == "ours" else "_std"

X_train = torch.load(osp.join(data_path, f"X_train{dataset_x_ending}.pt"))
y_train = torch.load(osp.join(data_path, f"y_train.pt"))
M_train = torch.load(osp.join(data_path, f"M_train.pt"))

X_val = torch.load(osp.join(data_path, f"X_val{dataset_x_ending}.pt"))
y_val = torch.load(osp.join(data_path, f"y_val.pt"))
M_val = torch.load(osp.join(data_path, f"M_val.pt"))

train_data = TensorDataset(X_train, y_train, M_train)
val_data = TensorDataset(X_val, y_val, M_val)


# Get predictive metric function.
metric_f = metrics_dict[dataset_dict["metric"]]


# Construct the dictionary based on hyperparameters and dataset information.
def get_aucs(config):
  aucs = []
  for key in dataset_dict.keys():
    config[key] = dataset_dict[key]

  for repeat in range(1, args.num_repeats+1):
    print(f"\n\nRepeat {repeat} out of {args.num_repeats}")
    for file in os.listdir(path):
      os.remove(osp.join(path, file))

    # Set the seed for consistency.
    seed = 1690*repeat + 241
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Setup and train the model.
    model = models_dict[args.model](config).to(device)
    model.fit(train_data, val_data, path, metric_f)
    aucs.append(torch.load(osp.join(path, "val_auc.pt")))

  aucs = np.array(aucs)
  return aucs


# Create dictionary of aucs, this is the random search.
hyperparams_dict = trial_hyperparams[args.model][args.configs]
hyperparams_dict_no_dataset = copy.deepcopy(hyperparams_dict)
aucs_dict = {}
for key in hyperparams_dict.keys():
  aucs_dict[key] = get_aucs(hyperparams_dict[key])


# Print the results.
print("\n\n\nTuning resuts:\n")
best_mean = 0
for key in aucs_dict.keys():
  aucs = aucs_dict[key]
  mean = np.mean(aucs)
  std_err = np.std(aucs) / np.sqrt(len(aucs))
  print(f"{key}: {mean:.3f} +- {std_err:.3e}")
  if mean > best_mean:
    best_mean = mean
    best_std_err = std_err
    best_key = key
print(f"\n\nBest: {key}, {best_mean:.3f} +- {best_std_err:.3e}\n")
print(hyperparams_dict_no_dataset[best_key])


# Save the results in a text file to be used later.
path = osp.join("experiments", "tuning", "results", args.dataset)
os.makedirs(path, exist_ok=True)

with open(osp.join(path, f"{args.model}{args.configs}.txt"), "w") as f:
  for key in aucs_dict.keys():
    aucs = aucs_dict[key]
    mean = np.mean(aucs)
    std_err = np.std(aucs) / np.sqrt(len(aucs))
    f.write(f"{key} AUC: {mean:.3f} +- {std_err:.3e}\n")
  for key in aucs_dict.keys():
    f.write(f"\n\n{key} = ")
    f.write("{\n")
    for k in hyperparams_dict_no_dataset[key].keys():
      f.write(f"  \"{k}\": {hyperparams_dict_no_dataset[key][k]},\n")
    f.write("}")
  f.close()


# Delete the checkpoints from the tuning.
path = osp.join("experiments", "tuning", "tmp", args.dataset, f"{args.model}{args.configs}")
for file in os.listdir(path):
  os.remove(osp.join(path, file))
os.rmdir(path) 