################################################################################
# spectral/testing/test_norm.py
#
# 
# 
# 
# 2024
#
# Performs testing (using the test dataset) of trained models, measuring the
# L2-norm of thought module outputs.

import json
import numpy as np
import torch

from argparse import ArgumentParser
from tqdm     import tqdm
from typing   import Optional, Union

from experilog.logger import Logger
from test_model      import *

Module = torch.nn.Module

def perform_test(
    # Arguments:
    model:      Module,
    dataloader: DataLoader,
    logger:     Logger,
    # Keyword Arguments:
    device:         Optional[Union[Device, str]] = None,
    max_iterations: Optional[int]                = None,
    mode:           str                          = "whole",
    verbose:        bool                         = False,
  ) -> None:
  # Device.
  device = torch.device("cpu") if device is None else device
  device = torch.device(device) if isinstance(device, str) else device
  # Max iterations.
  max_iterations = model.max_iterations if max_iterations is None \
                                        else max_iterations
  # Model.
  model.eval()
  measurements = {
    str(i): [] for i in range(max_iterations + 1)
  }
  if verbose:
    progress_bar = tqdm(total = len(dataloader))
  with torch.no_grad():
    for input_batch, target_batch in dataloader:
      input_batch  = input_batch.to(device)
      target_batch = target_batch.to(device)
      _, phi = model(
        input_batch,
        max_iterations = 0,
        return_thought = True
      )
      for i in range(max_iterations + 1):
        if mode == "whole":
          measurements[str(i)] += [
            np.linalg.norm(x) for x in phi.detach().cpu().numpy()
          ]
        else:
          measurements[str(i)] += np.ravel(np.mean(
            np.linalg.norm(
              phi.detach().cpu().numpy(), axis = 1
            ).reshape((phi.size(0), -1)),
            axis = 1
          )).tolist()
        _, phi = model(
          input_batch,
          max_iterations = 1,
          return_thought = True,
          phi = phi
        )
      if verbose:
        progress_bar.update(1)
  if verbose:
    progress_bar.close()
  measurements = {i: np.array(measurements[i]) for i in measurements}
  # To result.
  summaries = {
    i: logger.array_summary(measurements[i]) for i in measurements
  }
  results = {
    i: {
      "summary": summaries[i]
    }
    for i in measurements
  }
  logger.record_result(results)

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Tests a trained model.")
  parser.add_argument("model", help = "The model file.")
  parser.add_argument("dataset", help = "The root location of the dataset.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("-d", "--data", default = None, type = str)
  parser.add_argument("-t", "--title", default = None, type = str)
  parser.add_argument("-i", "--iterations", default = None, type = int)
  parser.add_argument("-m", "--mode", default = "whole", type = str)
  parser.add_argument("-v", "--verbose", action = "store_true")
  args = parser.parse_args()
  # Get the device.
  if torch.cuda.is_available():
    print("Using CUDA.")
    device = "cuda"
  elif torch.backends.mps.is_available():
    print("Using MPS.")
    device = "mps"
  else:
    print("Using CPU.")
    device = "cpu"
  # Get the model and config.
  model, config, contents = load_model_for_inference(args.model)
  model.to(device)
  # Get dataset.
  if args.data is not None:
    with open(args.data, "r") as file:
      dataset_config = json.loads(file.read())
  else:
    dataset_config = config["testing"]
  dataset_type = get_dataset_type(config["problem"])
  test_dataset = dataset_type(args.dataset, **dataset_config["data"])
  test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    num_workers = 0,
    batch_size  = dataset_config["batch_size"]
  )
  # Set up logger.
  logger = Logger(args.log_dir, args.title)
  logger.set_controls({
    "model_config": config,
    "data_config":  dataset_config,
    "iterations":   args.iterations,
    "device":       device,
    "train_info": {
      "timestamp":  contents["timestamp"],
      "epoch":      contents["epoch"],
      "train_loss": contents["train_loss"],
      "valid_loss": contents["valid_loss"]
    }
  })
  logger.start_experiment()
  # Get results.
  perform_test(
    model,
    test_dataloader,
    logger,
    device         = device,
    max_iterations = args.iterations,
    mode           = args.mode,
    verbose        = args.verbose
  )
  # End experiment.
  logger.stop_experiment()
  logger.write()