"""
Create separate image mosaic for each modality and camera, with categories arranged in two rows (11 per row),
fitting into 1/4 of an A4 page in non-landscape mode. Optionally blur faces in the images.
"""
__author__ = 'XYZ'


import argparse
import logging
import os
import random

from datetime import datetime

import cv2  ## OpenCV for face detection and blurring
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
from tqdm import tqdm


## Predefined list of categories
CATEGORIES = [
  'C1_Drive_Safe',
  'C2_Sleep',
  'C3_Yawning',
  'C4_Talk_Left',
  'C5_Talk_Right',
  'C6_Text_Left',
  'C7_Text_Right',
  'C8_Make_Up',
  'C9_Look_Left',
  'C10_Look_Right',
  'C11_Look_Up',
  'C12_Look_Down',
  'C13_Smoke_Left',
  'C14_Smoke_Right',
  'C15_Smoke_Mouth',
  'C16_Eat_Left',
  'C17_Eat_Right',
  'C18_Operate_Radio',
  'C19_Operate_GPS',
  'C20_Reach_Behind',
  'C21_Leave_Steering_Wheel',
  'C22_Talk_to_Passenger'
]

## Function to normalize the category names for display
def normalize_category_name(category):
  return category.split('_', 1)[1].replace('_', ' ')  ## Split by first underscore and replace remaining with spaces


## Function to configure logger
def configure_logger(output_path):
  if not os.path.exists(output_path):
    os.makedirs(output_path)
  log_file = os.path.join(output_path, f'mosaic_{datetime.now().strftime("%d%m%y_%H%M%S")}.log')
  logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s %(message)s')
  logging.info("Logger initialized")

## Function to detect and blur faces in an image (in-memory, without saving)
def detect_and_blur_faces(image):
  # Load the pre-trained face detection model (Haar Cascade)
  face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

  # Convert PIL image to OpenCV format
  cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
  gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)  # Convert to grayscale for face detection

  # Detect faces
  faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

  # Blur each detected face
  for (x, y, w, h) in faces:
    face = cv_img[y:y+h, x:x+w]
    blurred_face = cv2.GaussianBlur(face, (99, 99), 30)
    cv_img[y:y+h, x:x+w] = blurred_face

  # Convert OpenCV image back to PIL format
  blurred_img = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
  return blurred_img


## Function to get image from the split file based on category
def get_image_from_split(split_file, base_path, modality, camera, category):
  with open(split_file, 'r') as f:
    for line in f:
      image_path = line.split()[1]  ## Assuming second column has the relative path to the image
      if category in image_path:
        return os.path.join(base_path, modality, camera, image_path)  ## Full path to image
  return None

def create_single_mosaic(base_path, modality, cameras, experiment_setting, split_name, output_path, categories, img_width=100, img_height=100, blur_faces=True):
  """
  Create a single mosaic for all cameras with rows as categories and columns as cameras.
  """
  n_categories = len(categories)
  n_cameras = len(cameras)

  # Calculate the figure size dynamically
  figure_width = (img_width * n_cameras) / 100  # Width for all camera columns
  figure_height = (img_height * n_categories) / 100  # Height for all category rows

  fig, axs = plt.subplots(
    n_categories, n_cameras,
    figsize=(figure_width, figure_height),
    gridspec_kw={'wspace': 0.05, 'hspace': 0.05}  # Adjust spacing
  )
  
  for row_idx, category in enumerate(categories):
    for col_idx, camera in enumerate(cameras):
      # Extract camera index for split file naming
      camera_index = camera.replace("Cam", "")
      split_file = os.path.join(
        base_path,
        f"data-splits/{experiment_setting}/{modality}/{camera}/{modality[0].upper() if modality else ''}{camera_index}_{split_name}.txt"
      )

      # Fetch the image for the given category and camera
      img_path = get_image_from_split(split_file, base_path, modality, camera, category)
      if img_path and os.path.exists(img_path):
        img = Image.open(img_path)

        # Apply face blurring if enabled
        if blur_faces:
          img = detect_and_blur_faces(img)

        img = img.resize((img_width, img_height))  # Resize dynamically
        axs[row_idx, col_idx].imshow(img)
        axs[row_idx, col_idx].axis('off')
      else:
        axs[row_idx, col_idx].axis('off')  # Hide empty cells if image is missing

      # Add labels for the first column (category name) and first row (camera name)
      if col_idx == 0:
        axs[row_idx, col_idx].text(
          -0.1, 0.5, f"{category.split('_')[0].replace('C', '')}. {normalize_category_name(category)}",
          ha='right', va='center', fontsize=5, transform=axs[row_idx, col_idx].transAxes
        )
      if row_idx == 0:
        axs[row_idx, col_idx].set_title(f"{camera}", fontsize=6, pad=5)

  # Save the combined mosaic
  output_file = os.path.join(output_path, f"{modality}-All-Cameras-Single-Mosaic.png")
  plt.savefig(output_file, bbox_inches='tight', pad_inches=0.01, dpi=300)
  plt.close()
  print(f"Single mosaic for all cameras saved at {output_file}")


## Common argument parser
def parse_args():
  parser = argparse.ArgumentParser(description='Image mosaic creator with optional face blurring', formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument('--from', type=str, dest='from_path', required=True, help='Base directory containing the data splits.')
  parser.add_argument('--to', type=str, dest='to_path', default=None, help='Output directory to save the mosaic image.')
  parser.add_argument('--modality', type=str, default='day', choices=['day', 'night'], help="Modality (default: day)")
  parser.add_argument('--experiment', type=str, default='traditional', choices=['traditional', 'cross-camera', 'cross-modality', 'cross-vehicle'],
                      help="Experimental setting (default: traditional)")
  parser.add_argument('--split-name', type=str, default='test', choices=['train', 'test', 'val'], help="Split name (default: test)")
  parser.add_argument('--random-seed', type=int, default=42, help="Random seed for selecting images (default: 42)")
  parser.add_argument('--blur-faces', action='store_true', help="Flag to enable face blurring (default: True)")
  return parser.parse_args()


## Main function to execute the mosaic generation process
def main(args):
  ## Set default output path if not provided
  output_path = args.to_path or f"logs/{datetime.now().strftime('%d%m%y_%H%M%S')}"
  configure_logger(output_path)

  base_path = args.from_path
  experiment_setting = f"{args.experiment.capitalize()}-setting"
  modality = args.modality.capitalize()  # Day/Night

  cameras = [f"Cam{i}" for i in range(1, 5)]  # List of cameras

  logging.info(f"Starting single mosaic creation for all cameras and categories")
  create_single_mosaic(
    base_path, modality, cameras, experiment_setting, args.split_name,
    output_path, CATEGORIES, img_width=50, img_height=50, blur_faces=args.blur_faces
  )
  logging.info(f"Single mosaic for all cameras created at {output_path}")


if __name__ == "__main__":
  args = parse_args()
  main(args)
