################################################################################
# spectral/tsp_comparison.py
#
# 
# 
# 
# 2024
#
# Loads a trained TSP model and compares its results against other algorithms
# for different problem sizes.

import json
import numpy as np
import torch

from argparse     import ArgumentParser
from numpy.typing import NDArray
from tqdm         import tqdm
from typing       import Any, List, Optional, Tuple, Union

from experilog.logger  import Logger, JSONType
from modules.tsp.loss  import nodes_to_edges
from modules.tsp.model import load_from_json_dict as load_tsp

DataLoader = torch.utils.data.DataLoader
Device     = torch.device
Module     = torch.nn.Module
Tensor     = torch.Tensor

ASYMMETRIC = 0
SYMMETRIC  = 1
EUCLIDEAN  = 2

def generate_grids(
    # Arguments:
    batch_size: int,
    grid_size:  int,
    # Keyword Arguments:
    euclidean_dim: int = 2,
    grid_type:     int = ASYMMETRIC
  ) -> Tensor:
  """
  Generates a adjacency grids for TSP. Functional version of the method from
  ``TSPRandomGrids``.

  Args:
    batch_size (int):
      The number of grids to generate.
    grid_size (int):
      The size of the grid.
    euclidean_dim (int, optional):
      If ``grid_type`` is ``EUCLIDEAN`` (2), then this is the dimensions of the
      points generated.
      Defaults to ``2``.
    grid_type (int, optional):
      The type of grid to generate. Either ``ASYMMETRIC`` (0),
      ``SYMMETRIC`` (1), or ``EUCLIDEAN`` (2).
      Defaults to ``ASYMMETRIC``.

  Returns:
    Tensor:
      An unformatted adjacency grid.
  """
  if grid_type == EUCLIDEAN:
    points = torch.rand(
      (batch_size, grid_size, euclidean_dim)
    )
    grid = torch.cdist(points, points)
  else:
    grid = torch.rand(
      (batch_size, grid_size, grid_size)
    )
    if grid_type == SYMMETRIC:
      grid = torch.tril(grid) + torch.tril(grid, -1).mT
  grid = grid / torch.max(grid).clamp(1e-12)
  return grid.unsqueeze(1)

def format_grid(
    # Arguments:
    grid: Tensor,
    # Keyword Arguments:
    diag_dist: float = 1.0
  ) -> Tensor:
  """
  Formats a generated grid into the proper form, with diagonal distances and
  masking. Functional version of the method from ``TSPRandomGrids``.

  Args:
    grid (Tensor):
      The grid to format.
    diag_dist (float, optional):
      The distance along the diagonal.
      Defaults to ``1.0``.

  Returns:
    Tensor:
      The formatted grid.
  """
  grid_size = grid.size(-1)
  tri = torch.tril(torch.ones((grid_size, grid_size)), -1)
  info = tri.mT - tri
  info = torch.repeat_interleave(
    info.view(1, 1, grid_size, grid_size),
    grid.size(0),
    dim = 0
  )
  grid = torch.abs(info) * grid + (1 - torch.abs(info)) * diag_dist
  return torch.concat([grid, info], dim = 1)

def tour_length(
    # Arguments:
    grid: NDArray,
    tour: List[int]
  ) -> float:
  """
  Calculates the length of a tour. Assumes the final edge is back to the
  starting point.

  Args:
    grid (NDArray):
      The grid of distances to operate on.
    tour (list[int]):
      The tour to calculate the distance of.

  Returns:
    float:
      The distance of the tour.
  """
  path = tour + [tour[0]]
  return sum(grid[path[x], path[x + 1]] for x in range(len(path) - 1))


def nearest_neighbor_step(
    # Arguments:
    grid: NDArray,
    curr: List[int]
  ) -> List[int]:
  """
  Performs one recursive step of nearest neighbor.

  Args:
    grid (NDArray):
      The grid of distances to operate on.
    curr (list[int]):
      The current partial tour (path).

  Returns:
    list[int]:
      The full tour from nearest-neighbor.
  """
  if len(curr) == grid.shape[-1]: return curr
  new_grid = grid.copy()
  new_grid[:, curr[-1]] = np.inf
  curr.append(np.argmin(new_grid[curr[-1]]))
  return nearest_neighbor_step(new_grid, curr)

def nearest_neighbor_bound(
    # Arguments:
    grid: NDArray,
    # Keyword Arguments:
    find_best: bool = False
  ) -> Tuple[float, List[int]]:
  """
  Computes nearest-neighbor solutions for each starting node and uses that
  to compute an upper bound.

  Args:
    grid (NDArray):
      The grid of distances to operate on.
    find_best (bool, optional):
      Whether to find the best NN tour (by starting at each node instead of
      a random one).
      Defaults to ``False``.

  Returns:
    tuple[float, list[int]]:
      The tour distance and the tour taken.
  """
  grid = grid.copy()
  grid[np.identity(grid.shape[-1]) > 0.5] = np.inf
  if find_best:
    lowest = (np.inf, [])
    for start in range(grid.shape[-1]):
      tour = nearest_neighbor_step(grid, [start])
      length = tour_length(grid, tour)
      if length < lowest[0]:
        lowest = (length, tour)
  else:
    start = np.random.randint(0, grid.shape[-1])
    tour = nearest_neighbor_step(grid, [start])
    length = tour_length(grid, tour)
    lowest = (length, tour)
  return lowest

