"""Enhanced model evaluation script to generate a comprehensive classification report
with metrics, visualizations, and system information.
"""
__author__ = 'XYZ'

import os
import glob
import time

from collections import Counter
from datetime import datetime
from pathlib import Path

from tqdm import tqdm

try:
  import torch
  import torch.multiprocessing as mp
  mp.set_sharing_strategy('file_system')

  from torchvision import transforms
except ImportError:
  print('torch or torchvision is not installed')


from .core._log_ import logger
log = logger(__file__)

from .dataset import (
  DataSet,
  get_dataloader,
  get_splits,
  calculate_mean_std,
)
from .utils.modelstats import (
  get_device,
  gather_system_info,
  model_summary,
  model_perfstats,
  save_metrics,
  save_model_summary,
  save_model_perfstats,
  save_model_architecture,
  save_report,
)
from .utils.classificationmetrics import (
  save_confusion_matrix_with_fp_fn_imagelist,
  plot_curves,
  plot_and_save_curves,
)
from .utils.torchapi import loadmodel, unloadmodel


def evaluate_model(model, dataloader, device):
  all_logits, all_labels = [], []
  total_inference_time = 0.0

  model.eval()
  with torch.no_grad():
    with tqdm(total=len(dataloader), desc="Evaluating", unit="batch") as pbar:
      for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device)

        start_time = time.time()
        logits = model(images)
        total_inference_time += time.time() - start_time

        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())
        pbar.update(1)

  return torch.cat(all_logits), torch.cat(all_labels), total_inference_time


def load_dataset(args, splits, flag='test', has_to_calculate_mean_std=False, shuffle=True, timestamp=None):
  log.info(f"splits:: {splits}")
  log.info(f"flag:: {flag}")

  loadertxt = splits['loadertxt']
  flagloadertxt = splits[f'{flag}loadertxt']

  ## Create an initial dataset instance without normalization
  transform = transforms.Compose([
    transforms.Resize(args.input_size),
    transforms.ToTensor()
  ])

  dataset = DataSet(
    root=loadertxt,
    lists=flagloadertxt,
    input_size=args.input_size,
    transform=transform,
    sample_limit=args.sample_limit
  )
  loader = get_dataloader(
    dataset=dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=shuffle,
  )

  agumentations = [
    transforms.Resize(args.input_size),
    transforms.ToTensor(),
  ]

  mean, std = None, None
  if has_to_calculate_mean_std:
    ## Calculate mean and std
    mean, std = calculate_mean_std(loader)
    agumentations.append(transforms.Normalize(mean=mean, std=std))

  ## Re-apply normalization to the dataset
  dataset.transform = transforms.Compose(agumentations)

  ## update the dataloader with updated transforms
  loader = get_dataloader(
    dataset=dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=shuffle,
  )

  label_counts = Counter(dataset.labels)
  total_images_per_label = {dataset.index_to_label[idx]: count for idx, count in label_counts.items()}


  index_to_label = dataset.index_to_label
  class_names = list(index_to_label.values())

  dataset_info = {
    "creator": __author__,
    "timestamp": timestamp or datetime.now().strftime("%d%m%y_%H%M%S"),
    "dataset_name": args.dataset,
    "num_samples": len(dataset),
    "num_classes": len(class_names),
    "batch_size": args.batch_size,
    "split": flag,
    "class_names": class_names,
    "loadertxt": loadertxt,
    "flagloadertxt": flagloadertxt,
    "input_size": args.input_size,
    "num_unique_labels": len(dataset.unique_label_names),
    "label_counts": dict(label_counts),
    "total_images_per_label": total_images_per_label,
    # "mean": mean.tolist(),
    # "std": std.tolist(),
    "label_to_index": dict(dataset.label_to_index),
    "labels": list(dataset.label_to_index.keys()),
    "label_ids": list(dataset.label_to_index.values()),
    "index_to_label": dict(index_to_label),
  }

  return loader, dataset, label_counts, dataset_info


