import os
import os.path as osp

import argparse

import numpy as np
import torch
import torch.nn.functional as F

from scipy import stats

from experiments.metrics_dict import metrics_dict
from models.models_dict import models_dict
from models.xgboost_predictor import XGBoostPredictor



parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="cube")
parser.add_argument("--device", type=str, default="0")
args = parser.parse_args()



# Set the device.
if args.device == "cpu":
  device = torch.device("cpu")
else:
  device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")


# Get the best first feature, either by knowing (in case of Invase) or by
# choosing the most common best first feature from the fixed MLP fixed orders.
if args.dataset == "invase_4" or args.dataset == "invase_5" or args.dataset == "invase_6":
  best_first_feature = 10
else:
  # Get best first feature from the mlp fixed orders.
  best_features = []
  mlp_path = osp.join("experiments", "saved_models", args.dataset, "fixed_mlp")
  for repeat in range(1, 5+1):
    fixed_order_scores = torch.load(osp.join(mlp_path, f"repeat_{repeat}", "fixed_order_scores.pt"))
    best_features.append(torch.argmax(fixed_order_scores).item())
  best_first_feature = stats.mode(np.array(best_features), keepdims=False).mode




def get_acquisition_array(dataset_name, ablation_name, first_acquisition="active"):
  dataset_path = osp.join("datasets", "data", dataset_name)
  metric_f = metrics_dict[torch.load(osp.join(dataset_path, "dataset_dict.pt"))["metric"]]

  # Load in an XGBoost model, train one if it does not exist.
  xgb_predictor = XGBoostPredictor()
  xgb_path = osp.join("experiments", "saved_models", dataset_name, "xgb_predictor")
  try:
    xgb_predictor.load(xgb_path)
  except:
    print("Training an XGBoost model from scratch.\n")
    os.makedirs(xgb_path, exist_ok=True)
    X_train = torch.load(osp.join(dataset_path, "X_train_std.pt"))
    y_train = torch.load(osp.join(dataset_path, "y_train.pt"))
    M_train = torch.load(osp.join(dataset_path, "M_train.pt"))
    train_set = (X_train, y_train, M_train)
    X_val = torch.load(osp.join(dataset_path, "X_val_std.pt"))
    y_val = torch.load(osp.join(dataset_path, "y_val.pt"))
    M_val = torch.load(osp.join(dataset_path, "M_val.pt"))
    val_set = (X_val, y_val, M_val)
    xgb_predictor.fit(train_set, val_set, xgb_path, num_subsamples=10)

  # Load in the MLP model.
  mlp_path = osp.join("experiments", "saved_models", dataset_name, "fixed_mlp")
  mlp_predictor = models_dict["fixed_mlp"](torch.load(osp.join(mlp_path, "config.pt"))).to(device)
  mlp_predictor.load(osp.join(mlp_path, "repeat_6"))

  # Load in our model.
  our_path = osp.join("experiments", "saved_models", dataset_name, "ours")
  our_predictor = models_dict["ours"](torch.load(osp.join(our_path, "config.pt"))).to(device)
  our_predictor.load(osp.join(our_path, "repeat_6"))

  # Load in test data.
  X_test_std = torch.load(osp.join(dataset_path, "X_test_std.pt")).to(device)
  X_test_cdf = torch.load(osp.join(dataset_path, "X_test_cdf.pt")).to(device)
  y_test = torch.load(osp.join(dataset_path, "y_test.pt")).to(device)
  M_test = torch.load(osp.join(dataset_path, "M_test.pt")).to(device)
  X_test_model = X_test_cdf.to(device)

  # Create the model.
  model_path = osp.join("experiments", "saved_models", dataset_name, f"ours_{ablation_name}")
  model = models_dict["ours"](torch.load(osp.join(model_path, "config.pt"))).to(device)

  own_metric = []
  xgb_metric = []
  mlp_metric = []
  our_metric = []

  own_preds = []
  xgb_preds = []
  mlp_preds = []
  our_preds = []

  selected = []

  for repeat in range(1, 5+1):
    print(f"Acquisition: {first_acquisition}, Ablation: {ablation_name}, Repeat: {repeat}/5")
    model.load(osp.join(model_path, f"repeat_{repeat}"))

    own_metric_tmp = []
    xgb_metric_tmp = []
    mlp_metric_tmp = []
    our_metric_tmp = []

    own_preds_tmp = []
    xgb_preds_tmp = []
    mlp_preds_tmp = []
    our_preds_tmp = []

    selected_tmp = []


    mask = torch.zeros_like(M_test)


    def append_preds_metrics(selection=None):
      own_preds_curr = model.predict(X_test_model, mask).detach()
      xgb_preds_curr = xgb_predictor.predict(X_test_std, mask).detach()
      mlp_preds_curr = mlp_predictor.predict(X_test_std, mask).detach()
      our_preds_curr = our_predictor.predict(X_test_cdf, mask).detach()

      own_metric_tmp.append(metric_f(own_preds_curr, y_test))
      xgb_metric_tmp.append(metric_f(xgb_preds_curr, y_test))
      mlp_metric_tmp.append(metric_f(mlp_preds_curr, y_test))
      our_metric_tmp.append(metric_f(our_preds_curr, y_test))

      own_preds_tmp.append(own_preds_curr.cpu())
      xgb_preds_tmp.append(xgb_preds_curr.cpu())
      mlp_preds_tmp.append(mlp_preds_curr.cpu())
      our_preds_tmp.append(our_preds_curr.cpu())

      if selection is not None:
        selected_tmp.append(selection.detach().cpu())


    # Results with no features.
    append_preds_metrics(None)

    # Carry out the first acquisition based on function input.
    if first_acquisition == "active":
      mask, selection = model.acquire(X_test_model, mask, M_test, True)
    elif first_acquisition == "fixed":
      mask[:, best_first_feature] = 1.0  # NOTE might not work with some missing features.
      selection = torch.full_like(mask[:, 0], best_first_feature)
    elif first_acquisition == "random":
      scores = torch.rand_like(M_test)*M_test
      selection = torch.argmax(scores, dim=-1)
      mask = F.one_hot(selection, M_test.shape[-1]).float()
    append_preds_metrics(selection)

    # Carry out the remaining acquisitions actively.
    for _ in range(M_test.shape[-1]-1):
      mask, selection = model.acquire(X_test_model, mask, M_test, True)
      append_preds_metrics(selection)

    # Append to repeated arrays.
    own_metric.append(torch.tensor(own_metric_tmp))
    xgb_metric.append(torch.tensor(xgb_metric_tmp))
    mlp_metric.append(torch.tensor(mlp_metric_tmp))
    our_metric.append(torch.tensor(our_metric_tmp))

    own_preds.append(torch.stack(own_preds_tmp, dim=0))
    xgb_preds.append(torch.stack(xgb_preds_tmp, dim=0))
    mlp_preds.append(torch.stack(mlp_preds_tmp, dim=0))
    our_preds.append(torch.stack(our_preds_tmp, dim=0))

    selected.append(torch.stack(selected_tmp, dim=0))

  # Stack across the repeats and return.
  own_metric = torch.stack(own_metric, dim=0).float()
  xgb_metric = torch.stack(xgb_metric, dim=0).float()
  mlp_metric = torch.stack(mlp_metric, dim=0).float()
  our_metric = torch.stack(our_metric, dim=0).float()

  own_preds = torch.stack(own_preds, dim=0).float()
  xgb_preds = torch.stack(xgb_preds, dim=0).float()
  mlp_preds = torch.stack(mlp_preds, dim=0).float()
  our_preds = torch.stack(our_preds, dim=0).float()

  selected = torch.stack(selected, dim=0).long()

  return {
    "own_metric": own_metric,
    "xgb_metric": xgb_metric,
    "mlp_metric": mlp_metric,
    "our_metric": our_metric,
    "own_preds": own_preds,
    "xgb_preds": xgb_preds,
    "mlp_preds": mlp_preds,
    "our_preds": our_preds,
    "selected": selected,
  }