def load_model_for_inference(
    # Arguments:
    file: str
  ) -> Tuple[Module, JSONType, Any]:
  """
  Loads a model and its config for inference from the specified file.

  Args:
    file (str):
      The file location for the saved model.

  Returns:
    Tuple[Module, JSONType]:
      The model with the trained weights and the JSON-ready config dictionary.
  """
  contents = torch.load(file, map_location = "cpu")
  config = contents["config"]
  # Exists just to show the pretty error if needed.
  model = load_tsp(config["model"])
  model.load_state_dict(contents["model_state"])
  return (model, config, contents)

def random_permutation(
    # Arguments:
    n: int
  ) -> NDArray:
  x = np.identity(n).astype(np.float32)
  np.random.shuffle(x)
  x = x.T
  np.random.shuffle(x)
  return x.T

def perform_test(
    # Arguments:
    model:       Module,
    logger:      Logger,
    batch_size:  int,
    num_batches: int,
    grid_size:   int,
    # Keyword Arguments:
    device:         Optional[Union[Device, str]] = None,
    diag_dist:      float                        = 1.0,
    euclidean_dim:  int                          = 2,
    max_iterations: Optional[int]                = None,
    grid_type:      int                          = ASYMMETRIC,
    verbose:        bool                         = False
  ) -> None:
  # Device.
  device = torch.device("cpu") if device is None else device
  device = torch.device(device) if isinstance(device, str) else device
  # Max iterations.
  max_iterations = model.max_iterations if max_iterations is None \
                                        else max_iterations
  # Model.
  model.eval()
  measurements = {
    "model": [],
    "bnn": [],
    "nn": [],
    "random": []
  }
  if verbose:
    progress_bar = tqdm(total = num_batches)
  with torch.no_grad():
    for _ in range(num_batches):
      grid = generate_grids(
        batch_size,
        grid_size,
        euclidean_dim = euclidean_dim,
        grid_type     = grid_type
      )
      input_batch = format_grid(grid, diag_dist = diag_dist).to(device)
      predicted_batch = nodes_to_edges(
        model(input_batch, max_iterations = max_iterations)[:, 0]
      )
      measurements["model"].append(
        torch.sum(
          (predicted_batch * input_batch[:, 0]),
          dim = (-1, -2)
        ).detach().cpu().numpy()
      )
      measurements["nn"].append(
        np.array([nearest_neighbor_bound(x[0].numpy(), False)[0] for x in grid])
      )
      measurements["bnn"].append(
        np.array([nearest_neighbor_bound(x[0].numpy(), True)[0] for x in grid])
      )
      randoms = np.stack(
        [random_permutation(grid_size) for _ in range(batch_size)],
        axis = 0
      )
      measurements["random"].append(
        np.sum(
          randoms * input_batch[:, 0].detach().cpu().numpy(),
          axis = (-1, -2)
        )
      )
      if verbose:
        progress_bar.update(1)
  if verbose:
    progress_bar.close()
  if verbose:
    print("Concatenating measurements.")
  measurements = {m: np.concatenate(measurements[m]) for m in measurements}
  # To result.
  if verbose:
    print("Summarizing results.")
  summaries = {m: logger.array_summary(measurements[m]) for m in measurements}
  results = {
    m: {
      #"data": logger.from_numpy(measurements[m]),
      "summary": summaries[m]
    }
    for m in measurements
  }
  if verbose:
    print("Saving results.")
  logger.record_result(results)
  if verbose:
    print("Done.")

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Compares TSP methods.")
  parser.add_argument("model", help = "The model file.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("-d", "--data", default = None, type = str)
  parser.add_argument("-t", "--title", default = None, type = str)
  parser.add_argument("-i", "--iterations", default = None, type = int)
  parser.add_argument("-v", "--verbose", action = "store_true")
  args = parser.parse_args()
  # Get the device.
  if torch.cuda.is_available():
    print("Using CUDA.")
    device = "cuda"
  elif torch.backends.mps.is_available():
    print("Using MPS.")
    device = "mps"
  else:
    print("Using CPU.")
    device = "cpu"
  # Get the model and config.
  model, config, contents = load_model_for_inference(args.model)
  model.to(device)
  # Get dataset.
  if args.data is not None:
    with open(args.data, "r") as file:
      dataset_config = json.loads(file.read())
  else:
    dataset_config = config["testing"]
  # Set up logger.
  logger = Logger(args.log_dir, args.title)
  logger.set_controls({
    "model_config": config,
    "data_config":  dataset_config,
    "iterations":   args.iterations,
    "device":       device,
    "train_info": {
      "timestamp":  contents["timestamp"],
      "epoch":      contents["epoch"],
      "train_loss": contents["train_loss"],
      "valid_loss": contents["valid_loss"]
    }
  })
  logger.start_experiment()
  if args.verbose:
    print(dataset_config)
  # Get results.
  perform_test(
    model,
    logger,
    dataset_config["data"].get("batch_size", 1),
    dataset_config["data"].get("batch_count", 1),
    dataset_config["data"].get("grid_size", 11),
    device         = device,
    diag_dist      = dataset_config["data"].get("diag_dist", 1.0),
    euclidean_dim  = dataset_config["data"].get("euclidean_dim", 2),
    max_iterations = args.iterations,
    grid_type      = dataset_config["data"].get("grid_type", ASYMMETRIC),
    verbose        = args.verbose
  )
  # End experiment.
  logger.stop_experiment()
  logger.write()