def get_weights_path(weights_basepath, model_name, suffix='final'):
  """
  Dynamically resolve the correct weights path by pattern matching subdirectories.
  """
  ## Pattern to match subdirectories starting with the model name
  pattern = os.path.join(weights_basepath, f"{model_name}-*")

  ## Get all matching directories
  matching_dirs = glob.glob(pattern)

  if not matching_dirs:
    log.warning(f"No matching directory found for model '{model_name}' in '{weights_basepath}'")
    raise FileNotFoundError(f"No matching directory found for model '{model_name}' in '{weights_basepath}'")

  ## Assume the most recent directory based on creation time
  matching_dirs.sort(key=os.path.getctime, reverse=True)
  selected_dir = matching_dirs[0]

  ## Build the weights file path
  weights_path = os.path.join(selected_dir, f"{model_name}-{suffix}.pth")

  ## Validate that the weights file exists
  if not os.path.exists(weights_path):
    log.warning(f"Weights file '{weights_path}' not found.")
    raise FileNotFoundError(f"Weights file '{weights_path}' not found.")

  return weights_path


def main(args):
  """Perform evaluations on the dataset, including statistics calculation and visualization."""
  __dataset_root__ = os.getenv('__DATASET_ROOT__')
  __modelhub_root__ = os.getenv('__MODELHUB_ROOT__')
  timestamp = datetime.now().strftime("%d%m%y_%H%M%S")

  splits = get_splits(args.dataset, args.datasetcfg, __dataset_root__)

  split = args.split
  has_to_calculate_mean_std = args.mean_std
  dataloader, dataset, label_counts, dataset_info = load_dataset(args, splits, flag=split, has_to_calculate_mean_std=has_to_calculate_mean_std)

  ## Get class names from dataset_info
  index_to_label = dataset_info['index_to_label']
  class_names = list(index_to_label.values())

  output_dir = Path(args.to_path or f"logs/eval_{split}-{timestamp}")

  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
  print(torch.version.cuda)
  print(f'loadmodel::device: {device}')
  print(f'loadmodel::args.gpu: {args.gpu}')

  model_name = args.net
  weights_path = args.weights_path
  suffix = 'final'

  log.info(f'weights_path: {weights_path}')

  save_dir = os.path.dirname(weights_path)
  if args.use_to_path:
    save_dir = Path(os.path.join(output_dir, model_name))
    save_dir.mkdir(parents=True, exist_ok=True)

  model = loadmodel(args)
  ## Set model to evaluation mode
  model.eval()
  save_model_architecture(model, save_dir, args.net)

  log.info(f"Running inference for model: {model_name}...")
  logits, labels, _ = evaluate_model(model, dataloader, get_device(args))

  ## Visualization
  metrics = save_confusion_matrix_with_fp_fn_imagelist(logits, labels, class_names, save_dir, dataset)

  plot_curves(logits, labels, class_names, save_dir)
  plot_and_save_curves(logits, labels, class_names, save_dir)

  num_iterations = args.num_iterations
  num_class = len(class_names)
  ## Generate Model Summary based on final trained model
  log.info("Generating model summary...")
  key_stats, summary_info = model_summary(
    model=model,
    input_size=(3, *args.input_size),
    device=device,
    verbose=False,
    num_iterations=num_iterations,
    weights_path=weights_path or None,
    dnnarch=model_name,
    num_class=num_class,
    depth=10,
  )

  ## Generate Model stats based on final trained model
  log.info("Generating model information...")
  perfstats = model_perfstats(
    model=model,
    input_size=(3, *args.input_size),
    device=device,
    verbose=False,
    num_iterations=num_iterations,
    weights_path=weights_path or None,
    dnnarch=model_name,
    num_class=num_class,
  )

  ## Save report using actual class names
  report_data = {
    "system_info": gather_system_info(),
    "metrics": metrics,
    "dataset_info": {},
    "model_info": {
      "architecture": model.__class__.__name__,
      "total_layers": len(list(model.modules())),
      "weights_path": args.weights_path or "N/A",
      "input_size": args.input_size,
      "creator": __author__,
    },
    "modelsummary": key_stats,
    "modelperfstats": perfstats,
  }

  report_data['dataset_info'][split] = dataset_info


  save_metrics(save_dir, metrics)
  # log.info(f"Model metrics: {metrics}")

  save_model_summary(save_dir, key_stats, summary_info)
  # log.info(f"Model summary saved at: {save_dir}")

  save_model_perfstats(save_dir, perfstats)
  # log.info(f"Model perfstats: {perfstats}")

  save_report(report_data, save_dir)

  log.info(f"Evaluation completed for model: {model_name}")
  unloadmodel(model)

  log.info("All evaluations completed.")


