"""
Create a mosaic for the Kaggle State Farm dataset, with 10 classes randomly picked, arranged in two rows,
and labeled with their corresponding normalized names with aesthetic formatting.
"""
__author__ = 'XYZ'


import os
import argparse
import logging
from PIL import Image
import matplotlib.pyplot as plt
import random
from datetime import datetime

## Mapping of class labels to their normalized names
LABELS = {
  'c0': 'Normal Driving',
  'c1': 'Texting Right',
  'c2': 'Talking on the Phone Right',
  'c3': 'Texting Left',
  'c4': 'Talking on the Phone Left',
  'c5': 'Operating the Radio',
  'c6': 'Drinking',
  'c7': 'Reaching Behind',
  'c8': 'Hair and Makeup',
  'c9': 'Talking to Passenger'
}

## Function to configure logger
def configure_logger(output_path):
  if not os.path.exists(output_path):
    os.makedirs(output_path)
  log_file = os.path.join(output_path, f'mosaic_{datetime.now().strftime("%d%m%y_%H%M%S")}.log')
  logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s %(message)s')
  logging.info("Logger initialized")


## Function to create a mosaic for the Kaggle State Farm dataset
def create_mosaic(base_path, output_path, random_seed, labels, img_width=150, img_height=150, blur_faces=True):
  """
  Create a mosaic for the Kaggle State Farm dataset with 10 classes arranged in two rows.
  """
  train_dir = os.path.join(base_path, 'imgs', 'train')

  ## Check if the train directory exists
  if not os.path.exists(train_dir):
    raise FileNotFoundError(f"Train directory not found: {train_dir}")

  classes = sorted(os.listdir(train_dir))

  ## Select one random image per class
  selected_images = []
  for cls in classes:
    class_dir = os.path.join(train_dir, cls)
    images = os.listdir(class_dir)
    if not images:
      raise FileNotFoundError(f"No images found in class directory: {class_dir}")
    selected_image = random.choice(images)
    selected_images.append((cls, os.path.join(class_dir, selected_image)))

  ## Parameters for spacing
  n_cols = 5  # Number of images per row (fixed at 5)
  n_rows = 2  # Two rows

  ## Calculate the figure size dynamically based on the image dimensions and the number of rows/columns
  col_width = img_width + 2  # Add 5px of spacing between columns
  row_height = img_height + 17  # Increased row height to add extra space for labels and between rows
  figure_width = (col_width * n_cols) / 100  # Convert to figure width in inches
  figure_height = (row_height * n_rows) / 100  # Convert to figure height in inches

  ## Create the figure with dynamically calculated width and height
  fig, axs = plt.subplots(
    n_rows,
    n_cols,
    figsize=(figure_width, figure_height),
    gridspec_kw={'wspace': 0.025, 'hspace': 0.2}
  )
  fig.tight_layout(pad=0.0)  ## Reduced padding to minimize white space

  for idx, (cls, img_path) in enumerate(selected_images):
    row_idx = idx // n_cols
    col_idx = idx % n_cols
    img = Image.open(img_path).resize((img_width, img_height))
    axs[row_idx, col_idx].imshow(img)
    axs[row_idx, col_idx].axis('off')

    # Add labels below for the top row and above for the bottom row
    if row_idx == 0:
      axs[row_idx, col_idx].text(0.5, -0.03, f'c{idx}:{labels[cls]}', ha='center', va='top', fontsize=4.5, transform=axs[row_idx, col_idx].transAxes)
    else:
      ## Move labels for row 2 above the image
      axs[row_idx, col_idx].text(0.5, 1.03, f'c{idx}:{labels[cls]}', ha='center', va='bottom', fontsize=4.5, transform=axs[row_idx, col_idx].transAxes)

  ## Save the mosaic
  output_file = os.path.join(output_path, "sf3dd-mosaic.png")
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0.1)  ## Adjusted to allow room for labels
  plt.savefig(output_file, bbox_inches='tight', pad_inches=0.01, dpi=300)
  plt.close()

  ## Log the output path
  logging.info(f"Saved mosaic for StateFarm at {output_file}")
  print(f"Saved mosaic for StateFarm at {output_file}")


def parse_args():
  parser = argparse.ArgumentParser(description='Image mosaic creator with optional face blurring', formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument('--from', type=str, dest='from_path', required=True, help='Base directory containing the data splits.')
  parser.add_argument('--to', type=str, dest='to_path', default=None, help='Output directory to save the mosaic image.')
  parser.add_argument('--random-seed', type=int, default=42, help="Random seed for selecting images (default: 42)")
  parser.add_argument('--blur-faces', action='store_true', help="Flag to enable face blurring (default: True)")
  return parser.parse_args()


def main(args):
  ## Set default output path if not provided
  output_path = args.to_path or f"logs/{datetime.now().strftime('%d%m%y_%H%M%S')}"
  configure_logger(output_path)

  base_path = args.from_path
  random.seed(args.random_seed)

  logging.info("Starting mosaic creation for Kaggle State Farm dataset")
  create_mosaic(base_path, output_path, args.random_seed, LABELS, img_width=150, img_height=150, blur_faces=args.blur_faces)

  logging.info(f"Mosaic created and saved at {output_path}")


if __name__ == "__main__":
  args = parse_args()
  main(args)
