################################################################################
# spectral/testing/test_iterations.py
#
# 
# 
# 
# 2024
#
# Performs testing (using the test dataset) of trained models, specifically
# running extended iterations.

import json
import numpy as np
import torch

from argparse import ArgumentParser
from tqdm     import tqdm
from typing   import Optional, Union

from experilog.logger import Logger
from test_model      import *

Module = torch.nn.Module

def perform_test(
    # Arguments:
    model:      Module,
    dataloader: DataLoader,
    logger:     Logger,
    # Keyword Arguments:
    device:         Optional[Union[Device, str]] = None,
    max_iterations: Optional[int]                = None,
    metrics:        Optional[MetricsDict]        = None,
    verbose:        bool                         = False,
  ) -> None:
  # Device.
  device = torch.device("cpu") if device is None else device
  device = torch.device(device) if isinstance(device, str) else device
  # Max iterations.
  max_iterations = model.max_iterations if max_iterations is None \
                                        else max_iterations
  # Metrics.
  metrics = METRICS if metrics is None else {m: METRICS[m] for m in metrics}
  # Model.
  model.eval()
  measurements = {
    str(i): {m: [] for m in metrics}
    for i in range(max_iterations + 1)
  }
  if verbose:
    progress_bar = tqdm(total = len(dataloader))
  with torch.no_grad():
    for input_batch, target_batch in dataloader:
      input_batch  = input_batch.to(device)
      target_batch = target_batch.to(device)
      predicted_batch, phi = model(
        input_batch,
        max_iterations = 0,
        return_thought = True
      )
      for i in range(max_iterations + 1):
        for metric in metrics:
          measurements[str(i)][metric].append(
            metrics[metric](
              predicted_batch,
              target_batch
            ).detach().cpu().numpy()
          )
        predicted_batch, phi = model(
          input_batch,
          max_iterations = 1,
          return_thought = True,
          phi = phi
        )
      if verbose:
        progress_bar.update(1)
  if verbose:
    progress_bar.close()
  if verbose:
    print("Concatenating measurements.")
  measurements = {
    i: {m: np.concatenate(measurements[i][m]) for m in measurements[i]}
    for i in measurements
  }
  if verbose:
    print("Summarizing results.")
  results = {
    i: {
      m: {
        #"data": logger.from_numpy(measurements[m]),
        "summary": logger.array_summary(measurements[i][m])
      }
      for m in measurements[i]
    }
    for i in measurements
  }
  if verbose:
    print("Saving results.")
  logger.record_result(results)
  if verbose:
    print("Done.")

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Tests a trained model.")
  parser.add_argument("model", help = "The model file.")
  parser.add_argument("dataset", help = "The root location of the dataset.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("-d", "--data", default = None, type = str)
  parser.add_argument("-t", "--title", default = None, type = str)
  parser.add_argument("-i", "--iterations", default = None, type = int)
  parser.add_argument("-v", "--verbose", action = "store_true")
  args = parser.parse_args()
  # Get the device.
  if torch.cuda.is_available():
    print("Using CUDA.")
    device = "cuda"
  elif torch.backends.mps.is_available():
    print("Using MPS.")
    device = "mps"
  else:
    print("Using CPU.")
    device = "cpu"
  # Get the model and config.
  model, config, contents = load_model_for_inference(args.model)
  model.to(device)
  # Get dataset.
  if args.data is not None:
    with open(args.data, "r") as file:
      dataset_config = json.loads(file.read())
  else:
    dataset_config = config["testing"]
  dataset_type = get_dataset_type(config["problem"])
  test_dataset = dataset_type(args.dataset, **dataset_config["data"])
  test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    num_workers = 0,
    batch_size  = dataset_config["batch_size"]
  )
  # Set up logger.
  logger = Logger(args.log_dir, args.title)
  logger.set_controls({
    "model_config": config,
    "data_config":  dataset_config,
    "iterations":   args.iterations,
    "device":       device,
    "train_info": {
      "timestamp":  contents["timestamp"],
      "epoch":      contents["epoch"],
      "train_loss": contents["train_loss"],
      "valid_loss": contents["valid_loss"]
    }
  })
  logger.start_experiment()
  # Get results.
  perform_test(
    model,
    test_dataloader,
    logger,
    device         = device,
    max_iterations = args.iterations,
    metrics        = dataset_config.get("metrics"),
    verbose        = args.verbose
  )
  # End experiment.
  logger.stop_experiment()
  logger.write()