def parse_args(**kwargs):
  """Comprehensive argument parser for model evaluation."""
  import argparse
  import ast
  parser = argparse.ArgumentParser(description='Input parser', formatter_class=argparse.RawTextHelpFormatter)

  # Model and GPU-related arguments
  parser.add_argument('--net', type=str, help='net type (e.g., mobilenet_v2, efficientnet_b0)')
  parser.add_argument('--loss', type=str, default='CrossEntropyLoss', help='define the lossfunction to be used')
  parser.add_argument('--gpu', action='store_false', default=True, help='use gpu if available')
  parser.add_argument('--batch_size', type=int, default=64, help='batch size for dataloader')
  parser.add_argument('--warm', type=int, default=2, help='warm-up training phase')
  parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate')
  parser.add_argument('--resume', action='store_true', default=False, help='resume training if checkpoint available')
  parser.add_argument('--resume_epoch', type=int, default=0, help='epoch to resume training from')
  parser.add_argument('--pretrain', action='store_true', default=False, help='whether the pretrain model is used')
  parser.add_argument('--num_iterations', type=int, default=100, help='total number of iterations for FPS calculations')

  # Dataset and evaluation-specific arguments
  parser.add_argument('--dataset', default='100-driver-day-cam1', type=str, help='dataset name or path')
  parser.add_argument('--num_class', type=int, default=22, help='number of classes for classification')
  parser.add_argument('--epochs', type=int, default=100, help='number of epochs for training')
  parser.add_argument('--weights_path', type=str, help='path to model weights')
  parser.add_argument('--weights_basepath', type=str, help='path to model weights baspath for batch processing only')
  parser.add_argument('--checkpoint_path', type=str, help='path to checkpoint for resuming')
  parser.add_argument('--datasetcfg', default="data/ddd-datasets.yml", help='dataset configuration file path')
  parser.add_argument('--sample_limit', type=int, help='limit number of samples per epoch for testing')
  parser.add_argument('--input_size', type=str, default='(224,224)', help="input size for the DNN")
  parser.add_argument('--split', type=str, default='test', help="Single valued dataset split to be used for evaluating on.")
  parser.add_argument('--mean_std', action='store_false', default=True, help='calculate mean and standard deviation for the dataset')

  ## Output path argument
  parser.add_argument('--use_to_path', action='store_true', default=False, help='By default basepath to the weights_path is used as the output path to ensure requried files are created on the test dataset.')
  parser.add_argument('--to', dest='to_path', type=str, help='Output directory for saving results')

  ## DataLoader configuration
  parser.add_argument('--num_workers', type=int, default=4, help='number of worker processes for data loading')

  args = parser.parse_args()

  ## Validate and parse the input_size argument
  try:
    args.input_size = ast.literal_eval(args.input_size)
    if not isinstance(args.input_size, tuple) or len(args.input_size) != 2:
      raise ValueError
  except (ValueError, SyntaxError):
    log.error("Error: --input_size should be a tuple of two integers, e.g., '(224,224)'")
    exit(1)

  ## Check if weights path exists
  if args.weights_path and not os.path.exists(args.weights_path):
    log.warning(f"Weights path '{args.weights_path}' does not exist. Proceeding without loading weights.")

  return args


def print_args(args):
  """Prints the values of all arguments"""
  print("Arguments:")
  for arg in vars(args):
    print(f"{arg}: {getattr(args, arg)}")


if __name__ == '__main__':
  args = parse_args()
  print_args(args)
  main(args)
