"""
Create image mosaic for each modality (Day/Night) and camera view (Cam1-Cam4),
arranging images by category in rows, fitting into 1/4 of A4 page in non-landscape mode.
"""
__author__ = 'XYZ'


import pdb
import os
import argparse
import logging
from PIL import Image
import matplotlib.pyplot as plt
import random
from datetime import datetime


## Common argument parser
def parse_args():
  parser = argparse.ArgumentParser(description='Image mosaic creator', formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument('--from', type=str, dest='from_path', required=True, help='Base directory containing the data splits.')
  parser.add_argument('--to', type=str, dest='to_path', default=None, help='Output directory to save the mosaic image.')
  parser.add_argument('--experiment', type=str, default='traditional', choices=['traditional', 'cross-camera', 'cross-modality', 'cross-vehicle'],
                      help="Experimental setting (default: traditional)")
  parser.add_argument('--split-name', type=str, default='test', choices=['train', 'test', 'val'], help="Split name (default: test)")
  parser.add_argument('--random-seed', type=int, default=42, help="Random seed for selecting images (default: 42)")
  return parser.parse_args()


## Function to configure logger
def configure_logger(output_path):
  if not os.path.exists(output_path):
    os.makedirs(output_path)
  log_file = os.path.join(output_path, f'mosaic_{datetime.now().strftime("%d%m%y_%H%M%S")}.log')
  logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s %(message)s')
  logging.info("Logger initialized")


## Function to load image paths and dynamically extract categories from the split text file
def load_image_paths_and_categories(file_path):
  img_paths = []
  categories = set()  ## Using a set to avoid duplicates

  with open(file_path, 'r') as file:
    # pdb.set_trace()
    for line in file:
      img_paths.append(line.split('\t')[1].strip())
      category_name = line.split('/')[0]  ## Extract the category (e.g., "C1_Drive_Safe")
      categories.add(category_name)

  return img_paths, sorted(categories)  ## Return sorted category names for consistency


## Function to get a random image from the specific category and camera
def get_random_image(base_path, camera, category, random_seed):
  category_dir = os.path.join(base_path, camera, category)
  images = os.listdir(category_dir)
  random.seed(random_seed)
  return os.path.join(category_dir, random.choice(images))


## Function to create and save the mosaic image
def create_mosaic(base_path, modality, experiment_setting, split_name, output_path, random_seed):
  cameras = [f"Cam{i}" for i in range(1, 5)]
  split_file = os.path.join(base_path, f"data-splits/{experiment_setting}/{modality}/{cameras[0]}/{modality[0].upper()}1_{split_name}.txt")
  _, categories = load_image_paths_and_categories(split_file)

  n_rows = len(categories)
  n_cols = len(cameras)

  fig, axs = plt.subplots(n_rows, n_cols, figsize=(8.27/2, 11.69/2))  ## A4 size / 2
  fig.tight_layout(pad=0.1)  ## Reduced padding to minimize whitespace between images

  # pdb.set_trace()
  for row, category in enumerate(categories):
    for col, camera in enumerate(cameras):
      split_file = os.path.join(base_path, f"data-splits/{experiment_setting}/{modality}/{camera}/{modality[0].upper()}{col+1}_{split_name}.txt")
      img_paths, _ = load_image_paths_and_categories(split_file)
      
      ## Pick a random image from the category folder based on the split file
      img_path = get_random_image(base_path, camera, category, random_seed)
      img = Image.open(img_path)
      axs[row, col].imshow(img)
      axs[row, col].axis('off')

      ## Label y-axis with category names
      if col == 0:
        axs[row, col].set_ylabel(category, rotation=0, labelpad=40, fontsize=8)

      ## Label x-axis with camera names
      if row == 0:
        axs[row, col].set_title(f"{camera}", fontsize=8)

  ## Save the final mosaic image
  output_file = os.path.join(output_path, f"{modality}_mosaic.png")
  plt.savefig(output_file, bbox_inches='tight', dpi=300)
  plt.close()

  ## Log the output path
  logging.info(f"Saved mosaic for {modality}-{camera} at {output_file}")
  print(f"Saved mosaic for {modality}-{camera} at {output_file}")


## Main function to execute the mosaic generation process
def main():
  args = parse_args()

  ## Set default output path if not provided
  output_path = args.to_path or f"logs/{datetime.now().strftime('%d%m%y_%H%M%S')}"
  configure_logger(output_path)

  base_path = args.from_path
  experiment_setting = f"{args.experiment.capitalize()}-setting"
  modalities = ["Day", "Night"]
  pdb.set_trace()

  for modality in modalities:
    create_mosaic(base_path, modality, experiment_setting, args.split_name, output_path, args.random_seed)
    logging.info(f"Mosaic for {modality} created at {output_path}")

  logging.info("All mosaics generated successfully.")


if __name__ == "__main__":
  main()