if __name__ == "__main__":

  # Create save folder.
  os.makedirs(osp.join("experiments", "results", "ablations"), exist_ok=True)

  if args.dataset == "invase_4" or args.dataset == "invase_5" or args.dataset == "invase_6":
    ablations_list = [
      "ib",
      "train_sample",
    ]
  else:
    ablations_list = [
      "ib",
      "train_sample",
    ]

  print("Loading Original Data")
  original_results = torch.load(osp.join("experiments", "results", f"{args.dataset}.pt"))

  results = {}

  # Run all the acquisitions in series. We parallelize by the dataset in screen sessions.
  for first_acquisition in ["active", "random", "fixed"]:
    results[first_acquisition] = {}

    results[first_acquisition]["metrics"] = {}
    results[first_acquisition]["predictions"] = {}
    results[first_acquisition]["selections"] = {}

    results[first_acquisition]["metrics"]["own"] = {}
    results[first_acquisition]["metrics"]["xgb"] = {}
    results[first_acquisition]["metrics"]["mlp"] = {}
    results[first_acquisition]["metrics"]["ours"] = {}

    results[first_acquisition]["predictions"]["own"] = {}
    results[first_acquisition]["predictions"]["xgb"] = {}
    results[first_acquisition]["predictions"]["mlp"] = {}
    results[first_acquisition]["predictions"]["ours"] = {}

    for ablation_name in ablations_list:
      with torch.no_grad():
        ablation_results = get_acquisition_array(args.dataset, ablation_name, first_acquisition)

      results[first_acquisition]["metrics"]["own"][ablation_name] = ablation_results["own_metric"]
      results[first_acquisition]["metrics"]["xgb"][ablation_name] = ablation_results["xgb_metric"]
      results[first_acquisition]["metrics"]["mlp"][ablation_name] = ablation_results["mlp_metric"]
      results[first_acquisition]["metrics"]["ours"][ablation_name] = ablation_results["our_metric"]

      results[first_acquisition]["predictions"]["own"][ablation_name] = ablation_results["own_preds"]
      results[first_acquisition]["predictions"]["xgb"][ablation_name] = ablation_results["xgb_preds"]
      results[first_acquisition]["predictions"]["mlp"][ablation_name] = ablation_results["mlp_preds"]
      results[first_acquisition]["predictions"]["ours"][ablation_name] = ablation_results["our_preds"]

      results[first_acquisition]["selections"][ablation_name] = ablation_results["selected"]

      print("")

    results[first_acquisition]["metrics"]["own"]["ours"] = original_results[first_acquisition]["metrics"]["own"]["ours"]
    results[first_acquisition]["metrics"]["xgb"]["ours"] = original_results[first_acquisition]["metrics"]["xgb"]["ours"]
    results[first_acquisition]["metrics"]["mlp"]["ours"] = original_results[first_acquisition]["metrics"]["mlp"]["ours"]
    results[first_acquisition]["metrics"]["ours"]["ours"] = original_results[first_acquisition]["metrics"]["ours"]["ours"]

    results[first_acquisition]["predictions"]["own"]["ours"] = original_results[first_acquisition]["predictions"]["own"]["ours"]
    results[first_acquisition]["predictions"]["xgb"]["ours"] = original_results[first_acquisition]["predictions"]["xgb"]["ours"]
    results[first_acquisition]["predictions"]["mlp"]["ours"] = original_results[first_acquisition]["predictions"]["mlp"]["ours"]
    results[first_acquisition]["predictions"]["ours"]["ours"] = original_results[first_acquisition]["predictions"]["ours"]["ours"]

    results[first_acquisition]["selections"]["ours"] = original_results[first_acquisition]["selections"]["ours"]

  # Save the results.
  torch.save(results, osp.join("experiments", "results", "ablations", f"{args.dataset}.pt"))